## Fine-tuning of a Large Language Model on AugARC Data

### Training script for any open-source LLM on the ARC Augmented Training Data with 2000 tasks

In [None]:
import transformers
import torch
from datasets import  Dataset
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer

### Load LLM and Tokenizer

In [None]:
model_name = "meta-llama/Meta-Llama-3-70B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4b_quant_type='nf4',
    torch_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_safetensors=True,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, truncation=True, max_length=4096)
tokenizer.pad_token=tokenizer.eos_token

In [None]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### Load LoRA Adapter

In [None]:
config = LoraConfig(
    r=32,
    lora_alpha=16,
    bias="none",
    task_type="CASUAL_LM",
)

In [None]:
model=get_peft_model(model, config)

### Dataset preparation

In [None]:
import os
import json

def read_json_files(directory):
    json_data = []

    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            file_path = os.path.join(directory, filename)

            with open(file_path, 'r') as file:
                try:
                    data = json.load(file)
                    json_data.append(data)
                except json.JSONDecodeError as e:
                    print(f"Error reading {filename}: {e}")

    return json_data

directory = 'arc_data/training'
all_json_data = read_json_files(directory)

In [None]:
def transform_input(data):
    transformed_data = {}

    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = case['input']
            output_matrix = case['output']
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data

In [None]:
def extract_after_output(text):
    index = text.find('###Output:\n')
    if index != -1:
        return text[index + len('###Output:\n'):]
    else:
        return text

def extract_before_output(text):
    index = text.find('###Output:\n')
    if index != -1:
        return text[:index]
    else:
        return text

In [None]:
DEFAULT_PROMPT = "We are playing a game which involves transforming a 2D input grid of digits into an output grid of digits. Every below pair of grids contains the same transformation. Each Input grid is followed by an Output grid which applies the same transformation as previous Input/Output pairs. Given the provided examples, output the correct grid for the last input"

def generate_train_prompt(data_point):
    train = data_point['train']
    test = data_point['test']
    correct_output = extract_after_output(test).strip()
    text = f'{DEFAULT_PROMPT}\n{train}\n{test}'
    return {'text': text, 'labels': correct_output}

In [None]:
def flip_2d_list(matrix, flip_type):
    if flip_type == 'horizontal':
        # Flip each row
        return [row[::-1] for row in matrix]
    elif flip_type == 'vertical':
        # Reverse the order of rows
        return matrix[::-1]
    else:
        raise ValueError("Invalid flip type. Use 'horizontal' or 'vertical'.")

def rotate_matrix_90_degrees(matrix):
    return [list(row) for row in zip(*matrix[::-1])]

def rotate_matrix_270_degrees(matrix):
    return [list(row) for row in zip(*matrix)][::-1]


In [None]:
def transform_input_horizontal(data):
    transformed_data = {}

    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = flip_2d_list(case['input'], 'horizontal')
            output_matrix = flip_2d_list(case['output'], 'horizontal')
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data


In [None]:
def transform_input_vertical(data):
    transformed_data = {}

    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = flip_2d_list(case['input'], 'vertical')
            output_matrix = flip_2d_list(case['output'], 'vertical')
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            # Add the output matrix to the transformed data
            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data


In [None]:
def transform_input_270(data):
    transformed_data = {}

    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = rotate_matrix_270_degrees(case['input'])
            output_matrix = rotate_matrix_270_degrees(case['output'])
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data

In [None]:
def transform_input_90(data):
    transformed_data = {}

    for key in ['train', 'test']:
        transformed_data[key] = ''
        for case in data[key]:
            input_matrix = rotate_matrix_90_degrees(case['input'])
            output_matrix = rotate_matrix_90_degrees(case['output'])
            new_input = ''
            new_output = ''
            for row in input_matrix:
                for i, element in enumerate(row):
                    new_input += str(element)
                    if i < len(row) -1 :
                        new_input += ' '
                    else:
                        new_input += '\n'

            for row in output_matrix:
                for i, element in enumerate(row):
                    new_output += str(element)
                    if i < len(row) -1 :
                        new_output += ' '
                    else:
                        new_output += '\n'

            transformed_data[key] += f'\n###Input:\n{new_input}\n###Output:\n{new_output}'

    return transformed_data


In [None]:
train_data = []

for i, element in enumerate(all_json_data):
    train_data.append(generate_train_prompt(transform_input(element)))
    train_data.append(generate_train_prompt(transform_input_90(element)))
    train_data.append(generate_train_prompt(transform_input_270(element)))
    train_data.append(generate_train_prompt(transform_input_vertical(element)))
    train_data.append(generate_train_prompt(transform_input_horizontal(element)))

In [None]:
train_dataset = Dataset.from_list(train_data)

### Training

In [3]:
# Configure hyperparameters
batch_size=8
steps=64
lr=0.0004
epochs=24

In [None]:
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=steps,
    learning_rate=lr,
    fp16=True,
    num_train_epochs=epochs,
    save_strategy="epoch",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    output_dir='./experiments',
    remove_unused_columns=False,
    warmup_ratio=0.05,
    logging_strategy='epoch',
    label_names=['labels'],
    group_by_length=True
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    tokenizer=tokenizer,
    dataset_text_field='text',
    peft_config=config,
    max_seq_length=4096
)

In [None]:
model.config.use_cache = False
trainer.state.log_history = True
trainer.train()

### Save the fine-tuned model

In [None]:
model.save_pretrained(f'{model_name}_{str(epochs)}_epochs_augmented_2000_lr_0_0004')