# Main

In [1]:
%load_ext autoreload
%autoreload 2

# ARC

In [2]:
import os, json
from typing import List, Dict
import plotly.graph_objects as go
from plotly import subplots

In [15]:
class Plotter:
    COLOR_MAP = {-1: [255, 255, 255], 0: (0, 0, 0), 1: (37, 150, 190), 
                 2: (255, 65, 54), 3: (46, 204, 64), 4: (255, 220, 0), 
                 5: (170, 170, 170), 6: (240, 18, 190), 7: (255, 133, 27), 
                 8: (127, 219, 255), 9: (135, 12, 37)}
    def __init__(self) -> None:
        pass

    @classmethod
    def plot_grid(cls, grid:List[List[List[int]]], name=None):
        """ plots a grid """
        fig = subplots.make_subplots(1, 1)
        cls._add_grid_to_figure(fig, grid, 1, 1)
        cls._add_gridlines_to_figure(fig, grid)
        if name: print(name)
        fig.show()
    
    @classmethod
    def plot_grid_pairs(cls, grids, name=None):
        """ grids:  [[grid, grid], [grid, grid], ...] """
        fig = subplots.make_subplots(len(grids), 2)
        for i, grid_pair in enumerate(grids):
            for j, grid in enumerate(grid_pair):
                cls._add_grid_to_figure(fig, grid, i+1, j+1)
                cls._add_gridlines_to_figure(fig, grid, row=i+1, col=j+1)
        if name: print(name)
        fig.update_layout(height=400*len(grids))
        fig.show()
            

    @classmethod
    def _convert_to_plot_grid(cls, grid:List[List[List[int]]]):
        """ converts a grid to plotable grid """
        return [ [cls.COLOR_MAP[x] for x in y] for y in grid]

    @classmethod
    def _add_grid_to_figure(cls, fig:go.Figure, grid, row_id=1, col_id=1):
        """ converts grid to plotable grid and adds a trace """
        plot_grid = cls._convert_to_plot_grid(grid)
        fig.add_trace(go.Image(z=plot_grid), row_id, col_id)

    @classmethod
    def _add_gridlines_to_figure(cls, fig:go.Figure, grid, 
                                 color="sienna", width=2, row=1, col=1):
        """ given the grid, adds gridlines to draw over the figure """
        W, H = len(grid[0]), len(grid)
        [fig.add_shape(type="line", x0=-0.5, y0=y-0.5, x1=W-0.5, y1=y-0.5,
                       line=dict(color=color, width=width), row=row, col=col)
         for y in range(H)]
        [fig.add_shape(type="line", x0=x-0.5, y0=-0.5, x1=x-0.5, y1=H-0.5,
                       line=dict(color=color, width=width), row=row, col=col)
         for x in range(W)]
        


In [12]:
class Problem:
    def __init__(self, problem_dict:dict) -> None:
        self.problem = problem_dict
        self.name = problem_dict['name']
        self.num_train, self.num_test, self.train_inputs, self.train_outputs, \
            self.test_inputs, self.test_outputs = self._get_breakdown()
    
    def get_grid(self, i=0, type_input=True, type_train=True):
        """ gets i'th train/test input/output. """
        samples = self.problem['train'] if type_train else self.problem['test']
        k = "input" if type_input else "output"
        return samples[min(i, len(samples)-1)][k]
    
    def plot(self):
        """ plots the problem with all inputs and outputs """
        plotter = Plotter()
        plotter.plot_grid_pairs([(self.train_inputs[i], self.train_outputs[i])
                                 for i in range(self.num_train)], 
                                name=f"{self.name}-train")
        plotter.plot_grid_pairs([(self.test_inputs[i], self.test_outputs[i])
                                 for i in range(self.num_test)],
                                name=f"{self.name}-test")
    def _get_breakdown(self):
        """ breaks down the problem into train/test input/output """
        problem = self.problem
        num_train, num_test = len(problem['train']), len(problem['test'])
        return (num_train, num_test,
                [problem['train'][i]['input'] for i in range(num_train)],
                [problem['train'][i]['output'] for i in range(num_train)],
                [problem['test'][i]['input'] for i in range(num_test)],
                [problem['test'][i]['output'] for i in range(num_test)])
        

        

        

In [13]:
class ARC:
    def __init__(self) -> None:
        self.problems:List[Problem] = self.load_problem_set(training=True)
        
        self.problems_eval:List[Problem] = self.load_problem_set(training=False)
    
    
    def load_problem_set(self, training=True):
        """ load either training set or evaluation set """
        loc = "repo/data/training/" if training else "repo/data/evaluation/"
        problems:List[Problem] = []
        for file in os.listdir(loc):
            problem = json.loads(open(loc+file).read())
            problem['name'] = file.split('.')[0]
            problems.append(Problem(problem))
        return problems        
    
    
        
    


# Wibbly

In [20]:
arc = ARC()
for problem in arc.problems[:5]:
    problem.plot()

    



a85d4709-train


a85d4709-test


c8cbb738-train


c8cbb738-test


8e1813be-train


8e1813be-test


a699fb00-train


a699fb00-test


5c2c9af4-train


5c2c9af4-test
