In [1]:
import os
from PIL import Image
import numpy as np
import json
import random
import torch

In [2]:
class ARC_Task:
  def __init__(self, inputs, outputs, test_input, test_output):
    self.inputs = inputs
    self.outputs = outputs
    self.test_input = test_input
    self.test_output = test_output

In [3]:
from matplotlib.pyplot import imshow
%matplotlib inline

def visualize(tensor):
  #torch.argmax(tensor.reshape(30, 30, 11)).item()
  arr = np.zeros((30, 30))
  for x in range(30):
    for y in range(30):
      arr[x, y] = torch.argmax(tensor[x, y]).item()
  convert(arr)

def visualize2(tensor):
  arr = np.zeros((30, 30))
  for x in range(30):
    for y in range(30):
      difs = [torch.abs(v - tensor[x, y]) for v in color2vector]
      sums = [torch.sum(difs[i]) for i in range(11)]
      index_min = min(range(len(sums)), key=sums.__getitem__)
      #print(index_min)
      arr[x, y] = index_min
  convert(arr)

converter = {}
converter[0] = np.array([0, 0, 0])
converter[1] = np.array([0, 116, 217])
converter[2] = np.array([255, 65, 54])
converter[3] = np.array([46, 204, 64])
converter[4] = np.array([255, 220, 0])
converter[5] = np.array([170, 170, 170])
converter[6] = np.array([240, 18, 190])
converter[7] = np.array([255, 113, 27])
converter[8] = np.array([127, 219, 255])
converter[9] = np.array([135, 12, 37])
converter[10] = np.array([255, 255, 255])

def convert(X):
  Y = np.zeros(shape=(X.shape[0], X.shape[1], 3), dtype=np.uint8)
  for x in range(X.shape[0]):
    for y in range(X.shape[1]):
      Y[x, y] = converter[X[x, y]]
  imshow(Y)

In [4]:
def fetch(directory):
  data = []
  MAX_LENGTH = 30
  os.chdir(directory)
  files = [f for f in os.listdir() if f.endswith('json')]
  for f in files:
    with open(f, 'r') as file:
      json_file = json.loads(file.read())
      inputs = [one_hot_encode_3d(np.array(train_exmaples['input'])) for train_exmaples in json_file['train']]
      outputs = [one_hot_encode_3d(np.array(train_exmaples['output'])) for train_exmaples in json_file['train']]
      test_input = [one_hot_encode_3d(np.array(train_exmaples['input'])) for train_exmaples in json_file['test']]
      test_output = [one_hot_encode_3d(np.array(train_exmaples['output'])) for train_exmaples in json_file['test']]
      data.append(ARC_Task(inputs, outputs, test_input[0], test_output[0]))

  return data

def one_hot_encode(array):
  MAX_LENGTH = 30
  arr = torch.zeros((MAX_LENGTH, MAX_LENGTH, 11))
  size = array.shape
  for x in range(MAX_LENGTH):
    for y in range(MAX_LENGTH):
      if x < size[0] and y < size[1]:
        arr[x, y, array[x, y]] = 1
      else:
        arr[x, y, 10] = 1
  return arr.flatten()

def one_hot_encode_3d(array):
  MAX_LENGTH = 30
  arr = torch.zeros((MAX_LENGTH, MAX_LENGTH, 11))
  size = array.shape
  for x in range(MAX_LENGTH):
    for y in range(MAX_LENGTH):
      if x < size[0] and y < size[1]:
        arr[x, y, array[x, y]] = 1
      else:
        arr[x, y, 10] = 1
  return arr.reshape((11, MAX_LENGTH, MAX_LENGTH))

In [5]:
from torch.utils.data import Dataset

class ARCDataset(Dataset):
  def __init__(self, data):
    self.data = data

  def __getitem__(self, idx):
    return self.data[idx]

  def __len__(self):
    return len(self.data)

In [6]:
import torch.nn as nn
import torch.nn.functional as F

class ConvVAE(nn.Module):
  def __init__(self, image_channels = 11, kernel_size = 3,
               latent_dim = 200, init_channels = 8):
    super(ConvVAE, self).__init__()
    self.latent_dim = latent_dim
    self.enc1 = nn.Conv2d(
        in_channels=image_channels, out_channels=init_channels, 
        kernel_size=kernel_size, stride=2, padding=1
    )
    self.enc2 = nn.Conv2d(
        in_channels=init_channels, out_channels=init_channels*2, 
        kernel_size=kernel_size, stride=2, padding=1
    )
    self.enc3 = nn.Conv2d(
        in_channels=init_channels*2, out_channels=init_channels*4, 
        kernel_size=kernel_size, stride=2, padding=1
    )
    self.enc4 = nn.Conv2d(
        in_channels=init_channels*4, out_channels=64, 
        kernel_size=kernel_size, stride=2, padding=0
    )
    self.fc1 = nn.Linear(64, 128)
    self.fc_mu = nn.Linear(128, latent_dim)
    self.fc_log_var = nn.Linear(128, latent_dim)
    self.fc2 = nn.Linear(latent_dim, 64)

    self.dec1 = nn.ConvTranspose2d(
        in_channels=64, out_channels=init_channels*8,
        kernel_size=kernel_size, stride=1, padding=0
    )
    self.dec2 = nn.ConvTranspose2d(
        in_channels=init_channels*8, out_channels=init_channels*4,
        kernel_size=kernel_size, stride=2, padding=1
    )
    self.dec3 = nn.ConvTranspose2d(
        in_channels=init_channels*4, out_channels=init_channels*2,
        kernel_size=kernel_size, stride=2, padding=1
    )
    self.dec4 = nn.ConvTranspose2d(
        in_channels=init_channels*2, out_channels=init_channels*1,
        kernel_size=kernel_size, stride=2, padding=1
    )


  def reparameterize(self, mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    sample = mu + (eps * std)
    return sample

  def forward(self, x):
    task_vector = torch.zeros((1, 64))
    mu_acc = torch.zeros((1, self.latent_dim))
    log_var_acc = torch.zeros((1, self.latent_dim))
    for input, output in zip(task.inputs, task.outputs):
      latent_input, mu_input, log_var_input = self.run_encoder(input)
      latent_output, mu_output, log_var_output = self.run_encoder(output)
      task_vector += latent_output - latent_input
      mu_acc += mu_output - mu_input
      log_var_acc += log_var_output - log_var_input
    task_vector /= len(task.inputs)
    mu = mu_acc / len(task.inputs)
    log_var = log_var_acc / len(task.inputs)
    latent_prediction, _, _ = self.run_encoder(task.test_input)
    prediction = self.run_decoder(latent_prediction + task_vector)
    return prediction, mu, log_var

  def run_encoder(self, x):
    x = F.relu(self.enc1(x))
    x = F.relu(self.enc2(x))
    x = F.relu(self.enc3(x))
    #print(x.size())
    x = F.relu(self.enc4(x))
    batch, _, _, = x.shape
    x = F.adaptive_avg_pool2d(x, 1).reshape(1, -1)
    hidden = self.fc1(x)
    mu = self.fc_mu(hidden)
    log_var = self.fc_log_var(hidden)

    z = self.reparameterize(mu, log_var)
    z = self.fc2(z) # latent
    return z, mu, log_var

  def run_decoder(self, x):
    x = F.relu(self.dec1(x))
    x = F.relu(self.dec2(x))
    x = F.relu(self.dec3(x))
    reconstruction = torch.sigmoid(self.dec4(x))
    return reconstruction

In [7]:
def final_loss(loss, mu, log_var):
  KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
  return KLD + loss

In [8]:
train_dataset = ARCDataset(fetch("/content/train"))
len(train_dataset)

FileNotFoundError: ignored

In [9]:
model = ConvVAE()
model.train()

ConvVAE(
  (enc1): Conv2d(11, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (enc2): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (enc3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (enc4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
  (fc1): Linear(in_features=64, out_features=128, bias=True)
  (fc_mu): Linear(in_features=128, out_features=200, bias=True)
  (fc_log_var): Linear(in_features=128, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=64, bias=True)
  (dec1): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (dec2): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (dec3): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (dec4): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)

In [None]:
import torch.optim as optim

mse = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
n_epochs = 10

for epoch in range(n_epochs):
  loss = 0

  for task in train_dataset:
    optimizer.zero_grad()
    reconstruction, mu, log_var = model(task)
    mse_loss = mse(prediction, task.test_output)
    train_loss = final_loss(mse_loss, mu, log_var)
    train_loss.backward()
    optimizer.step()
    loss += train_loss.item()

  if epoch % 5 == 0:
    print(epoch, loss)