# Test-time Training Exploration

## Goal

Can I solve tasks using test-time training?

I want to explore different TTT techniques such as hindsight experience replay and RL to see if a model can solve novel tasks that cannot be solve with the base model.

I have to focus on the techniques, not on efficiency.

## Code

### Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer, AutoConfig
import matplotlib.pyplot as plt
import matplotlib as mpl
from datasets import Dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig

from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import create_prompt_from_task, pretty_print_prompt
from arc25.plot import plot_task
from arc25.code_execution import safe_code_execution
from arc25.utils import set_random_seed

import sys
sys.path.append(os.path.realpath("../scripts"))
from finetuning import get_data_collator


plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 12

### Task definition

In [None]:
input_img = create_img((9, 9), color=0)
output_img = input_img.copy()
for x in range(0, input_img.shape[1], 1):
    draw_vertical_line(output_img, x, color=x+1)

task = Task(inputs=[input_img], outputs=[output_img], code='', name='manual')

## Hindsight Experience Replay (HER)

### Load model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

In [None]:
base_model_path = '/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct'
lora_path = '/mnt/hdd0/Kaggle/arc25/trainings/20250430_first_trainings/steps_6400/checkpoint-6400'

model = AutoModelForCausalLM.from_pretrained(
    base_model_path, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(lora_path)
model = PeftModel.from_pretrained(model, lora_path, is_trainable=True)

In [None]:
prompt_version = 'code-from-examples-v3'
grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))')

In [None]:
raise

### Verify that task is not solvable

In [None]:
prompt = create_prompt_from_task(
    task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

In [None]:
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=1024,
    do_sample=True,
    temperature=0.5,
    top_p=0.95,
    num_return_sequences=256
)
generated_ids = generated_ids[:, len(model_inputs.input_ids[0]):]
print(f'Generated ids shape: {generated_ids.shape}')
predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

| batch size | inference time(s) | throughput (preds/s) |
|------------|-------------------|----------------------|
| 1          | 6.4               | 0.2                  |
| 4          | 7.4               | 0.5                  |
| 16         | 8.5               | 1.9                  |
| 64         | 9                 | 7.1                  |
| 128        | 10.9              | 11.7                 |
| 256        | 15.3              | 16.7                 |
| 512        | 30.1              | 17.0                 |

A batch size of 256 might be the sweet spot. It takes just twice as making two predictions with batch size 1.

In [None]:
predicted_codes = [prediction.replace('\n```', '') for prediction in predictions]
new_tasks = []
for predicted_code in tqdm(predicted_codes):
    try:
        predicted_output = safe_code_execution(predicted_code, task.inputs)
        new_tasks.append(Task(inputs=task.inputs, outputs=predicted_output, code=predicted_code, name=task.name))
    except Exception as e:
            print(f'Error executing code: {predicted_code}')
            print(e)

In [None]:
new_tasks_with_unique_outputs = [new_tasks[0]]
for new_task in new_tasks[1:]:
    if not any([np.all(new_task.outputs[0] == t.outputs[0]) for t in new_tasks_with_unique_outputs]):
        new_tasks_with_unique_outputs.append(new_task)
print(f'Number of unique outputs: {len(new_tasks_with_unique_outputs)}')

In [None]:
pixel_accuracies = []
for new_task in tqdm(new_tasks_with_unique_outputs):
    pixel_accuracies.append(float(np.mean(new_task.outputs[0] == task.outputs[0])))
print(f'Max pixel accuracy: {max(pixel_accuracies)}')

In [None]:
plt.hist(pixel_accuracies, bins=20, log=True);

In [None]:
new_tasks_with_unique_outputs = sorted(new_tasks_with_unique_outputs, key=lambda x: float(np.mean(x.outputs[0] == task.outputs[0])), reverse=False)

In [None]:
plot_task(new_tasks_with_unique_outputs[0])

In [None]:
plot_task(new_tasks_with_unique_outputs[-1])

This probes that if we use the fine-tuned model as it is it is not capable of solving the task. What if we further fine-tune it on the new tasks.

### HER

In [None]:
def hindsight_experience_replay(task, max_epochs=10):
    """
    Use hindsight experience replay to try to solve new tasks
    """
    for epoch in range(max_epochs):
        print(f'Epoch {epoch}')
        new_tasks = inference(task)
        finetuning(new_tasks)


def inference(task):
    prompt = create_prompt_from_task(
        task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
    model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.5,
        top_p=0.95,
        num_return_sequences=256
    )
    generated_ids = generated_ids[:, len(model_inputs.input_ids[0]):]
    print(f'Generated ids shape: {generated_ids.shape}')
    predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    predicted_codes = [prediction.replace('\n```', '') for prediction in predictions]
    new_tasks = []
    pixel_accuracies = []
    for predicted_code in tqdm(predicted_codes):
        try:
            predicted_output = safe_code_execution(predicted_code, task.inputs)
            new_tasks.append(Task(inputs=task.inputs, outputs=predicted_output, code=predicted_code, name=task.name))
            pixel_accuracies.append(float(np.mean(new_tasks[-1].outputs[0] == task.outputs[0])))
        except Exception as e:
                print(f'Error executing code: {predicted_code}')
                print(e)
    plt.hist(pixel_accuracies, bins=np.linspace(0, 1, 20), log=True, label='all predictions');
    new_tasks_with_unique_outputs = [new_tasks[0]]
    for new_task in new_tasks[1:]:
        if not any([np.all(new_task.outputs[0] == t.outputs[0]) for t in new_tasks_with_unique_outputs]):
            new_tasks_with_unique_outputs.append(new_task)
    print(f'Number of unique outputs: {len(new_tasks_with_unique_outputs)}')

    pixel_accuracies = []
    for new_task in tqdm(new_tasks_with_unique_outputs):
        pixel_accuracies.append(float(np.mean(new_task.outputs[0] == task.outputs[0])))
    print(f'Max pixel accuracy: {max(pixel_accuracies)}')

    plt.hist(pixel_accuracies, bins=np.linspace(0, 1, 20), log=True, label='unique predictions', alpha=0.5);
    plt.legend()
    plt.show()
    new_tasks_with_unique_outputs = sorted(new_tasks_with_unique_outputs, key=lambda x: float(np.mean(x.outputs[0] == task.outputs[0])), reverse=False)
    return new_tasks_with_unique_outputs


def finetuning(new_tasks):
    prompts = []
    for task in new_tasks:
        prompts.append(create_prompt_from_task(
    task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=True))
    train_dataset = Dataset.from_dict({'text': prompts})

    training_arguments = SFTConfig(
        output_dir='/mnt/hdd0/Kaggle/arc25/trainings/20250505_TTT/debug',
        num_train_epochs=1,
        warmup_ratio=0.1,
        learning_rate=1e-5,
        lr_scheduler_type='constant_with_warmup', #constant_with_warmup, cosine, cosine_with_restarts
        # lr_scheduler_kwargs=lr_scheduler_kwargs,
        gradient_checkpointing=False,
        optim="paged_adamw_8bit",
        max_grad_norm=1.0,

        dataset_text_field="text",
        max_seq_length=4096,

        do_eval=True,
        eval_strategy="no", #TODO: previously it was steps
        # save_steps=cfg.save_steps or cfg.eval_steps,
        logging_steps=10, #50,
        log_level="info",
        report_to='none',

        # parameters added to make the code work with accelerate
        # dispatch_batches=False,
        # https://huggingface.co/transformers/v4.9.1/main_classes/trainer.html#trainingarguments
        ddp_find_unused_parameters=False, # only used with accelerate, got a warning saying that it slows down if True

        ignore_data_skip=True, # otherwise it takes too long to start training when resuming from checkpoint

        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
    )

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        data_collator=get_data_collator(tokenizer),
        args=training_arguments,
    )
    trainer.train()

In [None]:
hindsight_experience_replay(task, max_epochs=3)

In [None]:
input_img = create_img((9, 9), color=0)
output_img = input_img.copy()
color = 0
for x in range(0, input_img.shape[1], 3):
    for y in range(0, input_img.shape[0], 3):
        color += 1
        draw_rectangle(output_img, (x, y), (x+2, y+2), color=color)
task = Task(inputs=[input_img], outputs=[output_img], code='', name='manual')
plot_task(task)
hindsight_experience_replay(task, max_epochs=5)

## TODO

- Encapsulate steps into function
- Try on other tasks (I might think of a more complex tasks with pixels)
- Parametrize batch size
- print successfull code
- Better progress visualization