# Imports and env  settings

In [None]:
%pip install "cohere" "datasets" "transformers" "accelerate" "peft" "bitsandbytes"

In [1]:
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, DataCollatorForLanguageModeling,Trainer
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm

import numpy as np
import torch
import math
import cohere
import json
import os
load_dotenv(override=True)

  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
COHERE_API_KEY = os.environ.get("COHERE_API_KEY_PAY", "")
INPUT_PATH_TRAIN = "fine_tune_data/bbt_train_cleaned.jsonl"
INPUT_PATH_VAL   = "fine_tune_data/bbt_val_cleaned.jsonl"

OUTPUT_PATH_TRAIN = "fine_tune_data/bbt_train_distilled.jsonl"
OUTPUT_PATH_VAL   = "fine_tune_data/bbt_val_distilled.jsonl"

MODEL_OUTPUT_PATH = r"distilled"
data_files = {
    "train": OUTPUT_PATH_TRAIN,
    "validation": OUTPUT_PATH_VAL
}
#ds = load_dataset("json", data_files=data_files)
# TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"] # full attention
TARGET_MODULES = ["q_proj","v_proj"]


# Distillation to improve fine-tune

## Preparing the distilled dataset 

In [3]:
co = cohere.ClientV2(COHERE_API_KEY)

# Simple persona descriptions per speaker
PERSONAS = {
    "Sheldon": (
        "You are Sheldon Cooper from The Big Bang Theory. "
        "You are a brilliant, pedantic theoretical physicist: literal, arrogant, and verbose. "
        "You speak in precise, formal, slightly condescending language and often reference science, physics, and your own intellect."
    ),
    "Leonard": (
        "You are Leonard Hofstadter from The Big Bang Theory. "
        "You are kind, self-conscious, often nervous, and try to keep the peace between your friends. "
        "You speak in a casual, slightly awkward but caring tone, and you often try to sound reasonable and supportive."
    ),
    "Penny": (
        "You are Penny from The Big Bang Theory. "
        "You are friendly, sarcastic, and down-to-earth, with good social intuition. "
        "You use casual everyday language, sometimes tease the guys, and react emotionally and humorously to their geeky behavior."
    ),
    "Howard": (
        "You are Howard Wolowitz from The Big Bang Theory. "
        "You are an aerospace engineer with an overconfident, sometimes creepy flirtatious style. "
        "You crack innuendo-filled jokes, brag about your accomplishments, and speak in a playful, comedic tone, especially about space and women."
    ),
    "Raj": (
        "You are Rajesh Koothrappali from The Big Bang Theory. "
        "You are sensitive, romantic, and somewhat socially awkward, with a love of pop culture and fantasy. "
        "You speak in an emotional, sometimes dramatic way, and you often talk about love, loneliness, and your interests like movies and comics."
    ),
    "Amy": (
        "You are Amy Farrah Fowler from The Big Bang Theory. "
        "You are a neurobiologist with a mix of scientific seriousness and socially awkward earnestness. "
        "You speak in a formal, analytical tone about emotions and relationships, and you are intensely devoted to Sheldon and your friends."
    ),
    "Bernadette": (
        "You are Bernadette Rostenkowski-Wolowitz from The Big Bang Theory. "
        "You have a sweet, high-pitched speaking style that can turn surprisingly strict or intimidating. "
        "You are practical, sometimes bossy, and you often mix cute phrasing with sharp, no-nonsense comments."
    ),
    # other non-central charactes
    "DEFAULT": (
        "You are a character from The Big Bang Theory. "
        "Respond in a style consistent with that character's personality and the show's comedic tone."
    ),
}

def get_persona(speaker: str) -> str:
    if not speaker:
        return PERSONAS["DEFAULT"]
    return PERSONAS.get(speaker, PERSONAS["DEFAULT"])

In [23]:
def distil_file(input_path: str, output_path: str, max_examples: int | None = None):
    """Reads original BBT JSONL and translates to teacher_target using Cohere"""
    with open(input_path, "r", encoding="utf-8") as fin:
        lines = [json.loads(l) for l in fin]

    if max_examples is not None:
        lines = lines[:max_examples]

    with open(output_path, "w", encoding="utf-8") as fout:
        for ex in tqdm(lines, desc=f"Distilling {input_path}"):
            prompt = ex.get("prompt", "")
            target_speaker = ex.get("target_speaker", "")

            persona = get_persona(target_speaker)

            # We use Cohere chat endpoint with system + user message
            try:
                response = co.chat(
                    model="command-r7b-12-2024",
                    messages=[
                        {
                            "role": "system",
                            "content": (
                                persona
                                + " You will be given the dialogue context. "
                                  "Continue the next line exactly as this character would speak. "
                                  "Respond with ONLY the next line of dialogue, no quotes, "
                                  "and do NOT add speaker tags."
                            ),
                        },
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.7,
                    max_tokens=96,
                )
                teacher_text = response.message.content[0].text.strip()

            except Exception as e:
                print(f"Error on example with ep={ex.get('ep')} scene={ex.get('scene')}:", e)
                # if Cohere fails we just use original script target 
                teacher_text = ex.get("target", "").strip()

            ex["teacher_target"] = teacher_text

            fout.write(json.dumps(ex, ensure_ascii=False) + "\n")

In [24]:
distil_file(INPUT_PATH_TRAIN, OUTPUT_PATH_TRAIN, max_examples=None) # meaning all - depends on restrictions of cohere account
distil_file(INPUT_PATH_VAL, OUTPUT_PATH_VAL, max_examples=None)

Distilling fine_tune_data/bbt_train_cleaned.jsonl:   1%|          | 329/32516 [03:48<6:12:23,  1.44it/s] 


KeyboardInterrupt: 

## Running fine tune again

In [None]:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
output_dir = r"Fine_Tune\outputs\tinyllama_bbt_distilled_lora"

ds = load_dataset("json", data_files=data_files)
train_ds = ds["train"]
val_ds = ds["validation"]

tok = AutoTokenizer.from_pretrained(model_id)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

base = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto",
)

base.config.pad_token_id = tok.pad_token_id

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=TARGET_MODULES,
)

model = get_peft_model(base, lora_config)
model.print_trainable_parameters()

In [None]:
max_len = 512

def build_example(ex):
    prompt = ex.get("prompt", "")
    # Use teacher's answer as label (distillation)
    target = ex.get("teacher_target", ex.get("target", ""))
    x = prompt + target

    enc_full   = tok(x, max_length=max_len, truncation=True)
    enc_prompt = tok(prompt, max_length=max_len, truncation=True)

    input_ids = enc_full["input_ids"]
    labels    = input_ids.copy()

    # mask prompt part
    n_prompt = len(enc_prompt["input_ids"])
    for i in range(min(n_prompt, len(labels))):
        labels[i] = -100

    return {
        "input_ids": input_ids,
        "attention_mask": enc_full["attention_mask"],
        "labels": labels,
    }

cols = ["input_ids", "attention_mask", "labels"]

train_tok = train_ds.map(
    build_example,
    remove_columns=train_ds.column_names,
)
val_tok = val_ds.map(
    build_example,
    remove_columns=val_ds.column_names,
)

train_tok.set_format(type="torch", columns=cols)
val_tok.set_format(type="torch", columns=cols)


In [None]:
# We don't want random masking (MLM); we already prepared labels ourselves.
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tok,
    mlm=False,
)

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    evaluation_strategy="steps",
    eval_steps=500,
    logging_steps=100,
    save_steps=1000,
    save_total_limit=3,
    num_train_epochs=3,
    learning_rate=2e-4,
    warmup_ratio=0.03,
    fp16=True,           # if GPU supports it
    bf16=False,          #  True if on A100
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    data_collator=data_collator,
)

In [None]:
trainer.train()
trainer.save_model(output_dir)
tok.save_pretrained(output_dir)
print("Finished training + saved model + tokenizer.")

## Evaluation