# Hindsight Experience Replay v2

## Goal

Improve the HER algorithm.

## Imports

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tqdm.auto import tqdm
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from transformers import AutoTokenizer, AutoConfig
import matplotlib.pyplot as plt
import matplotlib as mpl
from datasets import Dataset
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import logging
from IPython.display import Markdown, display
import torch
import random
from typing import List
from dataclasses import field
import wandb

from arc25.training_tasks import *
from arc25.encoders import create_grid_encoder
from arc25.prompting import create_prompt_from_task, pretty_print_prompt
from arc25.plot import plot_task, plot_grids_with_shape, plot_grid
from arc25.code_execution import safe_code_execution, validate_code
from arc25.utils import set_random_seed, get_timestamp
from arc25.logging import configure_logging, log_execution_time

configure_logging()

import sys
sys.path.append(os.path.realpath("../scripts"))
from finetuning import get_data_collator


plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 12

## Code

In [None]:
def get_task(task_name):
    tasks = []

    input_img = create_img((9, 9), color=0)
    output_img = input_img.copy()
    for x in range(0, input_img.shape[1], 1):
        draw_vertical_line(output_img, x, color=x+1)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='9-vertical-lines'))

    input_img = create_img((10, 8), color=0)
    output_img = input_img.copy()
    color = 0
    for x in range(0, input_img.shape[1], 2):
        for y in range(0, input_img.shape[0], 2):
            color = (color + 1) % 10
            if color == 0: color = 1
            draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='20-squares'))

    input_img = create_img((6, 8), color=0)
    output_img = input_img.copy()
    color = 0
    for x in range(0, input_img.shape[1], 2):
        for y in range(0, input_img.shape[0], 2):
            color = (color + 1) % 10
            if color == 0: color = 1
            draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='12-squares'))

    input_img = create_img((8, 8), color=0)
    output_img = input_img.copy()
    color = 0
    for x in range(0, input_img.shape[1], 2):
        for y in range(0, input_img.shape[0], 2):
            color = (color + 1) % 10
            if color == 0: color = 1
            draw_rectangle(output_img, (y, x), (y+1, x+1), color=color)
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='16-squares'))

    input_img = create_img((10, 10), color=0)
    output_img = Img([
        [8, 8, 8, 8, 4, 4, 8, 8, 8, 8],
        [8, 8, 4, 4, 4, 4, 4, 4, 8, 8],
        [8, 4, 4, 0, 4, 4, 0, 4, 4, 8],
        [8, 4, 2, 4, 4, 7, 4, 2, 4, 8],
        [8, 4, 4, 4, 7, 7, 4, 4, 4, 8],
        [8, 8, 4, 4, 4, 4, 4, 4, 8, 8],
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
        [4, 4, 4, 4, 4, 4, 4, 4, 4, 4],
        [8, 4, 4, 4, 4, 4, 4, 4, 4, 8],
        [8, 8, 4, 7, 4, 4, 7, 4, 8, 8],
    ])
    tasks.append(Task(inputs=[input_img], outputs=[output_img], code='', name='chick'))

    for task in tasks:
        if task.name == task_name:
            return task
    raise ValueError(f"Task {task_name} not found. Available tasks: {[task.name for task in tasks]}")

In [None]:
EpochResults = namedtuple("EpochResults", ["best_prediction", 'pixel_accuracies'])
InferenceParams = namedtuple("InferenceParams", ["num_return_sequences", "temperature"])

@log_execution_time
def hindsight_experience_replay(task, cfg):
    """
    Use hindsight experience replay to try to solve new tasks
    """
    wandb.init(project='HER_v2', name=f'{task.name}_{get_timestamp()}', config=cfg, reinit=True)
    #wandb.run.log_code(os.path.dirname(__file__))
    fig = plt.figure()
    plot_task(task); plt.suptitle('Task to solve'); plt.tight_layout()
    wandb.log({"task": wandb.Image(fig)}); plt.show()
    model, tokenizer = load_model(cfg.base_model_path, cfg.lora_path)
    metrics = []
    for epoch in range(1, cfg.max_epochs + 1):
        logging.info(f'Starting epoch {epoch}...')
        new_tasks, pixel_accuracies = inference(
            task, model, tokenizer, cfg.grid_encoder, cfg.prompt_version,
            inference_params=cfg.inference_params)
        metrics.append(EpochResults(best_prediction=new_tasks[-1], pixel_accuracies=pixel_accuracies))
        fig = plt.figure()
        plot_grid(new_tasks[-1].outputs[0])
        wandb.log({"epoch": epoch, "max_pixel_accuracy": max(pixel_accuracies),
                   "mean_pixel_accuracy": np.mean(pixel_accuracies),
                   "min_pixel_accuracy": min(pixel_accuracies),
                   "best_prediction": wandb.Image(fig),
                   'pixel_accuracy': wandb.Histogram(pixel_accuracies),
                   'best_code': wandb.Html(f'<pre>{new_tasks[-1].code}</pre>'),
                   },
                   step=epoch, commit=True)
        plt.close(fig)
        plot_metrics_evolution(metrics)
        if np.max(pixel_accuracies) == 1:
            logger.info(f'Found a perfect prediction at epoch {epoch}!')
            break
        if not cfg.use_accuracy_for_sorting:
            logging.info('Shuffling the tasks, no information about the accuracy is used')
            random.shuffle(new_tasks)
        finetuning(new_tasks, model, tokenizer, cfg.grid_encoder, cfg.prompt_version)
    display(Markdown(f'# Best prediction code\n\n```python\n{metrics[-1].best_prediction.code}\n```'))
    wandb.finish()
    return metrics

@log_execution_time
def load_model(base_model_path, lora_path):
    logging.info(f"Loading model from {base_model_path} and LoRA from {lora_path}")
    torch.cuda.empty_cache()
    model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype="auto", device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(lora_path)
    model = PeftModel.from_pretrained(model, lora_path, is_trainable=True)
    return model, tokenizer

@log_execution_time
def inference(task, model, tokenizer, grid_encoder, prompt_version, inference_params):
    prompt = create_prompt_from_task(
        task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=False)
    model_inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    predicted_codes = []
    for params in inference_params:
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=params.temperature,
            top_p=0.95,
            num_return_sequences=params.num_return_sequences,
        )
        generated_ids = generated_ids[:, len(model_inputs.input_ids[0]):]
        predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        predicted_codes.extend([prediction.replace('\n```', '') for prediction in predictions])

    new_tasks = []
    pixel_accuracies = []
    for predicted_code in tqdm(predicted_codes):
        try:
            predicted_output = safe_code_execution(predicted_code, task.inputs)
            validated_code = validate_code(predicted_code, task.inputs)
            new_tasks.append(Task(inputs=task.inputs, outputs=predicted_output, code=validated_code, name=task.name))
            pixel_accuracies.append(float(np.mean(new_tasks[-1].outputs[0] == task.outputs[0])))
        except Exception as e:
                print(f'Error executing code: {predicted_code}')
                print(e)

    new_tasks_with_unique_outputs = [new_tasks[0]]
    filtered_pixel_accuracies = []
    for new_task in new_tasks[1:]:
        if not any([np.all(new_task.outputs[0] == t.outputs[0]) for t in new_tasks_with_unique_outputs]):
            new_tasks_with_unique_outputs.append(new_task)
            filtered_pixel_accuracies.append(float(np.mean(new_task.outputs[0] == task.outputs[0])))
    logging.info(f'Number of unique outputs: {len(new_tasks_with_unique_outputs)}/{len(new_tasks)}')
    logging.info(f'Max pixel accuracy: {max(pixel_accuracies)}')
    new_tasks_with_unique_outputs = sorted(new_tasks_with_unique_outputs, key=lambda x: float(np.mean(x.outputs[0] == task.outputs[0])), reverse=False)
    return new_tasks_with_unique_outputs, pixel_accuracies

@log_execution_time
def finetuning(new_tasks, model, tokenizer, grid_encoder, prompt_version):
    prompts = []
    for task in new_tasks:
        prompts.append(create_prompt_from_task(
    task, prompt_version=prompt_version, grid_encoder=grid_encoder, tokenizer=tokenizer, is_train_prompt=True))
    train_dataset = Dataset.from_dict({'text': prompts})

    training_arguments = SFTConfig(
        output_dir=None, #'/mnt/hdd0/Kaggle/arc25/trainings/20250505_TTT/debug',
        save_strategy='no',
        num_train_epochs=1,
        warmup_ratio=0.1,
        learning_rate=1e-5,
        lr_scheduler_type='constant_with_warmup', #constant_with_warmup, cosine, cosine_with_restarts
        # lr_scheduler_kwargs=lr_scheduler_kwargs,
        gradient_checkpointing=False,
        optim="paged_adamw_8bit",
        max_grad_norm=1.0,

        dataset_text_field="text",
        max_seq_length=4096,

        do_eval=True,
        eval_strategy="no", #TODO: previously it was steps
        # save_steps=cfg.save_steps or cfg.eval_steps,
        logging_steps=10, #50,
        log_level="info",
        report_to='none',

        # parameters added to make the code work with accelerate
        # dispatch_batches=False,
        # https://huggingface.co/transformers/v4.9.1/main_classes/trainer.html#trainingarguments
        ddp_find_unused_parameters=False, # only used with accelerate, got a warning saying that it slows down if True

        ignore_data_skip=True, # otherwise it takes too long to start training when resuming from checkpoint

        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
    )

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        data_collator=get_data_collator(tokenizer),
        args=training_arguments,
    )
    trainer.train()


def plot_best_prediction(task, best_prediction, accuracy):
    plot_grids_with_shape(task.outputs + best_prediction.outputs, suptitle=f'Best prediction accuracy: {accuracy:.1%}')
    display(Markdown(f'```python\n{best_prediction.code}\n```'))


def plot_metrics_evolution(metrics):
    plot_score_histograms(metrics)

    for epoch, epoch_results in enumerate(metrics):
        plt.subplot(1, len(metrics), epoch + 1)
        plot_grid(epoch_results.best_prediction.outputs[0])
        plt.title(f'Epoch {epoch} acc: {max(epoch_results.pixel_accuracies):.1%}')
    plt.suptitle('Evolution of best predictions')
    plt.tight_layout()
    plt.show()


def plot_score_histograms(metrics, offset_scale=1):
    """
    Plots stacked (y-offset) histograms
    """
    cmap = mpl.colormaps['viridis']#get_cmap("viridis")
    norm = plt.Normalize(0, len(metrics) - 1)
    bins = np.linspace(0, 1, 100)
    bin_centers = 0.5 * (bins[1:] + bins[:-1])

    plt.figure(figsize=(10, 6))
    for i, epoch_results in enumerate(metrics):
        color = cmap(norm(i))
        counts, _ = np.histogram(epoch_results.pixel_accuracies, bins=bins)
        counts = np.log1p(counts)
        offset = i * np.max(counts) * offset_scale  # Add spacing between histograms
        plt.fill_between(bin_centers, offset, counts + offset, color=color, label=f'Epoch {i}', alpha=0.5)

    plt.xlabel("Pixel accuracy")
    plt.ylabel("Epoch ->")
    plt.title("Evolution of pixel accuracy")
    plt.yticks([])  # Hide y-ticks since they don't represent absolute values
    plt.grid(axis='x')
    plt.tight_layout()
    plt.show()

## First experiments

In [None]:
@dataclass
class Config:
    base_model_path: str = '/home/gbarbadillo/models/Qwen2.5-Coder-0.5B-Instruct'
    lora_path: str = '/mnt/hdd0/Kaggle/arc25/trainings/20250430_first_trainings/steps_6400/checkpoint-6400'
    prompt_version: str = 'code-from-examples-v3'
    grid_encoder = create_grid_encoder('GridShapeEncoder(RowNumberEncoder(MinimalGridEncoder()))')
    max_epochs: int = 10
    temperature: float = 0.5
    n_predictions: int = 256 # 256 seems to be the best for my hardware
    use_accuracy_for_sorting: bool = True
    inference_params: List[InferenceParams] = field(default_factory=lambda: [
        InferenceParams(num_return_sequences=16, temperature=0.25),
        InferenceParams(num_return_sequences=128, temperature=0.5),
        InferenceParams(num_return_sequences=128, temperature=0.75),
    ])

In [None]:
task = get_task('12-squares')
cfg = Config(inference_params=[
    InferenceParams(num_return_sequences=256, temperature=0.5),
])
hindsight_experience_replay(task, cfg);

In [None]:
raise

In [None]:
code_str = """
def greet(name):
    print(f"Hello, {name}!")

greet("W&B")
"""

# Format as Markdown with syntax highlighting
md_code = f"```python\n{code_str}\n```"

wandb.log({"code_snippet": wandb.Html(md_code)})

In [None]:
wandb.log({"code_snippet_text": code_str})

In [None]:
task = get_task('chick')

fig = plt.figure()
plot_grid(task.outputs[0])
wandb.log({"task": wandb.Image(fig)})

In [None]:
task = get_task('chick')

for step in range(1, 5):
    fig = plt.figure()
    plot_grid(task.outputs[0][step:-step, step:-step])
    wandb.log({"task": wandb.Image(fig)}, step=step, commit=True)
    plt.close(fig)

In [None]:
wandb.finish()

In [None]:
task.outputs[0][0]

In [None]:
raise

In [None]:
task = get_task('chick')
cfg = Config(inference_params=[
    InferenceParams(num_return_sequences=16, temperature=0.1),
    InferenceParams(num_return_sequences=16, temperature=0.25),
    InferenceParams(num_return_sequences=64, temperature=0.5),
    InferenceParams(num_return_sequences=128, temperature=0.75),
    InferenceParams(num_return_sequences=128, temperature=0.9),
])
hindsight_experience_replay(task, cfg);

## Results

I believe that the problem lies that in the latest epochs the number of different predictions decreases. We still need exploration to achieve the perfect solution.

```
cfg = Config(inference_params=[
    InferenceParams(num_return_sequences=16, temperature=0.25),
    InferenceParams(num_return_sequences=128, temperature=0.5),
    InferenceParams(num_return_sequences=128, temperature=0.75),
])
97% 951s


cfg = Config(inference_params=[
    InferenceParams(num_return_sequences=16, temperature=0.1),
    InferenceParams(num_return_sequences=16, temperature=0.25),
    InferenceParams(num_return_sequences=64, temperature=0.5),
    InferenceParams(num_return_sequences=128, temperature=0.75),
    InferenceParams(num_return_sequences=128, temperature=0.9),
])
97% 1535s
```


## TODO

- Better selection of the task for unique output (choose shortest)
- Print the code on each epoch
- Show task length evolution
- Better output show when epochs > 10
- What is the best way to distribute temperatures?
- Improve metrics plot
- Maybe I have to keep all the tasks to preserve diversity, and to avoid retraining all the time on the same task
- Add wandb with images