In [None]:
%%capture
!pip install wandb torch torchvision torchaudio transformers[torch] 'accelerate>=0.26.0' tokenizers datasets[s3]

# need to restart kernel before proceeding further

In [None]:
%%capture
!pip install tf-keras
# for datacrunch only

In [None]:
!wandb login <token>

In [None]:
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast, TrainingArguments, Trainer, DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
import torch
from tqdm import tqdm

MODEL_PATH = "./fineweb-gpt2-final/"
TOKENIZER_FILE = "fineweb-10bt-tokenizer-bpe.json"
OUTPUT_DIR = "./gpt2-chat-alpaca-dolly-2-neft-0p3"
ALPACA_DATASET_NAME = "tatsu-lab/alpaca"
DOLLY_DATASET_NAME = "databricks/databricks-dolly-15k"
SEED = 3047
MAX_LENGTH = 1024
NEFTUNE_NOISE_ALPHA = 5.0

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=TOKENIZER_FILE,
    bos_token="<|im_start|>",
    eos_token="<|im_end|>",
    pad_token="<pad>",
    unk_token="<unk>",
    mask_token="<mask>",
)

# add special token
special_tokens_to_add = []
if tokenizer.bos_token_id is None:
    special_tokens_to_add.append("<|im_start|>")
if tokenizer.eos_token_id is None:
    special_tokens_to_add.append("<|im_end|>")
if tokenizer.pad_token_id is None:
     # fallback to eos token
    if tokenizer.eos_token:
        tokenizer.pad_token = tokenizer.eos_token
    else:
        tokenizer.add_special_tokens({'pad_token': '<pad>'})
        special_tokens_to_add.append("<pad>")


if special_tokens_to_add:
    tokenizer.add_special_tokens({"additional_special_tokens": list(set(special_tokens_to_add))})


# i use chatML here, other format such as alpaca is available
chat_template = """{% for message in messages -%}
{{ bos_token }}{{ message['role'] }}
{{ message['content'] }}{{ eos_token }}
{% endfor -%}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
{{ bos_token }}assistant
{% endif -%}"""
tokenizer.chat_template = chat_template

print("Loading model")
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
model.resize_token_embeddings(len(tokenizer))

model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id

print(f"Tokenizer BOS: {tokenizer.bos_token} ({tokenizer.bos_token_id}), EOS: {tokenizer.eos_token} ({tokenizer.eos_token_id}), PAD: {tokenizer.pad_token} ({tokenizer.pad_token_id})")
print(f"Model BOS: {model.config.bos_token_id}, EOS: {model.config.eos_token_id}, PAD: {model.config.pad_token_id}")
assert model.config.pad_token_id is not None, "Model's pad_token_id is not set!"

# Load datasets
print("Loading ALPACA…")
alpaca_split = load_dataset(ALPACA_DATASET_NAME, split="train").train_test_split(test_size=0.1, seed=SEED, shuffle=True)
alpaca_train = alpaca_split["train"]
alpaca_val = alpaca_split["test"]

def format_alpaca(ds_split):
    convos = []
    for ex in ds_split:
        instr = ex.get("instruction", "").strip()
        ctx = ex.get("input", "").strip()
        response = ex.get("output", "").strip()

        user_msg_parts = [instr]
        if ctx: # Add context only if it's not empty
            user_msg_parts.append(ctx)
        user_msg = "\n\n".join(user_msg_parts)

        if user_msg and response:
            convos.append([
                {"role": "user", "content": user_msg},
                {"role": "assistant", "content": response},
            ])
    return convos


train_alpaca_convos = format_alpaca(alpaca_train)
val_alpaca_convos = format_alpaca(alpaca_val)
print(f"Alpaca -> {len(train_alpaca_convos)} train / {len(val_alpaca_convos)} val convos")

print("Loading Dolly 15k…")
dolly = load_dataset(DOLLY_DATASET_NAME)

# Filter out examples with missing 'instruction' or 'response' in Dolly
dolly_filtered = dolly.filter(lambda ex: ex.get("instruction") and ex.get("response") and \
                                       len(ex["instruction"].strip()) > 0 and \
                                       len(ex["response"].strip()) > 0)

dolly_split = dolly_filtered["train"].train_test_split(test_size=0.1, seed=SEED, shuffle=True)
dolly_train = dolly_split["train"]
dolly_val = dolly_split["test"]

def format_dolly(ds_split):
    convos = []
    for ex in ds_split:
        instr = ex.get("instruction", "").strip()
        ctx = ex.get("context", "").strip()
        response = ex.get("response", "").strip()

        user_msg_parts = [instr]
        if ctx:
            user_msg_parts.append(ctx)
        user_msg = "\n\n".join(user_msg_parts)

        if user_msg and response:
            convos.append([
                {"role": "user", "content": user_msg},
                {"role": "assistant", "content": response},
            ])
    return convos

train_dolly_convos = format_dolly(dolly_train)
val_dolly_convos = format_dolly(dolly_val)
print(f"Dolly-15k -> {len(train_dolly_convos)} train / {len(val_dolly_convos)} val convos")

# combine the 2 datasets
all_train_convos = train_alpaca_convos + train_dolly_convos
all_val_convos = val_alpaca_convos + val_dolly_convos
print(f"Combined (raw) -> {len(all_train_convos)} train / {len(all_val_convos)} val convos")

# remove those > context length
def filter_and_prepare_conversations(convos, tokenizer_ref, max_len):
    filtered_convos = []
    for convo in tqdm(convos, desc="Filtering long conversations"):
        if not convo: continue # skip empty conversation
        prompt_text = tokenizer_ref.apply_chat_template(
            convo,
            tokenize=False,
            add_generation_prompt=False
        )
        tokenized_len = len(tokenizer_ref.encode(prompt_text, truncation=False))

        if tokenized_len > 0 and tokenized_len <= max_len:
            filtered_convos.append({"conversations": convo})
        elif tokenized_len == 0:
            print(f"Empty conversation {convo}")

    return filtered_convos

train_convos_filtered = filter_and_prepare_conversations(all_train_convos, tokenizer, MAX_LENGTH)
val_convos_filtered = filter_and_prepare_conversations(all_val_convos, tokenizer, MAX_LENGTH)

print(f"Filtered by length -> {len(train_convos_filtered)} train / {len(val_convos_filtered)} val convos")


train_dataset = Dataset.from_list(train_convos_filtered)
val_dataset = Dataset.from_list(val_convos_filtered)


ASSISTANT_ROLE_NAME = "assistant"
USER_ROLE_NAME = "user"

def formatting_func(example):
    # 'example' is something like {"conversations": [{"role": ..., "content": ...}, ...]}
    conversation = example["conversations"]

    prompt_text = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=False
    )

    tokenized_inputs = tokenizer(
        prompt_text,
        truncation=True, # not really needed here bcs we already remove those that > max length
        max_length=MAX_LENGTH,
        return_attention_mask=True,
        padding="max_length",
    )
    input_ids = tokenized_inputs["input_ids"]

    labels = [-100] * len(input_ids)
    
    # find the last assistant response
    last_assistant_idx = max(
        idx for idx, turn in enumerate(conversation)
        if turn["role"] == ASSISTANT_ROLE_NAME
    )

    # Iterate through the tokenized input_ids to unmask assistant content
    
    current_token_idx = 0
    for turn_idx, turn in enumerate(conversation):
        role = turn["role"]
        content = turn["content"]

        try:
            start_of_turn_bos_idx = input_ids.index(tokenizer.bos_token_id, current_token_idx)
        except ValueError:
            break 

        search_for_eos_from = start_of_turn_bos_idx + 1 # Search after the current BOS
        end_of_turn_eos_idx = -1

        for k_eos in range(search_for_eos_from, len(input_ids)):
            if input_ids[k_eos] == tokenizer.eos_token_id:
                end_of_turn_eos_idx = k_eos
                break
        if end_of_turn_eos_idx == -1:
            print(f"Warning: Could not find EOS token for turn: {turn}")
            return None

        role_and_newline_text = f"{role}\n"
        role_and_newline_tokens = tokenizer.encode(role_and_newline_text, add_special_tokens=False)

        start_of_content_idx = start_of_turn_bos_idx + 1 + len(role_and_newline_tokens) # +1 for BOS

        if role == ASSISTANT_ROLE_NAME and turn_idx == last_assistant_idx:
            # Unmask tokens from start_of_content_idx up to (but not including) end_of_turn_eos_idx
                for k_label in range(start_of_content_idx, end_of_turn_eos_idx):
                    if k_label >= 0 and k_label < len(labels):
                        labels[k_label] = input_ids[k_label]
        
        current_token_idx = end_of_turn_eos_idx + 1 # Move search for next turn after this turn's EOS

    return {
        "input_ids": input_ids,
        "attention_mask": tokenized_inputs["attention_mask"],
        "labels": labels,
    }

tokenized_train = train_dataset.map(formatting_func)
tokenized_val = val_dataset.map(formatting_func)

# Set format
tokenized_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
tokenized_val.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


In [None]:
tokenized_train['conversations'][192]

In [None]:
import torch
import types

def add_neftune(embedding_layer, noise_alpha):
    original_forward = embedding_layer.forward

    def new_forward(self, input_ids):
        if self.training:
            embed_init = original_forward(input_ids)
            L = input_ids.size(1)
            d = embed_init.size(2)
            mag_norm = noise_alpha / (L * d)**0.5
            # generate uniform noise
            noise = torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)
            return embed_init + noise
        else:
            # during inference, return standard embeddings without noise
            return original_forward(input_ids)

    # bind the new forward method to the embedding layer
    embedding_layer.forward = types.MethodType(new_forward, embedding_layer)

In [None]:
from transformers import EarlyStoppingCallback

In [None]:
embedding_layer = model.get_input_embeddings()
add_neftune(embedding_layer, noise_alpha=NEFTUNE_NOISE_ALPHA)

early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3
)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    save_steps=500,
    save_total_limit=6,
    logging_steps=50,
    learning_rate=5e-5,
    weight_decay=0.01,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    fp16=True,
    eval_strategy="steps",
    eval_steps=500,
    load_best_model_at_end=True,
    logging_dir=f"{OUTPUT_DIR}/logs",
    remove_unused_columns=False,
    torch_compile=True,
    seed=SEED,
    report_to="wandb",
    run_name = "gpt2-chat-0p3",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
    callbacks=[early_stopping],
)

print("Starting fine-tuning…")
trainer.train()

In [None]:
FINAL_MODEL_PATH = './gpt2-instruct'
trainer.save_model(FINAL_MODEL_PATH)
tokenizer.save_pretrained(FINAL_MODEL_PATH)