# LLM fine-tuning

## Goal

Fine-tune an LLM to learn to count objects in a grid, or to solve ARC tasks.

I might do 2 steps of fine-tuning:

1. Learning priors, f.e. learning to count
2. Solve ARC tasks

## References

- https://github.com/ironbar/prompt_recovery/blob/main/notebooks/012_fine-tune_llama.ipynb
- https://github.com/ironbar/prompt_recovery/blob/main/notebooks/020_fine-tune_final_ensemble.ipynb
- https://www.kaggle.com/code/ironbar/few-shot-prompting-for-arc24

## Imports

In [2]:
# import os
# import json
# from abc import ABC, abstractmethod
# import numpy as np
# from termcolor import colored
# from tqdm.auto import tqdm
# import matplotlib.pyplot as plt
# import matplotlib as mpl
# from matplotlib import colors
# import wandb
from typing import Optional

# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, pipeline
# from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
# from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
# from datasets import Dataset

# plt.plot()
# plt.close('all')
# plt.rcParams["figure.figsize"] = (20, 5)
# mpl.rcParams['lines.linewidth'] = 3
# mpl.rcParams['font.size'] = 16

## Configuration

In [3]:
class cfg:
    model_path = "/home/gbarbadillo/data/Phi-3-mini-128k-instruct"
    adapter_path: Optional[str] = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/15_continue_training_phi3_4e5/checkpoint-22800' # Set it to None to train lora from scratch
    train_dataset = '/mnt/hdd0/Kaggle/arc24/data/learn_to_count/learn_to_count_100000.json'
    val_dataset = '/mnt/hdd0/Kaggle/arc24/data/learn_to_count/learn_to_count_1000.json'
    output_dir = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/22_random_question_lr1e-4_1e5dataset'
    max_seq_len = 512
    epochs = 1
    eval_steps = 100
    warmup_ratio = 0.1
    learning_rate = 1e-4

In [None]:
class cfg:
    # model_path = "/home/gbarbadillo/data/llama-3-transformers-8b-chat-hf-v1"
    model_path = '/home/gbarbadillo/data/llama-3.1-transformers-8b-instruct-v1'
    adapter_path: Optional[str] = None
    # adapter_path: Optional[str] = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/12_llama_lr_2e-4_1e4dataset_r32/checkpoint-600'
    use_rslora = False,
    use_dora = False,
    train_dataset = '/mnt/hdd0/Kaggle/arc24/data/learn_to_count/learn_to_count_100000.json'
    val_dataset = '/mnt/hdd0/Kaggle/arc24/data/learn_to_count/learn_to_count_1000.json'
    output_dir = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/22_llama31_lr1e-4_1e5dataset_r32'
    max_seq_len = 640
    epochs = 1
    eval_steps = 100
    warmup_ratio = 0.1
    learning_rate = 2e-4

In [None]:
os.makedirs(cfg.output_dir, exist_ok=True)
with open(os.path.join(cfg.output_dir, 'cfg.json'), 'w') as f:
    json.dump({key:value for key, value in cfg.__dict__.items() if not key.startswith('__')}, f, indent=4)

## Model

In [None]:
if 'llama' in cfg.model_path:
    device_map = {
        'model.embed_tokens': 0,
        'model.layers.0': 0,
        'model.layers.1': 0,
        'model.layers.2': 0,
        'model.layers.3': 0,
        'model.layers.4': 0,
        'model.layers.5': 0,
        'model.layers.6': 0,
        'model.layers.7': 0,
        'model.layers.8': 0,
        'model.layers.9': 0,
        'model.layers.10': 0,
        'model.layers.11': 0,
        'model.layers.12': 0,
        'model.layers.13': 0,
        'model.layers.14': 0,
        'model.layers.15': 0,
        'model.layers.16': 0,
        'model.layers.17': 1,
        'model.layers.18': 1,
        'model.layers.19': 1,
        'model.layers.20': 1,
        'model.layers.21': 1,
        'model.layers.22': 1,
        'model.layers.23': 1,
        'model.layers.24': 1,
        'model.layers.25': 1,
        'model.layers.26': 1,
        'model.layers.27': 1,
        'model.layers.28': 1,
        'model.layers.29': 1,
        'model.layers.30': 1,
        'model.layers.31': 1,
        'model.norm': 1,
        'model.rotary_emb': 1,
        'lm_head': 1,
    }
else:
    device_map = 'balanced'

# device_map = 'balanced'

model = AutoModelForCausalLM.from_pretrained(
    cfg.model_path,
    #quantization_config=bnb_config,
    device_map=device_map,
    # max_memory={0: '9GB', 1: '8GB'},
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    )

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    cfg.model_path,
    trust_remote_code=True)
if 'llama' in cfg.model_path:
    print('Adding <|pad|> token to tokenizer')
    tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
    model.resize_token_embeddings(len(tokenizer))
    tokenizer.padding_side = 'right'
tokenizer.special_tokens_map

In [None]:
def print_gpu_memory():
    for device in range(torch.cuda.device_count()):
        print(f'GPU {device} memory allocated: {torch.cuda.memory_allocated(device)/1024**3:.1f} GB, max memory allocated: {torch.cuda.max_memory_allocated(device)/1024**3:.1f} GB')
print_gpu_memory()

## Data

### Grid encoders

In [None]:
class GridEncoder(ABC):
    @abstractmethod
    def to_text(self, grid):
        pass
    
    @abstractmethod
    def to_grid(self, text):
        pass

In [None]:
sample_grid = np.eye(3, dtype=int).tolist()

def test_translator(translator):
    assert sample_grid == translator.to_grid(translator.to_text(sample_grid))
    print(translator.to_text(sample_grid))

In [None]:
class MinimalGridEncoder(GridEncoder):
    @staticmethod
    def to_text(grid):
        text = '\n'.join([''.join([str(x) for x in line]) for line in grid])
        return text
    
    @staticmethod
    def to_grid(text):
        lines = text.strip().splitlines()
        grid = [[int(x) for x in line] for line in lines]
        return grid
        
test_translator(MinimalGridEncoder())

In [None]:
class GridWithSeparationEncoder(GridEncoder):
    def __init__(self, split_symbol):
        self.split_symbol = split_symbol

    def to_text(self, grid):
        text = '\n'.join([self.split_symbol.join([str(x) for x in line]) for line in grid])
        return text

    def to_grid(self, text):
        lines = text.strip().splitlines()
        grid = [[int(x) for x in line.split(self.split_symbol)] for line in lines]
        return grid

test_translator(GridWithSeparationEncoder('|'))

In [None]:
class GridCodeBlockEncoder(GridEncoder):
    def __init__(self, base_encoder):
        self.encoder = base_encoder

    def to_text(self, grid):
        text = f'```grid\n{self.encoder.to_text(grid)}\n```'
        return text

    def to_grid(self, text):
        grid_text = text.split('```grid\n')[1].split('\n```')[0]
        grid = self.encoder.to_grid(grid_text)
        return grid

test_translator(GridCodeBlockEncoder(MinimalGridEncoder()))

test_translator(GridCodeBlockEncoder(GridWithSeparationEncoder('|')))

### Format data

In [None]:
def create_dataset(filepath, grid_encoder, shuffle_question_order=True):
    with open(filepath, 'r') as f:
        data = json.load(f)

    prompts = []

    for sample_id, sample in tqdm(data.items(), total=len(data)):
        messages = create_messages_from_sample(
            sample, grid_encoder, shuffle_question_order=shuffle_question_order)
        prompt = tokenizer.apply_chat_template(messages,
                                            tokenize=False,
                                            add_generation_prompt=False)
        prompts.append(prompt)

    np.random.shuffle(prompts)
    pretty_print_prompt(prompts[0])

    prompt_lengths = [len(tokenizer.encode(prompt)) for prompt in tqdm(prompts)]
    plt.hist(prompt_lengths, bins=100);
    plt.title('Prompt length distribution')
    plt.xlabel('Number of tokens');
    plt.show()

    prompts = [prompt for prompt, prompt_length in zip(prompts, prompt_lengths) if prompt_length < cfg.max_seq_len]
    print(f'Using {len(prompts)} prompts after removing those longer than {cfg.max_seq_len} tokens')

    dataset = Dataset.from_dict({'text': prompts})
    return dataset


def create_messages_from_sample(sample, grid_encoder, shuffle_question_order=False):
    first_message = True
    messages = [{'role': 'system', 'content': 'You are a helpful AI assistant'}]
    questions = list(sample['questions'].keys())
    if shuffle_question_order:
        np.random.shuffle(questions)
    for question in questions:
        answer = sample['questions'][question]
        if first_message:
            content = grid_encoder.to_text(sample['grid']) + '\n' + question
            first_message = False
        else:
            content = question
        messages.append({'role': 'user', 'content': content})
        messages.append({'role': 'assistant', 'content': str(answer)})
    return messages


def pretty_print_prompt(text, default_color='white'):
    color = default_color
    attrs = None
    for line in text.splitlines():
        if line.startswith('<|assistant|>'):
            color = 'blue'
        elif line.startswith('<|user|>'):
            color = default_color
        elif line.startswith('<|system|>'):
            color = 'green'
        if line.startswith('<'):
            attrs = ['bold']
        else:
            attrs = None
        print(colored(line, color, attrs=attrs))

In [None]:
if 'llama' in cfg.model_path:
    # we need to add separation between numbers in the grid
    grid_encoder = GridCodeBlockEncoder(GridWithSeparationEncoder('|'))
else:
    grid_encoder = GridCodeBlockEncoder(MinimalGridEncoder())
train_dataset = create_dataset(cfg.train_dataset, grid_encoder, shuffle_question_order=True)

In [None]:
val_dataset = create_dataset(cfg.val_dataset, grid_encoder, shuffle_question_order=True)

## Train

In [None]:
raise

In [None]:
if cfg.adapter_path is None:
    peft_config = LoraConfig(
        # lora_alpha: LoRA scaling factor.
        lora_alpha=64, #64,
        lora_dropout=0.1, # 0.1, althought Vaca suggested to use 0.05 for big models
        # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters.
        r=32, #16
        bias="none",
        task_type="CAUSAL_LM",
        # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
        target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj'],
        use_rslora=cfg.use_rslora,
        use_dora=cfg.use_dora,
    )
else:
    print(f'Loading adapter from {cfg.adapter_path}')
    peft_config = None
    model = PeftModel.from_pretrained(model, cfg.adapter_path, is_trainable=True)

In [None]:
if 'llama' in cfg.model_path:
    batch_size_kwargs = dict(
        per_device_train_batch_size=3, # 4-16 should be fine for lora.
        gradient_accumulation_steps=5,
        per_device_eval_batch_size=4,
    )
else:
    batch_size_kwargs = dict(
        per_device_train_batch_size=8, # 4-16 should be fine for lora.
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
    )

training_arguments = TrainingArguments(
        output_dir=cfg.output_dir,
        num_train_epochs=cfg.epochs,
        warmup_ratio=cfg.warmup_ratio,
        learning_rate=cfg.learning_rate,
        lr_scheduler_type="linear",
        optim="paged_adamw_8bit",

        do_eval=True,
        evaluation_strategy="steps",
        save_steps=cfg.eval_steps,
        logging_steps=10, #50,
        eval_steps=cfg.eval_steps,
        log_level="debug",

        **batch_size_kwargs
)

In [None]:
if 'llama' in cfg.model_path:
    data_collator = DataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer,
        instruction_template='<|start_header_id|>user<|end_header_id|>',
        response_template='<|start_header_id|>assistant<|end_header_id|>',
    )
else:
    data_collator = DataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer,
        instruction_template='<|user|>',
        response_template='<|assistant|>'
    )

In [None]:
w = wandb.init(reinit=True,
               dir=cfg.output_dir,
               project=os.path.basename(os.path.dirname(cfg.output_dir)),
               name=os.path.basename(cfg.output_dir))
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=cfg.max_seq_len,
    data_collator=data_collator,
    args=training_arguments,
    # packing=True, # ValueError: You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument.
)

trainer.train()
w.finish()

- Evaluation for 1k samples is taking around 1 minute.
- One epoch has taken around 30 minutes, with around 11 evaluations. Thus 1/3 of the time was spend evaluating. (550 steps, eval every 50 steps)
- Training for 10 epochs took 5 hours. I set the temperature of the AC to 27ºC and the room was at 22ºC, probably 28ºC is fine.
- Eval loss was 0.1665, without signs of overfit

300 steps without flash attention: 16:17 minutes, with flash attention 14 min

25 minutes to do 620 steps, around 10k samples

In [None]:
raise

## Evaluation

- https://huggingface.co/docs/transformers/en/peft
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.merge_and_unload

### Code

In [None]:
with open(cfg.val_dataset, 'r') as f:
    data = json.load(f)
val_samples_ids = list(data.keys())

def ask_question_to_model(sample_idx, pipe, question_idx=0, arbitrary_question=None):
    sample_id = val_samples_ids[sample_idx]
    sample = data[sample_id]

    sample_with_one_question = sample.copy()
    if arbitrary_question is None:
        sample_with_one_question['questions'] = {question:answer for idx, (question, answer) in enumerate(sample['questions'].items()) if idx == question_idx}
    else:
        sample_with_one_question['questions'] = {arbitrary_question:''}

    messages = create_messages_from_sample(sample_with_one_question, grid_encoder)
    prompt = tokenizer.apply_chat_template(messages[:2],
                                            tokenize=False,
                                            add_generation_prompt=True)
    plot_grid(sample['grid']); plt.show()
    # pretty_print_prompt(prompt)

    generation_args = {
        "max_new_tokens": 50,
        "return_full_text": False,
        "do_sample": False,
    }

    output = pipe(prompt, **generation_args)
    print(list(sample_with_one_question['questions'].keys())[0])
    print(f">{output[0]['generated_text']} ({list(sample_with_one_question['questions'].values())[0]})")

def plot_grid(grid):
    grid = np.array(grid)
    cmap = colors.ListedColormap(
        ['#000000', '#0074D9','#FF4136','#2ECC40','#FFDC00',
         '#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'])
    norm = colors.Normalize(vmin=0, vmax=9)
    plt.imshow(grid, cmap=cmap, norm=norm)
    plt.grid(True,which='both',color='lightgrey', linewidth=0.5)
    plt.xticks(np.arange(-0.5, grid.shape[1]), [])
    plt.yticks(np.arange(-0.5, grid.shape[0]), [])
    plt.xlim(-0.5, grid.shape[1]-0.5)

    for i in range(grid.shape[0]):
        for j in range(grid.shape[1]):
            plt.text(j, i, grid[i, j], ha='center', va='center')

### Experiments

In [None]:
adapter_path = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/11_lr_4e-4_1e5dataset_r32/checkpoint-12400'
adapter_path = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/09_lr_1e-3_1e4dataset_r32/checkpoint-600/'
adapter_path = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/14_llama31/checkpoint-333'
adapter_path = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/15_continue_training_phi3_4e5/checkpoint-22800'
adapter_path = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/22_random_question_lr1e-4_1e5dataset/checkpoint-6228'
adapter_path = '/mnt/hdd0/Kaggle/arc24/models/20240724_first_trainings/22_llama31_lr1e-4_1e5dataset_r32/checkpoint-6553'
model.load_adapter(adapter_path, adapter_path)
model.eval();

In [None]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
)

I want to visualize the grid and ask some random question, see how well it does.

Compare the responses with and without the adapter.

In [None]:
ask_question_to_model(sample_idx=750, question_idx=1, pipe=pipe)

In [None]:
# ask_question_to_model(sample_idx=300, arbitrary_question='Please describe the grid, saying how many objects are there, their color and area.', pipe=pipe)
# ask_question_to_model(sample_idx=300, arbitrary_question='What is the shape of the grid? (nxn)', pipe=pipe)
ask_question_to_model(sample_idx=550, arbitrary_question='Describe the objects in the grid', pipe=pipe)

## TODO

- [x] Fixed val dataset
- [x] The pad_token_id and eos_token_id values of this tokenizer are identical. If you are planning for multi-turn training, it can result in the model continuously generating questions and answers without eos token. To avoid this, set the pad_token_id to a different value. **I have evaluated phi-3 and works correctly**
- [x] Flash attention? Yes, it is slightly faster.
- [x] Batch size
- [x] Better WANDB configuration
- [x] Training parameters?
  - [x] Learning rate
  - [x] Is the linear decay working properly? Yes
- [x] Better switch between Llama and Phi