In [1]:
import os
import re
import torch
from typing import List, Dict, Any, Union

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM

In [2]:
DATASETS_DIR = "../datasets"
DATASET_NAME = "self_generated"

SYSTEM_PROMPT = "You are a helpful assistant that answers questions about the table. You only answer the question right after 'Answer: '"
ASSISTANT_PROMPT = "Answer: "
SHUFFLE_SEED = 42

USER_PROMPT_ORDER = ["table", "question"]
TABLE_EXTENSION = "csv"

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
MAX_SEQ_LENGTH = 1024


In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
dataset_path = os.path.join(DATASETS_DIR, DATASET_NAME)
dataset = load_dataset("csv", data_files={
    "train": os.path.join(dataset_path, "data", "train.csv"),
    "test": os.path.join(dataset_path, "data", "test.csv"),
    "validation": os.path.join(dataset_path, "data", "val.csv")
})

In [5]:
# Randomly shuffle the dataset
# dataset = dataset.shuffle(seed=SHUFFLE_SEED)

# Pick the first 100 examples
dataset["train"] = dataset["train"].select(range(480))
dataset["validation"] = dataset["validation"].select(range(80))
dataset["test"] = dataset["test"].select(range(80))

In [None]:
def get_table(context: str):
    context = re.sub(f".csv$", "", context)
    with open(os.path.join(dataset_path, context + "." + TABLE_EXTENSION), "r", encoding="utf-8") as f:
        return f.read()

def preprocess_single_example_to_string(example):
    table = get_table(example["context"])
    example["table"] = table
    
    user_prompt = "\n".join([example[col_name] for col_name in USER_PROMPT_ORDER])
    
    assistant_response = ASSISTANT_PROMPT + str(example["answer"])
    
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": assistant_response}
    ]
    
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    example["input_string"] = text
    return example

dataset = dataset.map(preprocess_single_example_to_string, batched=False)

In [7]:
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
)

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=quantization_config, device_map="auto")

In [8]:
def format_example(example):
    return example["input_string"]

def last_occurrence_indices(tensor, X):
    # tensor: 2D tensor of shape B x L
    # X: the integer value to find in the tensor

    # Create a boolean mask where elements equal to X are True
    mask = (tensor == X)  # Shape: B x L

    # Reverse the mask along the sequence dimension (dimension 1)
    reversed_mask = torch.flip(mask, dims=[1])  # Shape: B x L

    # Find the index of the first occurrence of True in the reversed mask
    # Convert boolean mask to float to use argmax (True becomes 1.0, False becomes 0.0)
    idx_in_reversed = reversed_mask.float().argmax(dim=1)  # Shape: B

    # Calculate the last occurrence index in the original tensor
    last_indices = tensor.size(1) - idx_in_reversed - 1  # Shape: B

    # Handle rows where X does not occur
    # If X does not occur in a row, the entire mask row is False, and argmax returns 0
    # We need to set last_indices for these rows to -1 or any invalid index as per your requirements
    has_X = mask.any(dim=1)  # Shape: B (True if X is in the row)
    last_indices[~has_X] = -1  # Set to -1 where X does not occur

    return last_indices.unsqueeze(1)  # Shape: B x 1

def data_collator_for_assistant_completion_only(examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
    
    input_strings = [example["input_string"] for example in examples]
    batch = tokenizer(input_strings, return_tensors="pt", padding=True, truncation=True, max_length=MAX_SEQ_LENGTH, add_special_tokens=False)
    batch["labels"] = batch["input_ids"].clone()
    
    # Find the index of the last <|end_header_id|> (128007)
    # And set the labels before the last <|end_header_id|> to -100
    last_end_header_id_indices = last_occurrence_indices(batch["input_ids"], 128007)
    batch["labels"][range(len(batch["labels"])), :last_end_header_id_indices+2] = -100
    
    return batch


In [9]:
peft_config = LoraConfig(
    task_type="SEQ_2_SEQ_LM",
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
)

sft_config = SFTConfig(
    max_seq_length=MAX_SEQ_LENGTH,
    output_dir="../outputs/sft_trainer_collator",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    logging_strategy="steps",
    logging_steps=10,
    bf16=True,
    remove_unused_columns=False,
)

sft_trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=sft_config,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    peft_config=peft_config,
    data_collator=data_collator_for_assistant_completion_only,
    formatting_func=format_example,
)




Map:   0%|          | 0/480 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

In [10]:
sft_trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mb08097[0m. Use [1m`wandb login --relogin`[0m to force relogin




Step,Training Loss
10,1.162
20,0.8917
30,1.3828
40,0.553
50,1.3252
60,1.3234
70,0.615
80,0.5711
90,0.8079
100,1.4384


TrainOutput(global_step=480, training_loss=0.7567129435638587, metrics={'train_runtime': 48.4682, 'train_samples_per_second': 9.903, 'train_steps_per_second': 9.903, 'total_flos': 708291403776000.0, 'train_loss': 0.7567129435638587, 'epoch': 1.0})

In [11]:
tokenizer(["<|start_header_id|>assistant<|end_header_id|>"], add_special_tokens=False)

{'input_ids': [[128006, 78191, 128007]], 'attention_mask': [[1, 1, 1]]}

In [12]:
tokenizer.all_special_tokens

['<|begin_of_text|>', '<|eot_id|>']

In [13]:
tokenizer([dataset["train"][0]["input_string"]], add_special_tokens=False)

{'input_ids': [[128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 605, 4723, 220, 2366, 19, 271, 2675, 527, 264, 11190, 18328, 430, 11503, 4860, 922, 279, 2007, 13, 1472, 1193, 4320, 279, 3488, 1314, 1306, 364, 16533, 25, 364, 128009, 128006, 882, 128007, 271, 11, 6255, 220, 16, 11, 6255, 220, 17, 11, 6255, 220, 18, 11, 6255, 220, 19, 198, 3179, 220, 16, 11, 3971, 11, 6083, 11, 975, 11, 6028, 198, 3179, 220, 17, 11, 1399, 11, 508, 11, 6086, 11, 4218, 198, 3179, 220, 18, 11, 5728, 11, 5728, 11, 4044, 11, 1484, 198, 3179, 220, 19, 11, 1419, 11, 17, 11, 1691, 11, 4103, 271, 3923, 374, 279, 7340, 315, 279, 2819, 304, 11035, 220, 16, 30, 128009, 128006, 78191, 128007, 271, 16533, 25, 220, 6083, 128009]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [14]:
dataset["train"][0]["input_string"]

"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 10 Nov 2024\n\nYou are a helpful assistant that answers questions about the table. You only answer the question right after 'Answer: '<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n,Col 1,Col 2,Col 3,Col 4\nRow 1,51,92,14,71\nRow 2,60,20,82,86\nRow 3,74,74,87,99\nRow 4,23,2,21,52\n\nWhat is the maximum of the values in Row 1?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nAnswer: 92<|eot_id|>"

In [15]:
tokenizer(["<|start_header_id|>assistant<|end_header_id|>\n\nAnswer: "], add_special_tokens=False)

{'input_ids': [[128006, 78191, 128007, 271, 16533, 25, 220]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1]]}

In [16]:
tokenizer.decode([ 16533,     25,  95695,   9259, 128009])

'Answer: Horton Smith<|eot_id|>'