<h1>Setup</h1>

In [None]:
!pip install -U trl peft bitsandbytes transformers datasets accelerate evaluate bert-score rouge_score --quiet


<h1>Imports & Login</h1>

In [None]:
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))  # Replace with your token


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
import torch

base_model = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=bnb_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(base_model)


<h1>Model & Dataset</h1>


In [None]:
from datasets import load_dataset

dataset = load_dataset("haistudy/en_law_qa", split="train")
instruction = "You are a Law Assistant. Please answer the following question."

def format_chat(row):
    row["text"] = f"[INST] {instruction} {row['Question']} [/INST] {row['Answer']}"
    return row

dataset = dataset.map(format_chat)
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]


<h1>LoRA Setup</h1>

In [None]:
import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names.add(name.split('.')[-1])
    return list(names - {"lm_head"})

target_modules = find_all_linear_names(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules
)

model = get_peft_model(model, lora_config)


<h1>Training</h1>

In [None]:
from trl import SFTTrainer, SFTConfig

tokenizer.truncation_side = "left"
tokenizer.model_max_length = 1024

training_args = SFTConfig(
    output_dir="mistral-law-finetuned",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    num_train_epochs=1,
    optim="paged_adamw_32bit",
    fp16=True,
    group_by_length=True,
    report_to="none",
    max_length=512,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_args,
    peft_config=lora_config,
)



In [None]:
trainer.train()

<h1>Evaluation</h1>

In [None]:
from evaluate import load
from bert_score import score as bert_score
from tqdm import tqdm
import gc

def simple_f1(pred, label):
    pred_tokens = pred.split()
    label_tokens = label.split()
    common = set(pred_tokens) & set(label_tokens)
    if not common: return 0.0
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(label_tokens)
    return 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0.0

def evaluate_model(model, tokenizer, dataset, max_new_tokens=60, batch_size=8):
    model.eval()
    rouge = load("rouge")
    results, f1s, preds, labels = [], [], [], []

    for i in tqdm(range(0, len(dataset), batch_size)):
        batch = dataset[i:i+batch_size]
        tokenizer.padding_side = "left"
        tokenizer.pad_token = tokenizer.eos_token
        inputs = tokenizer(batch["text"], return_tensors="pt", padding=True, truncation=True,
                           max_length=tokenizer.model_max_length).to(model.device)

        with torch.no_grad():
            outputs = model.generate(input_ids=inputs["input_ids"],
                                     attention_mask=inputs["attention_mask"],
                                     max_new_tokens=max_new_tokens, do_sample=False)

        decoded_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded_labels = batch["Answer"]

        for pred, label in zip(decoded_preds, decoded_labels):
            pred, label = pred.strip(), label.strip()
            preds.append(pred)
            labels.append(label)
            f1s.append(simple_f1(pred, label))
            results.append({"prediction": pred, "reference": label, "f1": f1s[-1]})

        del inputs, outputs
        torch.cuda.empty_cache()
        gc.collect()

    rougeL = rouge.compute(predictions=preds, references=labels)["rougeL"]
    _, _, bert_F1 = bert_score(preds, labels, lang="en", device="cpu")

    print(f"\nAverage F1 Score:  {sum(f1s)/len(f1s):.4f}")
    print(f"ROUGE-L Score:     {rougeL:.4f}")
    print(f"BERTScore (F1):    {bert_F1.mean().item():.4f}")

    return {
        "f1": sum(f1s) / len(f1s),
        "rougeL": rougeL,
        "bertscore_f1": bert_F1.mean().item(),
        "results": results
    }

# Run evaluation
metrics = evaluate_model(model, tokenizer, eval_dataset)

'''Average F1 Score:  0.6408
ROUGE-L Score:     0.6966
BERTScore (F1):    0.9337'''
