In [None]:
#======GPU assign=====#

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"


import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
print('Current cuda device: ', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

In [None]:
#====Packages====#

from datasets import load_dataset
from transformers import (
    AutoModel,
    AutoModelForCausalLM, #(Automatically loads a model for causal language modeling)
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
    TrainingArguments,
    Trainer,
    pipeline,
    logging,
    TrainerCallback
    
)
from peft import LoraConfig, PeftModel, get_peft_model, PeftConfig
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
import os
import csv
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch


In [None]:
#======Check compatibility for precision training=======#

# gpu_name = torch.cuda.get_device_name(0)
# compute_capability = torch.cuda.get_device_capability(0)
# print(f"GPU Name: {gpu_name}")
# print(f"Compute Capability: {compute_capability}")

"""
If greater than 7 then we can set fp16 to True
"""

In [None]:
#======Model info======#

model_name = "roberta-base"  # RoBERTa for MNLI task
task_name = "mnli"
num_labels=3  # MNLI has three labels: entailment, contradiction, neutral
output_dir = "/home/himani/Roberta_LoRA_mnli"
fine_tuned_model_path = "/home/himani/Roberta_LoRA_mnli/Experiment_3_epoch_3" #later for evaluation
plot_file_name="Experiment_4_similarity_matrices"
csv_name="Experiment_4_intruders.csv"

In [None]:
#=======Load dataset=======#

# Load MNLI dataset
dataset = load_dataset("glue", "mnli")
tokenizer = AutoTokenizer.from_pretrained("roberta-base") 

# Limit training dataset 
train_dataset = dataset["train"].shuffle(seed=42).select(range(50000))
validation_dataset = dataset["validation_matched"]

# Tokenize the dataset
def preprocess_function(examples):
    return tokenizer(
        examples["premise"],  # MNLI premise
        examples["hypothesis"],  # MNLI hypothesis
        truncation=True,
        padding="max_length",
        max_length=128,  
    )

train_dataset = train_dataset.map(preprocess_function, batched=True)
validation_dataset = validation_dataset.map(preprocess_function, batched=True)

# Prepare for Trainer
train_dataset = train_dataset.remove_columns(["idx", "premise", "hypothesis"])  # Remove unused columns
train_dataset = train_dataset.rename_column("label", "labels")  # Rename
train_dataset.set_format("torch")  

validation_dataset = validation_dataset.remove_columns(["idx", "premise", "hypothesis"])
validation_dataset = validation_dataset.rename_column("label", "labels")
validation_dataset.set_format("torch")

train_dataset.set_format("torch", device="cuda:0")
validation_dataset.set_format("torch", device="cuda:0")


In [None]:
#======Load LLM======#

model = AutoModelForSequenceClassification.from_pretrained(
    "roberta-base", 
    num_labels=3,  # MNLI has three labels: entailment, contradiction, neutral
)

# tokenizer for RoBERTa
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
tokenizer.pad_token = tokenizer.eos_token  # Ensure compatibility with padding
tokenizer.padding_side = "right"


# print(f"Padding token: {tokenizer.pad_token}, EOS token: {tokenizer.eos_token}")

# configure LoRA for sequence classification
peft_config = LoraConfig(
    r=1, 
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",  
    task_type="SEQ_CLS",  # sequence classification
)

scaling_factor = 1/ 256
model = get_peft_model(model, peft_config)
model = model.to("cuda:0")


print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


# Check if we imported trainable LoRA
# print("LoRA layers in the model:")
# for name, param in model.named_parameters():
#     if "lora" in name.lower():
#         print(f"{name}: Trainable = {param.requires_grad}")


In [None]:
#=======Set training parameters=======#

training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=1,
    optim="adamw_hf",
    save_strategy="epoch",  
    evaluation_strategy="epoch", 
    learning_rate=2e-5,
    weight_decay=0.01,
    fp16=False,
    max_grad_norm=1.0,  # 0.3 if needed
    max_steps=-1,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    logging_dir="./logs",
    load_best_model_at_end=True,
    save_total_limit=None
)


# Customize Trainer
class CustomTrainer(Trainer):
    def training_step(self, model, inputs):
        # Move all inputs to cuda:0
        inputs = {k: v.to("cuda:0")  for k, v in inputs.items()}
        return super().training_step(model, inputs)



In [None]:
# ======Start Training======#


trainer = CustomTrainer(
    model=model,  # Fine-tuned model
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
#======Extract lora_A and lora_B for delta weight======#

def extract_delta_weights(base_model_name, model_path, scaling_factor):
    """
   (W_delta = scaling_factor * lora_B @ lora_A) from a fine-tuned model
   model path= finetuned model path
   to load the finetuned model, you also need to pass the base model in the arg
    """

    base_model = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=3)   
    model = PeftModel.from_pretrained(base_model, model_path)
    
    delta_weights = []
    for name, param in model.named_parameters():
        if "lora_A" in name:
            layer_base_name = name.replace(".lora_A.default.weight", "")
            lora_a = param.clone().detach().to("cuda")
            
            lora_b_name = layer_base_name + ".lora_B.default.weight"
            if lora_b_name in model.state_dict():
                lora_b = model.state_dict()[lora_b_name].clone().detach().to("cuda")
                delta_w = scaling_factor * torch.matmul(lora_b, lora_a)
                delta_weights.append(delta_w)
            else:
                print(f"Missing corresponding LoRA B weight for {name}")
    
    print(f"Extracted delta weights for {len(delta_weights)} layers.")
    return delta_weights


delta_weights = extract_delta_weights(model_name, fine_tuned_model_path, scaling_factor)
# for idx, weight in enumerate(delta_weights):
#     print(f"Layer {idx}: Shape = {weight.shape}, Device = {weight.device}")


In [None]:
#========Extract weight from the basemodel i.e. W0========#

def extract_weights_from_base_model(model_name):
    
    """
    We are only interested in the query, value weights of each layer since LoRA only changes q,v 
    """
    model = AutoModel.from_pretrained(model_name)
    
    # extract query and value weights with names
    extracted_weights = []
    extracted_layer_names = [] # we will use the names for plots
    for name, param in model.named_parameters():
        if "attention.self.query.weight" in name or "attention.self.value.weight" in name:
            extracted_weights.append(param.clone().detach())
            extracted_layer_names.append(name)  # Save the layer name
    
    # print(f"Extracted {len(extracted_weights)} layers from base model ({base_model_name}):")
    # for layer_name in extracted_layer_names:
    #     print(layer_name)

    return extracted_weights, extracted_layer_names

base_weights, layer_names = extract_weights_from_base_model(model_name)



In [None]:
#======Check if we have extracted tensor weights======#

# print(f"Base weights type: {[type(weight) for weight in base_weights]}")
# print(f"Type of each extracted weight: {[type(w) for w in delta_weights]}")

In [None]:
#=======Metric Functions=======#

def cosine_similarity_matrix(W1, W2):
    """
    """
    U1, _, _ = torch.svd(W1)  
    U2, _, _ = torch.svd(W2)  
    cos_sim_matrix = torch.mm(U1.T, U2)
    return cos_sim_matrix 


def count_intruder_dimensions(similarity_matrix, epsilon=0.6):
    """
    Count intruder dimensions from the similarity matrix.
    """

    if isinstance(similarity_matrix, torch.Tensor):
        similarity_matrix = similarity_matrix.detach().cpu().numpy()
    elif not isinstance(similarity_matrix, np.ndarray):
        raise ValueError("similarity_matrix must be a torch.Tensor or a NumPy array.")
    
    max_similarities = np.max(similarity_matrix, axis=1)
    intruder_count = np.sum(max_similarities < epsilon)
    return intruder_count


def plot_max_similarity_checkerboard(W1, W2, save_path=None, epoch=None, layer=None, top_k=50):
    """
    Compute and plot the similarity matrix with masked max similarities.
    """

    W1 = W1.to("cuda")
    W2 = W2.to("cuda")

    U1, S1, _ = torch.svd(W1)
    U2, S2, _ = torch.svd(W2)

    # sort singular vectors
    U1_sorted = U1[:, torch.argsort(S1, descending=True)]
    U2_sorted = U2[:, torch.argsort(S2, descending=True)]

    similarity_matrix = torch.mm(U1_sorted.T, U2_sorted)

    # select top_k singular vectors
    similarity_matrix = similarity_matrix[:top_k, :top_k]

    # mask max similarities
    max_similarity_matrix = torch.zeros_like(similarity_matrix, device=W1.device)
    max_indices = torch.argmax(similarity_matrix, dim=1)
    max_similarity_matrix[torch.arange(similarity_matrix.size(0), device=W1.device), max_indices] = \
        similarity_matrix[torch.arange(similarity_matrix.size(0), device=W1.device), max_indices]
    max_similarity_matrix_np = max_similarity_matrix.detach().cpu().numpy()

    plt.figure(figsize=(8, 6))
    plt.imshow(max_similarity_matrix_np, cmap="Blues", interpolation="nearest")
    plt.colorbar(label="Cosine Similarity")
    plt.xlabel(f"Singular Vectors in $W_{{tuned}}$")
    plt.ylabel(f"Singular Vectors in $W_{{0}}$")
    title = f"Similarity Matrix: Epoch{epoch}_{layer}"
    plt.title(title)

    # Save or show the plot
    if save_path:
        plt.savefig(save_path)
        print(f"Saved similarity plot to: {save_path}")
    plt.close()


In [None]:
#======Metrics Class=======#

class MetricsEvaluator:
    def __init__(self, base_weights, delta_weights, layer_names,plot_file_name,csv_name, log_dir="./logs"):

        self.base_weights = [weight.to("cuda") for weight in base_weights]
        self.delta_weights = [weight.to("cuda") for weight in delta_weights]
        self.layer_names = [
            name.replace("encoder.", "") for name in layer_names
        ]  # remove "encoder." prefix for cleaner names, these layer names are used for plotting title
        self.log_dir = log_dir

        self.similarity_matrices_dir = os.path.join(log_dir,plot_file_name )
        self.log_file = os.path.join(log_dir, csv_name)

        os.makedirs(self.similarity_matrices_dir, exist_ok=True)
        if not os.path.exists(self.log_file):
            with open(self.log_file, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["epoch", "layer", "intruder_count"])

    def calculate_similarity_matrix(self, base_weight, delta_weight):
        fine_tuned_weight = base_weight + delta_weight
        return cosine_similarity_matrix(base_weight, fine_tuned_weight)

    def count_intruders(self, similarity_matrix, epsilon=0.6):
        return count_intruder_dimensions(similarity_matrix, epsilon)

    def plot_similarity(self, base_weight, delta_weight, epoch, layer_name):
        """
        plot the similarity matrix for a specific layer.
        """
        plot_file = os.path.join(self.similarity_matrices_dir, f"epoch_{epoch}_{layer_name}.png")
        fine_tuned_weight = base_weight + delta_weight

        plot_max_similarity_checkerboard(
            base_weight,
            fine_tuned_weight,
            save_path=plot_file,
            epoch=epoch,
            layer=layer_name,
        )
        print(f"Saved similarity plot for {layer_name} at epoch {epoch} to {plot_file}")

    def evaluate(self, epoch):
        """
        we need the epoch to write in csv, I am saving models at each epochs so..
        """
        assert len(self.base_weights) == len(self.delta_weights) == len(self.layer_names), (
            "Base weights, delta weights, and layer names must all have the same length!"
        )

        for i, (base_weight, delta_weight) in enumerate(zip(self.base_weights, self.delta_weights)):
            layer_name = self.layer_names[i]  #for plotting
            try:
                similarity_matrix = self.calculate_similarity_matrix(base_weight, delta_weight) #finetuned weight is retrieved in the fn later
                intruder_count = self.count_intruders(similarity_matrix)

                # save intruder count to CSV
                with open(self.log_file, "a", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch, layer_name, intruder_count])

                self.plot_similarity(base_weight, delta_weight, epoch, layer_name)
                print(f"Processed {layer_name}: Intruder Count = {intruder_count}")

            except Exception as e:
                print(f"Error processing {layer_name}: {e}")


In [None]:
#======Get the metrics======#

metrics_evaluator = MetricsEvaluator(
    base_weights=base_weights,
    delta_weights=delta_weights,
    layer_names=layer_names,  # pass the layer names
    plot_file_name=plot_file_name,
    csv_name=csv_name,
    log_dir="./logs"
)
metrics_evaluator.evaluate(epoch=3) #write the epoch of the saved model


In [None]:
#========Load benchmarks=======#

benchmarks = {
    "mnli_matched": load_dataset("glue", "mnli", split="validation_matched"),
    "mnli_mismatched": load_dataset("glue", "mnli", split="validation_mismatched"),
    "snli": load_dataset("snli", split="test"),
    "hans": load_dataset("hans", split="validation")
}

# ======== Preprocessing Function ======== #
def preprocess_eval_function(examples, dataset_name, tokenizer):
    if dataset_name in ["mnli_matched", "mnli_mismatched", "snli", "hans"]:
        return tokenizer(
            examples["premise"], examples["hypothesis"],
            truncation=True, padding="max_length", max_length=128
        )
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)


# ======== Evaluation Function ======== #

def evaluate_model_from_path(model_path, benchmarks, batch_size=32):
    """
    Evaluates a fine-tuned LoRA model on given benchmarks, computing separate average accuracy scores
    for in-distribution and out-of-distribution tasks.
    """
    # Load LoRA adapter configuration
    adapter_config = PeftConfig.from_pretrained(model_path)

    # Load the base model and tokenizer
    base_model = AutoModelForSequenceClassification.from_pretrained(
        adapter_config.base_model_name_or_path,
        num_labels=3  # Assuming 3 labels for MNLI-like tasks
    ).to("cuda:0")
    model = PeftModel.from_pretrained(base_model, model_path).eval()  # Set to eval mode
    tokenizer = AutoTokenizer.from_pretrained(adapter_config.base_model_name_or_path)

    # Separate benchmarks into in-distribution and out-of-distribution
    in_distribution = ["mnli_matched"]
    out_of_distribution = ["mnli_mismatched", "snli", "hans"]

    in_acc_total, out_acc_total = 0.0, 0.0
    in_count, out_count = 0, 0

    for dataset_name, dataset in benchmarks.items():
        print(f"Evaluating on {dataset_name}...")

        tokenized_dataset = dataset.map(
            lambda x: preprocess_eval_function(x, dataset_name, tokenizer),
            batched=True,
            remove_columns=[col for col in dataset.column_names if col not in ["label"]]
        )
        if "label" in tokenized_dataset.column_names:
            tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
        tokenized_dataset.set_format("torch")

        data_loader = torch.utils.data.DataLoader(
            tokenized_dataset,
            batch_size=batch_size,
            collate_fn=lambda x: {
                key: torch.stack([example[key] for example in x]).to("cuda:0")
                for key in ["input_ids", "attention_mask", "labels"]
            },
        )

        all_predictions, all_labels = [], []

        for batch in data_loader:
            inputs = {k: batch[k] for k in ["input_ids", "attention_mask"]}
            labels = batch["labels"]
            with torch.no_grad():
                outputs = model(**inputs)
                predictions = outputs.logits.argmax(dim=-1)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        acc = accuracy_score(all_labels, all_predictions)

        # Accumulate accuracy scores based on distribution type
        if dataset_name in in_distribution:
            in_acc_total += acc
            in_count += 1
        elif dataset_name in out_of_distribution:
            out_acc_total += acc
            out_count += 1

    # Calculate average accuracy for in-distribution and out-of-distribution
    avg_in_acc = in_acc_total / in_count if in_count > 0 else 0.0
    avg_out_acc = out_acc_total / out_count if out_count > 0 else 0.0

    print(f"Average In-Distribution Accuracy: {avg_in_acc:.4f}")
    print(f"Average Out-of-Distribution Accuracy: {avg_out_acc:.4f}")

    return {
        "avg_in_distribution_accuracy": avg_in_acc,
        "avg_out_of_distribution_accuracy": avg_out_acc
    }


In [None]:
evaluate_model_from_path(fine_tuned_model_path, benchmarks, batch_size=32)