In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import dotenv
from huggingface_hub import login
import os
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments
)
from src.training_utils import GenderLossTrainer
from src.utils import read_config
from peft import LoraConfig, get_peft_model
from src.data_utils import prepare_dataset
from datasets import load_dataset, concatenate_datasets

dotenv.load_dotenv()
login(token=os.getenv('huggingface_token'))


%load_ext autoreload
%autoreload 2

In [None]:
import wandb
wandb.login(key=os.getenv('WANDB_API_KEY'))


In [None]:
llm_configs = read_config('../configs/llm_config.yaml')
print(llm_configs)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(llm_configs['local_generative_model_name'])
model = AutoModelForCausalLM.from_pretrained(llm_configs['local_generative_model_name'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Custom gender dataset

In [None]:
dataset = prepare_dataset('../configs/dataset_config.yaml', '../data/short_profession_templates.txt', print_dataset_info=True)


In [None]:
dataset['train'][0]

In [7]:
# for name, param in model.named_parameters():
#     print(name, param.size())

In [8]:
gender_ds_extra_id = -777

def format_example(example):
    # Basic instruction prompt format
    instruction = example["context"]
    label = [tokenizer.convert_tokens_to_ids(pr) for pr in example["pronoun_list"]]
    input_ids = tokenizer(instruction, truncation=True, max_length=256).input_ids
    length = len(input_ids)
    label.append(length)
    label.append(gender_ds_extra_id)
   
    return {
        "input_ids": input_ids,
        "labels": label
    }

In [None]:
dataset_train = dataset['train'].map(format_example, remove_columns=dataset['train'].column_names)
dataset_validation = dataset['validation'].map(format_example, remove_columns=dataset['validation'].column_names)
dataset_test = dataset['test'].map(format_example, remove_columns=dataset['test'].column_names)

In [10]:
dataset_train = dataset_train.shuffle(seed=42)
dataset_train = dataset_train.select(range(len(dataset_train)//2))

In [None]:
dataset_train

# Dolly dataset

In [12]:
def load_dolly_dataset(max_length=512, print_dataset_info=False):
    """
    Loads and preprocesses the Dolly dataset to match the custom dataset format.
    
    Args:
        tokenizer: The tokenizer to encode the texts.
        max_length: Maximum sequence length.
    
    Returns:
        DatasetDict
    """
    dolly = load_dataset("databricks/databricks-dolly-15k")
    if print_dataset_info:
        print(dolly)
        print(dolly["train"][0])
    
    def preprocess_dolly(example):
        instruction = example["instruction"]
        context = example.get("context", "")
        response = example["response"]

        if context:
            prompt = f"Instruction:\n{instruction}\n\nContext:\n{context}\n\nAnswer:"
        else:
            prompt = f"Instruction:\n{instruction}\n\nAnswer:"

        # Токенизируем разом (потом можно разбить, если хотим)
        prompt_ids = tokenizer(prompt, truncation=True, max_length=max_length).input_ids
        response_ids = tokenizer(response, truncation=True, max_length=max_length).input_ids
        
        # Собираем полный вход
        input_ids = prompt_ids + response_ids
        
        # Соответствующие метки
        labels = [-100] * len(prompt_ids) + response_ids
        
        return {
            "input_ids": input_ids,
            "labels": labels
        }
        
    dolly = dolly.map(preprocess_dolly, remove_columns=dolly["train"].column_names)
    if print_dataset_info:
        print(dolly)
        print(dolly["train"][0])
    
    return dolly

In [None]:
dolly_dataset = load_dolly_dataset(print_dataset_info=True)

# Train

In [14]:
def data_collator(features):
    # Collate input_ids and labels into padded tensors
    input_ids = [torch.tensor(f["input_ids"]) for f in features]
    labels = [torch.tensor(f["labels"]) for f in features]
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)  # -100 for ignored tokens

    if input_ids.shape[1] < labels.shape[1]:
        input_ids = torch.nn.functional.pad(
            input_ids,
            (0, labels.shape[1] - input_ids.shape[1]),
            value=tokenizer.pad_token_id
        )

    if input_ids.shape[1] > labels.shape[1]:
        labels = torch.nn.functional.pad(
            labels,
            (0, input_ids.shape[1] - labels.shape[1]),
            value=-100
        )
    return {"input_ids": input_ids, "labels": labels, "attention_mask": (input_ids != tokenizer.pad_token_id)}

In [None]:

# merged_train = concatenate_datasets([dataset_train, dolly_dataset['train']])

merged_train = dataset_train ###################
merged_train = merged_train.shuffle(seed=42)
merged_train

In [16]:
train_config = read_config('../configs/train_config.yaml')

target_modules = []

for layer_idx in range(train_config['lora_train_config']['layer_numbers']['start'], train_config['lora_train_config']['layer_numbers']['end']):
    for proj in train_config['lora_train_config']['layers_to_train']:
        target_modules.append(f"layers.{layer_idx}.{train_config['lora_train_config']['module_to_train']}.{proj}")

lora_config = LoraConfig(
    r=train_config['lora_train_config']['r'],
    lora_alpha=train_config['lora_train_config']['lora_alpha'],
    target_modules=target_modules,
    lora_dropout=train_config['lora_train_config']['lora_dropout'],
    bias=train_config['lora_train_config']['bias'],
    task_type=train_config['lora_train_config']['task_type']
)

In [17]:
lora_model = get_peft_model(model, lora_config)

In [None]:
def count_trainable_params(model):
    all_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return all_params, trainable_params


# Count parameters in the LoRA-adapted model
all_params, trainable_params = count_trainable_params(lora_model)
print(f"All parameters: {all_params}")
print(f"Trainable parameters: {trainable_params}")
print(f"Trainable parameters share: {round(trainable_params / all_params * 100, 4)}%")


In [None]:

wandb.init(project="gender-bias-llm", config=train_config)

In [None]:
training_args = TrainingArguments(
    output_dir="../test/gender_only_ckpt",
    per_device_train_batch_size=4,
    num_train_epochs=2,
    logging_steps=50,
    eval_steps=200,
    evaluation_strategy="steps",
    logging_dir="../test/logs",
    report_to="wandb",
    # gradient_accumulation_steps=8
)

trainer = GenderLossTrainer(
    model=lora_model,
    args=training_args,
    train_dataset=merged_train,
    eval_dataset=dataset_validation,
    data_collator=data_collator,
    lambda_gender=train_config['lambda_gender'],
    gender_ds_extra_id=gender_ds_extra_id
)

trainer.train()

In [21]:
dataset_copa = load_dataset("super_glue", "copa")
dataset_copa_test = concatenate_datasets([dataset_copa['validation'], dataset_copa['train']])

In [None]:
import numpy as np
def evaluate_copa(model, tokenizer, dataset):
    correct = 0
    total = 0
    
    for sample in dataset:
        premise = sample["premise"]
        question = sample["question"]
        choice1 = sample["choice1"]
        choice2 = sample["choice2"]
        correct_label = sample["label"]
        
        scores = []
        for choice in [choice1, choice2]:
            prompt = (f"Premise: {premise}\n"
                      f"Question: What is the most likely {question}?\n"
                      f"Choice: {choice}\n")
            
            inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
            with torch.no_grad():
                outputs = model(**inputs, labels=inputs["input_ids"])
                loss = outputs.loss.item()
                
            scores.append(-loss)
        
        pred_label = int(np.argmax(scores))
        if pred_label == correct_label:
            correct += 1
        total += 1
        
    return correct / total

accuracy = evaluate_copa(lora_model, tokenizer, dataset_copa_test)
print(f"COPA validation accuracy: {accuracy*100:.2f}%")
