## Install the required libraries

In [None]:
!pip install -q wandb datasets trl

In [None]:
import torch
print("torch version:", torch.__version__)

In [None]:
import wandb
print("wandb version:", wandb.__version__)
wandb.login(key='SECRET')

In [None]:
from huggingface_hub import login

login(token='SECRET')

In [None]:
base_model_name = "simon-arc-lab-model647"
result_model_name = "simon-arc-lab-model648"
dataset_path = "neoneye/simon-arc-solve-fractal-v8"
base_model_path = f"neoneye/{base_model_name}"

max_input_length = 1024
max_target_length = 128
my_learning_rate = 1e-6

## Preprocess data


In [None]:
from datasets import load_dataset

full_dataset = load_dataset(dataset_path)
#print(full_dataset)
train_dataset = full_dataset['train']
#print(train_dataset)

# Extract a few lines
dataset = train_dataset.select(range(1000))

# Define a function that adds an 'id' field to each row
def add_id(example, idx):
    example['id'] = idx
    return example

# Use the map function to apply the add_id function to each example
# The with_indices=True parameter passes the index to the function
dataset = dataset.map(add_id, with_indices=True)

print(dataset)

As you can see, the "code-to-text/ruby" split consists of a training, validation and test set. Let's look at one particular example:

In [None]:
for i in range(1):
  example = dataset[i + 5]
  print("example:", i + 5)
  print("instruction:", example["instruction"])
  print("input:", example["input"])
  print("output:", example["output"])
  print("id:", example["id"])

In [None]:
!git clone https://github.com/neoneye/simon-arc-lab.git

In [None]:
!cd simon-arc-lab && sh test.sh

In [None]:
import sys
repo_path = '/content/simon-arc-lab'
while repo_path in sys.path:
  sys.path.remove(repo_path)
sys.path.append(repo_path)
print(sys.path)

In [None]:
import json
from simon_arc_lab.task import Task

all_tasks = {}
all_test_indexes = {}

for dataset_row in dataset:
  field_id_raw = dataset_row["id"]
  field_test_index_raw = dataset_row["test_index"]
  task_id = f'row_{field_id_raw}'
  json_string = dataset_row["arc_task"]
  json_dict = json.loads(json_string)
  task = Task.create_with_arcagi1_json(json_dict)
  task.metadata_task_id = task_id
  #print(task)
  #task.show()
  all_tasks[task_id] = task
  all_test_indexes[task_id] = field_test_index_raw
  #print(f"task_id: {task_id} test_index: {field_test_index_raw}")
  #break
print('number of tasks:', len(all_tasks))
print('number of test_indexes:', len(all_test_indexes))

In [None]:
from simon_arc_lab.task_similarity import TaskSimilarity

task_id_to_task_similarity = {}
for task_id, task in all_tasks.items():
  ts = TaskSimilarity.create_with_task(task)
  summary = ts.summary()
  #print(f"task: {task_id}  summary: {summary}")
  task_id_to_task_similarity[task_id] = ts

print('number of tasksimilarity instances:', len(task_id_to_task_similarity))

In [None]:
from transformers import RobertaTokenizer

tokenizer = RobertaTokenizer.from_pretrained(base_model_path)

def preprocess_examples(examples):
  # concatenate "instruction" and "input" with a newline
  instructions = examples['instruction']
  inputs = examples['input']

  concatenated_inputs = [f"{instruction}\n{input_data}" for instruction, input_data in zip(instructions, inputs)]
  model_inputs = tokenizer(concatenated_inputs, max_length=max_input_length, padding="max_length", truncation=True)

  # encode the outputs
  outputs = examples['output']
  labels = tokenizer(outputs, max_length=max_target_length, padding="max_length", truncation=True).input_ids

  # replace the index of the padding tokens by -100
  labels_with_ignore_index = []
  for labels_example in labels:
    labels_example = [label if label != 0 else -100 for label in labels_example]
    labels_with_ignore_index.append(labels_example)

  model_inputs["labels"] = labels_with_ignore_index

  return model_inputs

In [None]:
from datasets import DatasetDict

# Split the dataset into train and test (80% train, 20% test)
train_testvalid = dataset.train_test_split(test_size=0.2)
train_test = DatasetDict({
    'train': train_testvalid['train'],
    'test': train_testvalid['test']
})

# Split the training set again to create a validation set (10% of the original train set)
train_valid = train_test['train'].train_test_split(test_size=0.1)

# Combine to create a final dataset dictionary
final_datasets = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': train_test['test']
})

# Print the dataset splits
print(final_datasets)

# Apply the preprocessing function to all splits
final_datasets = final_datasets.map(preprocess_examples, batched=True)

# Set format for PyTorch DataLoader
final_datasets.set_format(type="torch", columns=['input_ids', 'attention_mask', 'labels', 'id'])

# Create DataLoaders
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    final_datasets['train'],
    shuffle=True,
    batch_size=64,
    drop_last=True, # This ensures that any incomplete batches (those with fewer than batch_size samples) are dropped. This way, all batches passed to the ppo_trainer will have the expected batch size.
)
valid_dataloader = DataLoader(final_datasets['validation'], batch_size=4)
test_dataloader = DataLoader(final_datasets['test'], batch_size=4)

print("DataLoaders created successfully.")

Now that we have defined the function, let's call `.map()` on the HuggingFace Dataset object, which allows us to apply this function in batches (by default a batch size of 1,000 is used!) - hence super fast.

Next, let's set the format to "torch" and create PyTorch dataloaders.

In [None]:
batch = next(iter(train_dataloader))
print("batch.keys:\n", batch.keys())

print("\ninput_ids:\n", tokenizer.decode(batch['input_ids'][0]))

labels = batch['labels'][0]
decoded = tokenizer.decode([label for label in labels if label != -100])
print("\ndecoded\n", decoded)

## Fine-tune using PyTorch Lightning

As we will train the model using PyTorch Lightning, we first need to define a `LightningModule`, which is an `nn.Module` with some additional functionalities. We just need to define the `forward` pass, `training_step` (and optionally `validation_step` and `test_step`), and the corresponding dataloaders. PyTorch Lightning will then automate the training for us, handling device placement (i.e. we don't need to type `.to(device)` anywhere), etc. It also comes with support for loggers (such as Tensorboard, Weights and Biases) and callbacks.

Of course, you could also train the model in other ways:
* using regular PyTorch
* using the HuggingFace Trainer (in this case, the Seq2SeqTrainer)
* using HuggingFace Accelerate
* etc.

In [None]:
from trl import PPOTrainer, PPOConfig
from trl import AutoModelForSeq2SeqLMWithValueHead

from transformers import AutoTokenizer
from transformers import T5ForConditionalGeneration, AdamW, get_linear_schedule_with_warmup

In [None]:
# Remove -100 and padding tokens before decoding
def clean_labels(labels):
    return [[token for token in sequence if token != -100 and token != tokenizer.pad_token_id] for sequence in labels]


In [None]:
import torch

# Dynamically check and import TPU module if TPU is available
def get_device():
    try:
        import torch_xla.core.xla_model as xm
        device = xm.xla_device()
        print("TPU is detected.")
        return device
    except ImportError:
        if torch.cuda.is_available():
            print("GPU is detected.")
            return torch.device('cuda')
        else:
            print("No TPU or GPU detected. Using CPU")
            return torch.device('cpu')

# Detect device
device = get_device()
print("Using device:", device)

In [None]:
#from simon_arc_lab.rle.deserialize import deserialize, DecodeRLEError

#image = deserialize('1 1 5')
#print(image)

In [None]:
from tqdm.notebook import tqdm
import numpy as np
from simon_arc_lab.rle.deserialize import deserialize, DeserializeError
from simon_arc_lab.image_pixel_similarity import image_pixel_similarity_jaccard_index
from simon_arc_lab.image_transition_similarity import image_transition_similarity

class CustomPPOTrainer(PPOTrainer):
    def __init__(self, *args, total_training_steps=None, **kwargs):
        super().__init__(*args, **kwargs)
        if total_training_steps is None:
            # If total_training_steps is not provided, calculate it
            total_training_steps = self.config.epochs * len(self.dataloader)  # Adjust as needed
        self.total_training_steps = total_training_steps

    def create_optimizer(self):
        """Initialize the optimizer with custom parameters."""
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-6,
            weight_decay=0.01,
        )

    def create_scheduler(self):
        """Initialize the scheduler with a warm-up period."""
        num_warmup_steps = int(0.1 * self.total_training_steps)  # 10% warm-up
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=self.total_training_steps,
        )

num_epochs = 4
total_training_steps = num_epochs * len(train_dataloader)

model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(base_model_path)

# Ensure model is on the correct device
model = model.to(device)

ppo_config = PPOConfig(
    model_name=base_model_path,
    learning_rate=my_learning_rate,
    batch_size=64,
    mini_batch_size=8,
    gradient_accumulation_steps=8,
    max_grad_norm=0.5,
    cliprange=0.05,
    init_kl_coef=0.1,
    target=0.01,
    #adap_kl_ctrl=True,
    # log_with='wandb',  # Uncomment if using logging
)

ppo_trainer = CustomPPOTrainer(
    config=ppo_config,
    model=model,
    tokenizer=tokenizer,
    total_training_steps=total_training_steps,
)

def compute_jaccard_index(predicted_str, expected_str, row_id) -> int:
    task_id = f'row_{row_id}'
    #return 0.5
    #print(f"compute_jaccard_index for task_id: {task_id}")
    #print("predicted_str", predicted_str)
    #print("expected_str", expected_str)
    if task_id not in all_tasks:
      raise ValueError(f"Task ID {task_id} not found in all_tasks.")
    task = all_tasks[task_id]

    if task_id not in task_id_to_task_similarity:
      raise ValueError(f"Task ID {task_id} not found in task_id_to_task_similarity.")
    task_similarity = task_id_to_task_similarity[task_id]

    if task_id not in all_test_indexes:
      raise ValueError(f"Task ID {task_id} not found in all_test_indexes.")
    test_index = all_test_indexes[task_id]

    try:
      image = deserialize(predicted_str)
    except DeserializeError as e:
      #print(f"Error decoding RLE string: {e}")
      return -1.0 + (e.score / 100.0) * 0.5
      #return -1.0
    score1 = task_similarity.measure_test_prediction(image, test_index) / 100.0

    expected_output = task.test_output(test_index)
    score2 = image_pixel_similarity_jaccard_index(expected_output, image) / 100.0
    score3_intersection, score3_union = image_transition_similarity(expected_output, image)
    if score3_union > 0:
      score3 = score3_intersection / score3_union
    else:
      score3 = 0
    #print(f"task_id: {task_id} test_index: {test_index} score: {score1} {score2} {score3}")
    score = (1.0 + score1 + score2 + score3) / 4.0
    return score

def compute_rewards(samples, expected_outputs, row_ids):
    rewards = []
    for sample, expected, row_id in zip(samples, expected_outputs, row_ids):
        reward = compute_jaccard_index(sample, expected, row_id)
        rewards.append(reward)
    return rewards

def reward_list_to_str(reward_list) -> str:
  reward_str_list = [str(int(r * 100)) for r in reward_list]
  return ' '.join(reward_str_list)

generate_kwargs = {
    'max_new_tokens': 128,
    'do_sample': True,
    'top_k': 50,
    'top_p': 0.95,
    'temperature': 0.9,
}

for epoch in range(num_epochs):
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        #print('step')
        #print("batch", batch.keys())
        task_ids = batch['id']
        #print("task_ids", task_ids)
        query_tensors = batch['input_ids']
        expected_output_labels = batch['labels']
        #print("query_tensors", query_tensors)
        #print("expected_output_labels", expected_output_labels)

        query_tensors_ondevice = query_tensors.to(device)
        response_tensors_ondevice = model.generate(
            query_tensors_ondevice,
            **generate_kwargs,
        )
        response_tensors = response_tensors_ondevice.cpu()

        cleaned_expected_output_labels = clean_labels(expected_output_labels.tolist())
        expected_outputs = tokenizer.batch_decode(cleaned_expected_output_labels, skip_special_tokens=True)

        batch_samples = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
        reward_list = compute_rewards(batch_samples, expected_outputs, task_ids)
        #rewards_humanreadable = [str(int(r * 100)) for r in reward_list]
        #rewards_humanreadable_spaced = ' '.join(rewards_humanreadable)
        print("rewards", reward_list_to_str(reward_list))

        np_reward_list = np.array(reward_list)
        if np.all(np_reward_list < 0.0):
          print("skip PPO step as it won't result in any parameter updates")
          continue

        number_of_negatives_allowed = 3
        count_negative = 0
        for reward_index, reward in enumerate(reward_list):
          if reward < -0.5:
            count_negative += 1
            if count_negative > number_of_negatives_allowed:
              reward_list[reward_index] = 0.0

        #min_reward = np.min(reward_list)
        #max_reward = np.max(reward_list)
        #mean_reward = np.mean(reward_list)
        #std_reward = np.std(reward_list) + 1e-8
        #normalized_reward_list = (reward_list - mean_reward) / std_reward
        #normalized_reward_list = np.clip(normalized_reward_list, -1.0, 1.0)
        #normalized_reward_list = 2 * (reward_list - min_reward) / (max_reward - min_reward + 1e-8) - 1
        #print("rewards normalized", reward_list_to_str(normalized_reward_list))

        rewards = [torch.tensor(reward, dtype=torch.float32) for reward in reward_list]
        #rewards = [torch.tensor(reward, dtype=torch.float32) for reward in normalized_reward_list]
        #rewards = [torch.tensor(0.0, dtype=torch.float32) for reward in reward_list]

        # Convert query_tensors and response_tensors to lists of tensors
        query_tensors_list = list(torch.unbind(query_tensors, dim=0))  # Convert tuple to list
        response_tensors_list = list(torch.unbind(response_tensors, dim=0))  # Convert tuple to list

        stats = ppo_trainer.step(query_tensors_list, response_tensors_list, rewards)



In [None]:
save_directory = "mymodel" # save in the current working directory, you can change this of course
model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)

In [None]:
from huggingface_hub import HfApi, HfFolder

# Set your repository details
model_path = save_directory

# Create a repository if it doesn't exist
api = HfApi()
username = api.whoami()['name']
repo_url = api.create_repo(repo_id=result_model_name, exist_ok=True, private=True)

# Upload files to the repository
from huggingface_hub import upload_folder
upload_folder(
    folder_path=model_path,
    repo_id=f"{username}/{result_model_name}",
    commit_message="Initial model upload"
)