<a href="https://colab.research.google.com/github/gmshroff/metaLearning2022/blob/main/project_data_code/arc_few_shot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q --upgrade --force-reinstall --no-deps kaggle

In [None]:
from google.colab import files

In [None]:
import numpy as np
import pandas as pd
import json, os
import matplotlib.pyplot as plt 
from matplotlib import colors
from PIL import Image
import io
import random
import copy
import pickle

In [None]:
files.upload()

In [None]:
!mkdir /root/.kaggle

In [None]:
!mv ./kaggle.json /root/.kaggle/.

In [None]:
!chmod 600 /root/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d gmshroff/few-shot-arc

In [None]:
!unzip few-shot-arc.zip

In [None]:
class ARC():
    def __init__(self,trn_dir='./training_orig/',tes_dir='./test_eval/'):
        pass
    def plot_task(self,task,kind='orig',show=True,ways=4):
        # Call with ways=4 for padded case and ways=6 for unpadded case
        n = len(task["train"]) + len(task["test"])
        if kind=='orig':fig, axs = plt.subplots(2, n, figsize=(4*n,8), dpi=50)
        elif kind=='fewshot': fig, axs = plt.subplots(ways+1, n, figsize=(6*n,12), dpi=100)
        plt.subplots_adjust(wspace=0, hspace=0)
        fig_num = 0
        cmap=self.cmap
        norm=self.norm
        for i, t in enumerate(task["train"]):
            if kind=='fewshot':t_in, t_out = np.array(t["input"]), t["output"]
            elif kind=='orig':t_in, t_out = np.array(t["input"]), np.array(t["output"])
            axs[0][fig_num].imshow(t_in, cmap=cmap, norm=norm)
            axs[0][fig_num].set_title(f'Train-{i} in')
            # axs[0][fig_num].set_yticks(list(range(t_in.shape[0])))
            # axs[0][fig_num].set_xticks(list(range(t_in.shape[1])))
            if kind=='orig':
                axs[1][fig_num].imshow(t_out, cmap=cmap, norm=norm)
                axs[1][fig_num].set_title(f'Train-{i} out')
                # axs[1][fig_num].set_yticks(list(range(t_out.shape[0])))
                # axs[1][fig_num].set_xticks(list(range(t_out.shape[1])))
            elif kind=='fewshot':
                for j in range(ways):
                    if j==t['label']: iscorrect='CORRECT'
                    else: iscorrect=''
                    axs[j+1][fig_num].imshow(np.array(t_out[j]), cmap=cmap, norm=norm)
                    axs[j+1][fig_num].set_title(f'Out-{i},{j} '+iscorrect)
                    # axs[j+1][fig_num].set_yticks(list(range(np.array(t_out[j]).shape[0])))
                    # axs[j+1][fig_num].set_xticks(list(range(np.array(t_out[j]).shape[1])))
            fig_num += 1
        for i, t in enumerate(task["test"]):
            if kind=='fewshot':t_in, t_out = np.array(t["input"]), t["output"]
            elif kind=='orig':t_in, t_out = np.array(t["input"]), np.array(t["output"])
            axs[0][fig_num].imshow(t_in, cmap=cmap, norm=norm)
            axs[0][fig_num].set_title(f'Test-{i} in')
            # axs[0][fig_num].set_yticks(list(range(t_in.shape[0])))
            # axs[0][fig_num].set_xticks(list(range(t_in.shape[1])))
            if kind=='orig' and show:
                axs[1][fig_num].imshow(t_out, cmap=cmap, norm=norm)
                axs[1][fig_num].set_title(f'Test-{i} out')
                # axs[1][fig_num].set_yticks(list(range(t_out.shape[0])))
                # axs[1][fig_num].set_xticks(list(range(t_out.shape[1])))
            elif kind=='fewshot' and show:
                for j in range(ways):
                    if j==t['label']: iscorrect='CORRECT'
                    else: iscorrect=''
                    axs[j+1][fig_num].imshow(np.array(t_out[j]), cmap=cmap, norm=norm)
                    axs[j+1][fig_num].set_title(f'Test-{i},{j} '+iscorrect)
                    # axs[j+1][fig_num].set_yticks(list(range(np.array(t_out[j]).shape[0])))
                    # axs[j+1][fig_num].set_xticks(list(range(np.array(t_out[j]).shape[1])))
            fig_num += 1
        plt.tight_layout()
        plt.show()
    def example2img(self,example):
        shp=np.array(example).shape
        fig=plt.Figure(figsize=(.5*shp[0],.5*shp[1]))
        ax = fig.add_subplot()
        cmap,norm=self.cmap,self.norm
        ax.imshow(np.array(example), cmap=cmap, norm=norm)
        """Convert a Matplotlib figure to a PIL Image and return it"""
        buf = io.BytesIO()
        fig.savefig(buf)
        buf.seek(0)
        img = Image.open(buf)
        return img
    def example2numpy(self,example):
        return np.array(example)

In [None]:
class FewShotARC(ARC):
    def __init__(self,trn_dir='./training_orig/',tes_dir='./test_eval/',ways=6):
        super().__init__(trn_dir='./training_orig/',tes_dir='./test_eval/')
        self.nrand=ways-1
        self.ntrain=len(self.trn_tasks)
        self.ntest=len(self.tes_tasks)
        self.meta_train_tasks=[]
        self.meta_test_tasks=[]
    def get_fs_task(self,taskid,kind='meta_train'):
        if kind=='meta_train': return self.meta_train_tasks[taskid]
        elif kind=='meta_test': return self.meta_test_tasks[taskid]
    def get_examples(self,taskid,trte,inout,kind='meta_train'):
        if kind=='meta_train':taskL=[self.get_task(taskid,kind) for taskid in self.ntrain]
        elif kind=='meta_test':taskL=[self.get_task(taskid,kind) for taskid in self.ntrain]
        return [taskL[taskid][trte][k][inout] for k in range(len(taskL[taskid][trte]))]

In [None]:
class FewShotPaddedARC(ARC):
    def __init__(self,trn_dir='./training_orig/',tes_dir='./test_eval/',ways=6):
        super().__init__(trn_dir='./training_orig/',tes_dir='./test_eval/')
        self.nrand=ways-1
        self.ntrain=len(self.trn_tasks)
        self.ntest=len(self.tes_tasks)
        self.meta_train_tasks=[]
        self.meta_test_tasks=[]
    def get_fs_task(self,taskid,kind='meta_train'):
        if kind=='meta_train': return self.meta_train_tasks[taskid]
        elif kind=='meta_test': return self.meta_test_tasks[taskid]
    def get_examples(self,taskid,trte,inout,kind='meta_train'):
        if kind=='meta_train':taskL=[self.get_task(taskid,kind) for taskid in self.ntrain]
        elif kind=='meta_test':taskL=[self.get_task(taskid,kind) for taskid in self.ntrain]
        return [taskL[taskid][trte][k][inout] for k in range(len(taskL[taskid][trte]))]

In [None]:
with open('./FewShotARC.pickle','rb') as f: a=pickle.load(f)

In [None]:
a.cmap=colors.ListedColormap(['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00','#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])

In [None]:
a.norm=colors.Normalize(vmin=0, vmax=9)

In [None]:
task=a.get_fs_task(11)

In [None]:
task['train'][0].keys()

In [None]:
a.plot_task(task,kind='fewshot')

In [None]:
with open('./FewShotPaddedARC.pickle','rb') as f: b=pickle.load(f)

In [None]:
b.cmap=colors.ListedColormap(['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00','#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])

In [None]:
b.norm=colors.Normalize(vmin=0, vmax=9)

In [None]:
padded_task=b.get_fs_task(11)

In [None]:
b.plot_task(padded_task,kind='fewshot')