In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import dotenv
from huggingface_hub import login
import os
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments
)
from src.training_utils import GenderLossTrainer
from src.utils import read_config
from peft import LoraConfig, get_peft_model
from src.data_utils import prepare_dataset_gender, prepare_dataset_gender_stories
from datasets import load_dataset, concatenate_datasets

dotenv.load_dotenv()
login(token=os.getenv('huggingface_token'))


%load_ext autoreload
%autoreload 2

In [None]:
import wandb
wandb.login(key=os.getenv('WANDB_API_KEY'))


In [None]:
llm_configs = read_config('../configs/llm_config.yaml')
print(llm_configs)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(llm_configs['local_generative_model_name'])
model = AutoModelForCausalLM.from_pretrained(llm_configs['local_generative_model_name'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Custom gender dataset

In [None]:
dataset_gender = prepare_dataset_gender('../configs/dataset_config.yaml', '../data/short_profession_templates.txt', print_dataset_info=True, reduced_number_of_train_templates=35)


In [None]:
dataset_gender['train'][0]

In [38]:
gender_ds_extra_id = -777

def format_example(example):
    # Basic instruction prompt format
    instruction = example["context"]
    label = [tokenizer.convert_tokens_to_ids(pr) for pr in example["pronoun_list"]]
    input_ids = tokenizer(instruction, truncation=True, max_length=256).input_ids
    length = len(input_ids)
    label.append(length)
    label.append(gender_ds_extra_id)
   
    return {
        "input_ids": input_ids,
        "labels": label
    }

In [None]:
dataset_gender_train = dataset_gender['train'].map(format_example, remove_columns=dataset_gender['train'].column_names)
dataset_gender_validation = dataset_gender['validation'].map(format_example, remove_columns=dataset_gender['validation'].column_names)
dataset_gender_test = dataset_gender['test'].map(format_example, remove_columns=dataset_gender['test'].column_names)

In [None]:
dataset_gender_train

# Custom stories dataset

In [None]:
dataset_stories = prepare_dataset_gender_stories('../data/stories', reduced_number_of_stories_per_profession=17)
dataset_stories

In [None]:
dataset_stories['train'][0]

In [45]:
def format_example_stories(example, max_length=256):
    instruction = example["instruction"]
    response = example["response"]
    prompt = f"{instruction}\n"
    prompt_ids = tokenizer(prompt, truncation=True, max_length=max_length).input_ids
    response_ids = tokenizer(response, truncation=True, max_length=max_length).input_ids
    input_ids = prompt_ids + response_ids
    labels = [-100] * len(prompt_ids) + response_ids
    
    return {
        "input_ids": input_ids,
        "labels": labels
    }

In [None]:
dataset_stories_train = dataset_stories['train'].map(format_example_stories, remove_columns=dataset_stories['train'].column_names)

# Dolly dataset

In [49]:
def load_dolly_dataset(max_length=512, print_dataset_info=False, number_of_train_samples=None):
    """
    Loads and preprocesses the Dolly dataset to match the custom dataset format.
    
    Args:
        tokenizer: The tokenizer to encode the texts.
        max_length: Maximum sequence length.
    
    Returns:
        DatasetDict
    """
    dolly = load_dataset("databricks/databricks-dolly-15k")
    if number_of_train_samples:
        dolly = dolly["train"].select(range(number_of_train_samples))
    if print_dataset_info:
        print(dolly)
        print(dolly[0])
    
    def preprocess_dolly(example):
        instruction = example["instruction"]
        context = example.get("context", "")
        response = example["response"]

        if context:
            prompt = f"Instruction:\n{instruction}\n\nContext:\n{context}\n\nAnswer:"
        else:
            prompt = f"Instruction:\n{instruction}\n\nAnswer:"

        prompt_ids = tokenizer(prompt, truncation=True, max_length=max_length).input_ids
        response_ids = tokenizer(response, truncation=True, max_length=max_length).input_ids
        
        input_ids = prompt_ids + response_ids
        labels = [-100] * len(prompt_ids) + response_ids
        
        return {
            "input_ids": input_ids,
            "labels": labels
        }
        
    dolly = dolly.map(preprocess_dolly, remove_columns=dolly.column_names)
    if print_dataset_info:
        print(dolly)
        print(dolly[0])
    
    return dolly

In [None]:
dolly_dataset = load_dolly_dataset(print_dataset_info=True, number_of_train_samples=1000)

# Train

In [51]:
def data_collator(features):
    # Collate input_ids and labels into padded tensors
    input_ids = [torch.tensor(f["input_ids"]) for f in features]
    labels = [torch.tensor(f["labels"]) for f in features]
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)  # -100 for ignored tokens

    if input_ids.shape[1] < labels.shape[1]:
        input_ids = torch.nn.functional.pad(
            input_ids,
            (0, labels.shape[1] - input_ids.shape[1]),
            value=tokenizer.pad_token_id
        )

    if input_ids.shape[1] > labels.shape[1]:
        labels = torch.nn.functional.pad(
            labels,
            (0, input_ids.shape[1] - labels.shape[1]),
            value=-100
        )
    return {"input_ids": input_ids, "labels": labels, "attention_mask": (input_ids != tokenizer.pad_token_id)}

In [None]:

merged_train = concatenate_datasets([dataset_gender_train, dataset_stories_train, dolly_dataset])

# merged_train = dataset_train ###################
merged_train = merged_train.shuffle(seed=42)
merged_train

In [55]:
train_config = read_config('../configs/train_config.yaml')

target_modules = []

for layer_idx in range(train_config['lora_train_config']['layer_numbers']['start'], train_config['lora_train_config']['layer_numbers']['end']):
    for proj in train_config['lora_train_config']['layers_to_train']:
        target_modules.append(f"layers.{layer_idx}.{train_config['lora_train_config']['module_to_train']}.{proj}")

lora_config = LoraConfig(
    r=train_config['lora_train_config']['r'],
    lora_alpha=train_config['lora_train_config']['lora_alpha'],
    target_modules=target_modules,
    lora_dropout=train_config['lora_train_config']['lora_dropout'],
    bias=train_config['lora_train_config']['bias'],
    task_type=train_config['lora_train_config']['task_type']
)

In [None]:
train_config

In [57]:
lora_model = get_peft_model(model, lora_config)

In [None]:
def count_trainable_params(model):
    all_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return all_params, trainable_params


# Count parameters in the LoRA-adapted model
all_params, trainable_params = count_trainable_params(lora_model)
print(f"All parameters: {all_params}")
print(f"Trainable parameters: {trainable_params}")
print(f"Trainable parameters share: {round(trainable_params / all_params * 100, 4)}%")


In [None]:

wandb.init(project="gender-bias-llm", config=train_config)

In [None]:
training_args = TrainingArguments(
    output_dir="../test/gender_only_ckpt",
    per_device_train_batch_size=8,
    num_train_epochs=2,
    logging_steps=50,
    eval_steps=200,
    evaluation_strategy="steps",
    logging_dir="../test/logs",
    report_to="wandb",
    # gradient_accumulation_steps=8
)

trainer = GenderLossTrainer(
    model=lora_model,
    args=training_args,
    train_dataset=merged_train,
    eval_dataset=dataset_validation,
    data_collator=data_collator,
    lambda_gender=train_config['lambda_gender'],
    gender_ds_extra_id=gender_ds_extra_id,
    p_total_power=train_config['p_total_power']
)

trainer.train()

# Testing

## COPA

In [8]:
dataset_copa = load_dataset("super_glue", "copa")
dataset_copa_test = concatenate_datasets([dataset_copa['validation'], dataset_copa['train']])

In [None]:
import numpy as np
def evaluate_copa(model, tokenizer, dataset):
    correct = 0
    total = 0
    
    for sample in dataset:
        premise = sample["premise"]
        question = sample["question"]
        choice1 = sample["choice1"]
        choice2 = sample["choice2"]
        correct_label = sample["label"]
        
        scores = []
        for choice in [choice1, choice2]:
            prompt = (f"Premise: {premise}\n"
                      f"Question: What is the most likely {question}?\n"
                      f"Choice: {choice}\n")
            
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            with torch.no_grad():
                outputs = model(**inputs, labels=inputs["input_ids"])
                loss = outputs.loss.item()
                
            scores.append(-loss)
        
        pred_label = int(np.argmax(scores))
        if pred_label == correct_label:
            correct += 1
        total += 1
        
    return correct / total

accuracy = evaluate_copa(lora_model, tokenizer, dataset_copa_test)
print(f"COPA validation accuracy: {accuracy*100:.2f}%")


## PIQA

In [None]:
piqa = load_dataset("piqa")
print(piqa)
val_data = piqa["validation"]

val_data = val_data.shuffle(seed=42)
val_data = val_data.select(range(500))

In [22]:
def compute_choice_score(model, tokenizer, prompt, choice_text):
    """
    Computes negative log-likelihood of `choice_text` given the `prompt`.
    We'll return *log-prob* (the higher, the more likely).
    """
    device = next(model.parameters()).device
    
    # Combine the prompt and choice
    full_text = prompt + " " + choice_text
    
    # Tokenize
    inputs = tokenizer(full_text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    # We'll use the model's causal LM head to get the total loss
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            labels=input_ids  # computing cross-entropy over the entire sequence
        )
        loss = outputs.loss.item()  # average cross-entropy over all tokens

    # Return negative loss as "score"
    # A higher score => lower cross-entropy => better fit
    return -loss

In [None]:
def evaluate_piqa(model, tokenizer, dataset):
    correct = 0
    total = 0
    
    for example in dataset:
        goal = example["goal"]
        sol1 = example["sol1"]
        sol2 = example["sol2"]
        label = example["label"]  # 0 or 1
        
        # Construct a simple prompt. For instance:
        prompt = f"Question: {goal}\nAnswer:"
        
        # Score each solution
        score_sol1 = compute_choice_score(model, tokenizer, prompt, sol1)
        score_sol2 = compute_choice_score(model, tokenizer, prompt, sol2)
        
        # Predict choice: whichever has higher log-prob
        pred_label = 0 if score_sol1 > score_sol2 else 1
        
        if pred_label == label:
            correct += 1
        total += 1

    accuracy = correct / total
    return accuracy

accuracy_val = evaluate_piqa(lora_model, tokenizer, val_data)
print(f"PIQA validation accuracy: {accuracy_val*100:.2f}%")

## LAMBADA

In [None]:
dataset = load_dataset("lambada")
print(dataset)

validation_data = dataset["validation"]
validation_data = validation_data.shuffle(seed=42)
validation_data = validation_data.select(range(500))

In [None]:
validation_data[0]

In [27]:
def evaluate_lambada_next_token_accuracy(model, tokenizer, dataset, max_eval_samples=None):
    """
    For each example, we:
      1. Tokenize the entire text.
      2. Separate the last token as the 'target'.
      3. Feed the preceding tokens (context) into the model.
      4. Let the model predict the next token (top-1).
      5. Check if it matches the actual last token.
    Returns accuracy (#correct / #total).
    """
    device = next(model.parameters()).device

    correct = 0
    total = 0

    for i, example in enumerate(dataset):
        text = example["text"].strip()
        # Convert text to token IDs
        tokens = tokenizer.encode(text)
        if len(tokens) < 2:
            # If the text is too short (only 1 token), skip
            continue
        
        context_ids = tokens[:-1]  # all but last token
        target_id = tokens[-1]     # last token

        # Convert to tensors
        context_ids = torch.tensor([context_ids], dtype=torch.long, device=device)
        target_id = torch.tensor([target_id], dtype=torch.long, device=device)

        with torch.no_grad():
            outputs = model(context_ids)
            # outputs.logits shape: (batch, seq_len, vocab_size)
            # We want the last hidden state from the final token in context
            logits_last = outputs.logits[:, -1, :]  # shape: (batch=1, vocab_size)
            pred_id = torch.argmax(logits_last, dim=-1)  # top-1 token index

        if pred_id.item() == target_id.item():
            correct += 1
        total += 1

        # Optionally limit number of evaluated samples
        if max_eval_samples is not None and (i + 1) >= max_eval_samples:
            break
    
    accuracy = correct / total if total > 0 else 0.0
    return accuracy

In [None]:
accuracy = evaluate_lambada_next_token_accuracy(lora_model, tokenizer, validation_data, max_eval_samples=1000)
print(f"LAMBADA test accuracy (next-token): {accuracy*100:.2f}%")

