In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from arc_ahrm import ArcAHRM
from load_dataset import LoadDataset
import os
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
import json
from IPython.display import clear_output
from visualization import Visualization

from train import train

plt.style.use('dark_background')

script_directory = os.getcwd()

plt.style.use('dark_background')

def plt_show(img):
    clear_output(wait=False)
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.show()

if __name__ == "__main__":
    
    d_model = 512
    arc_ahrm = ArcAHRM(d_model).to("cuda").to(torch.bfloat16)
    base_lr = 1e-3
    emb_lr = base_lr / 10 #/ (286 - 13)
    rec_lr = base_lr
    dec_lr = emb_lr
    optimizer = torch.optim.Adam([
        {"params": arc_ahrm.ahrm.parameters(), "lr": rec_lr},
        {"params": arc_ahrm.grid_embed.parameters(), "lr": emb_lr},
        {"params": arc_ahrm.grid_combiner.parameters(), "lr": emb_lr},
        {"params": arc_ahrm.grid_decode.parameters(), "lr": dec_lr}
        ])

    epoch_losses = []

    checkpoint = torch.load("./saved/pt/arc_ahrm_tet_rec_1e-3_emb_1e-4_4675.pth")
    arc_ahrm.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    epoch_losses = checkpoint["epoch_losses"]

    directories = [
        "/ARC-AGI/data/training",
        "/ARC-AGI/data/evaluation",
        "/ARC-AGI-2/data/training",
        #"/ARC-AGI-2/data/evaluation"
    ]

    name = "arc_ahrm"
    name += "_"

    for directory in directories:
        name += directory.split("/")[-1][0]

    name += "_rec_" + f"{rec_lr:.0e}" + "_emb_" + f"{emb_lr:.0e}"
    name = name.replace('e+0', 'e+').replace('e-0', 'e-')

    print(name)

    tasks = {}

    for directory in directories:
        new_tasks = LoadDataset.load_arc_tasks(script_directory + directory)
        tasks.update(new_tasks)
    
    train(
        draw_func=plt_show,
        arc_ahrm=arc_ahrm,
        optimizer=optimizer,
        epoch_losses=epoch_losses,
        tasks=tasks,
        batch_size=128-32,
        epochs=3000,
        save_each=25,
        save_name=name
        )
