In [1]:
# conda environment (project)
# /home/student/.conda/envs/project/bin/python 
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import os, gc
import torch

from transformers import set_seed
from datasets import load_dataset
from evaluate import *
from arc.arc import ARCSolver

from datasets import Dataset
from utils import render_grid

In [None]:
# prepare the test dataset
data_path = "dataset"
dataset, task_list = load_data(data_path)
df300 = sample_data(dataset, task_list, n_row=30000, random=112) 
df300.head(5) 

In [3]:
# prepare samples for each task
task_samples = []
for t in range(300):
    df = sample_data(dataset, task_list, n_row=1000, indices=[t])
    task_samples.append(df)

In [None]:
# Visualize a task (EDA)
task_indices = [1, 6, 8, 10, 13, 15, 17] # select which task you want to examine
n_sample = 1
for task_idx in task_indices:
    print(task_idx)
    for data in Dataset.from_pandas(task_samples[task_idx]).shuffle().select(range(n_sample)):
        for case in data['train']:
            print("==================================================")
            print("Example input")
            render_grid(case['input'])
            print("Example output")
            render_grid(case['output'])
            break
        print("==================================================")
        print("Example test input")
        render_grid(data['test'][0]['input'])
        print("Example test output")
        render_grid(data['test'][0]['output'])
    print("==================================================")

In [None]:
simple_tasks = []
hard_tasks = []
for task_idx in range(300):
    check = True
    for data in Dataset.from_pandas(task_samples[task_idx]).shuffle().select(range(3)):
        for case in data['train']:
            wi, hi = len(case['input'][0]), len(case['input'])
            wo, ho = len(case['output'][0]), len(case['output'])
            if (wi!=wo) or (hi!=ho): check = False
        case = data['test'][0]
        wi, hi = len(case['input'][0]), len(case['input'])
        wo, ho = len(case['output'][0]), len(case['output'])
        if (wi!=wo) or (hi!=ho): check = False
    if check: simple_tasks.append(task_idx)
    else: hard_tasks.append(task_idx)
print(simple_tasks)

In [6]:
set_seed(1234567890)
token = os.environ.get("HF_TOKEN", None)
solver = ARCSolver(model_id="Qwen/Qwen3-1.7B", hf_token=token)

In [7]:
# solver.prepare_train()
n_train = 30000
n_eval = 500
dfsimple = sample_data(dataset, task_list, n_row=n_train+n_eval, indices=simple_tasks, random=56)
dfhard = sample_data(dataset, task_list, n_row=n_train+n_eval, indices=hard_tasks, random=56)
train_dataset = Dataset.from_pandas(df300).select(range(n_train))
# solver.train(train_dataset)

In [None]:
n_eval = 100
solver.prepare_evaluation(select_adapter="20250711_170015") # make sure you set the right model

In [None]:
# evaluate our model (eval set)
scores = []
n_eval = 20
task_indices = [1,6,8,9,10,13,14,15,17] # list(range(20))
scores_task = []
for task in task_indices:
    eval_dataset = Dataset.from_pandas(task_samples[task]).select(range(n_eval))
    for eval_data in tqdm(eval_dataset):
        # print("============================================")
        # print("Test input")
        # render_grid(eval_data["test"][0]['input'])

        # print("Predict output")
        # solver.train_testtime(eval_data) # TTT
        preds = [ solver.predict(eval_data, use_ttt=False) for _ in range(10) ] # augmented sampling
        # for p in preds:
        #     if p is not None: render_grid(p)
        # if pred is not None: render_grid(pred)

        # print("Test output")
        # render_grid(eval_data["test"][0]['output'])
        # print("============================================")
        s = 0
        for p in preds:
            if p is None: continue
            else: s = max(s, check_match(p, eval_data["test"][0]["output"]))
        # if pred is None: s = 0
        # else: s = check_match(pred, eval_data["test"][0]["output"])
        scores.append(s)
    score = np.array(scores).mean() * 100
    scores_task.append(score)
    print(f"Evaluation score: {score:.2f}", flush=True)
    scores = []

In [None]:
from matplotlib import pyplot as plt
x = np.arange(9)
plt.bar(x, scores_task)
plt.xticks(x, task_indices)
plt.ylim(0,100)
plt.show()