In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
ls

[0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [3]:
%cd drive/MyDrive/DL_project/llm_finetuning/notebooks/

/content/drive/MyDrive/DL_project/llm_finetuning/notebooks


In [4]:
!pip install -q transformers accelerate bitsandbytes datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.6/297.6 kB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.9/388.9 kB[0m [31m39.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [8]:
import torch
import numpy as np
from torch.nn.functional import kl_div, softmax, log_softmax
from torch.optim import AdamW
from datasets import load_dataset, ClassLabel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm import tqdm


torch.cuda.empty_cache()

# for reproducibility
np.random.seed(42)

torch.manual_seed(42)

if torch.cuda.is_available():
  torch.cuda.manual_seed_all(42)


data = load_dataset("glue", "mnli")
hans_data = load_dataset("hans")
#print(hans_data.keys())

#Below method is refenced from: https://github.com/uds-lsv/llmft/blob/main/notebooks/majority_baseline.ipynb
def binarize_mnli(dataset, remove_neutral=True):
    if remove_neutral:
        # neutral class has label 1
        dataset = dataset.filter(lambda example: example["label"] != 1)

    # change labels of contradiction examples from 2 to 1
    def change_label(example):
        # convert labels 2 into labels 1. this merges the neutral and contradiction class
        example["label"] = 1 if example["label"] == 2 else example["label"]
        return example

    # change labels
    dataset = dataset.map(change_label)

    # change features to reflect the new labels
    features = dataset["train"].features.copy()
    features["label"] = ClassLabel(num_classes=2, names=['entailment', 'contradiction'], id=None)
    dataset = dataset.cast(features)  # overwrite old features

    return dataset

data = binarize_mnli(data, remove_neutral=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the original model
original_model = AutoModelForSequenceClassification.from_pretrained("facebook/opt-125m")
original_model.to(device)  # Move the model to the device (gpu if available)

# Load the student model
model = AutoModelForSequenceClassification.from_pretrained("facebook/opt-125m")
model.to(device)  # Move the model to the device (gpu if available)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

# Define the fixed context
fixed_context = "Given the premise, does the hypothesis hold true? "

##Prepare the inputs with the fixed context
def manipulate_inputs(batch):
    encoding = tokenizer([f'{fixed_context} Premise: {premise} Hypothesis: {hypothesis}'
                          for premise, hypothesis in zip(batch["premise"], batch["hypothesis"])],
                          truncation=True, padding="max_length", max_length=128, return_tensors='pt')
    batch["input_ids"] = encoding["input_ids"].squeeze()
    batch["attention_mask"] = encoding["attention_mask"].squeeze()
    return batch

data = data.map(manipulate_inputs, batched=True)
hans_data = hans_data.map(manipulate_inputs, batched=True)

# print("************")
# print((data["train"][0]))

# Define a custom loss function
def custom_loss(model_probs, original_model_probs):
    return kl_div(model_probs.log_softmax(dim=-1), original_model_probs, reduction='batchmean')

# Define the optimizer
optimizer = AdamW(model.parameters(), lr=1e-5)

#training and eval datasets
train_dataset = data["train"]
in_domain_eval_dataset = data["validation_matched"]
out_of_domain_eval_dataset = hans_data["validation"]



# print(len(train_dataset))

from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    premises = [item['premise'] for item in batch]
    hypotheses = [item['hypothesis'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch])
    input_ids = pad_sequence([torch.tensor(item['input_ids']) for item in batch], batch_first=True)
    attention_mask = pad_sequence([torch.tensor(item['attention_mask']) for item in batch], batch_first=True)

    return {'premise': premises, 'hypothesis': hypotheses, 'label': labels, 'input_ids': input_ids, 'attention_mask': attention_mask}

train_dataloader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn)
# Define the dataloaders for evaluation
in_domain_dataloader = DataLoader(in_domain_eval_dataset, batch_size=32, collate_fn=collate_fn)
out_of_domain_dataloader = DataLoader(out_of_domain_eval_dataset, batch_size=32, collate_fn=collate_fn)


# print(type(train_dataloader))
# # Convert the DataLoader to an iterator
# train_iter = iter(train_dataloader)

# # Get the first element
# first_element = next(train_iter)

# # Print the first element
# print(first_element)

# print("**************************")
# exit(1)


# Define the original task loss function
task_loss = CrossEntropyLoss()

# Initialize the lists to store the accuracies
in_domain_accuracies = []
out_of_domain_accuracies = []

# Custom training loop

for epoch in range(3):  # num_train_epochs
  model.train()
  batch_count = 0
  for batch in tqdm(train_dataloader, desc="Training"):
    #print("training batch num: ", batch_count, "of epoch: ", epoch)
    # Move the batch tensors to the same device as the model
    #print(batch["input_ids"].shape, batch["label"].shape)
    # print("hhhhhhhhhhhhhhhhhh")
    # print(len(batch["premise"]))
    # print(len(batch["input_ids"]))
    # print(len(batch["label"]))
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["label"].to(device)

    # Forward pass through the original model
    with torch.no_grad():
      outputs_original = original_model(input_ids=input_ids, attention_mask=attention_mask)
      original_model_probs = outputs_original.logits

    # Forward pass through the student model
    outputs_student = model(input_ids=input_ids, attention_mask=attention_mask)
    model_probs = outputs_student.logits

    # Compute the KL divergence loss
    distillation_loss = custom_loss(model_probs, original_model_probs)


    # Compute the task loss
    classification_loss = task_loss(outputs_student.logits, labels)

    # Combine the losses
    loss = 0.5 * distillation_loss + 0.5 * classification_loss

    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    batch_count = batch_count + 1

  # Evaluation loop for in-domain accuracy
  model.eval()
  correct_predictions = 0
  total_predictions = 0

  in_domain_batch_count = 0
  with torch.no_grad():
    for batch in tqdm(in_domain_dataloader, desc="Evaluating in-domain"):
      #print("eval in-domain batch num: ", in_domain_batch_count, "of epoch: ", epoch)
      # Move batch to device
      input_ids = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      labels = batch["label"].to(device)

      # Forward pass through the fine-tuned model
      outputs = model(input_ids=input_ids, attention_mask=attention_mask)
      predictions = torch.argmax(outputs.logits, dim=-1)

      # Update counters
      correct_predictions += (predictions == labels).sum().item()
      total_predictions += len(labels)
      in_domain_batch_count = in_domain_batch_count + 1

  in_domain_accuracy = correct_predictions / total_predictions
  print(f"In-domain accuracy: {in_domain_accuracy}")

  # Evaluation loop for out-of-domain accuracy
  correct_predictions = 0
  total_predictions = 0

  out_of_domain_batch_count = 0
  with torch.no_grad():
    for batch in tqdm(out_of_domain_dataloader, desc="Evaluating out-domain"):
      #print("eval out-of-domain batch num: ", out_of_domain_batch_count, "of epoch: ", epoch)
      # Move batch to device
      input_ids = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      labels = batch["label"].to(device)

      # Forward pass through the fine-tuned model
      outputs = model(input_ids=input_ids, attention_mask=attention_mask)
      predictions = torch.argmax(outputs.logits, dim=-1)

      # Update counters
      correct_predictions += (predictions == labels).sum().item()
      total_predictions += len(labels)
      out_of_domain_batch_count = out_of_domain_batch_count + 1

  out_of_domain_accuracy = correct_predictions / total_predictions
  print(f"Out-of-domain accuracy: {out_of_domain_accuracy}")

  # Add the accuracies to the lists
  in_domain_accuracies.append(in_domain_accuracy)
  out_of_domain_accuracies.append(out_of_domain_accuracy)

# Compute the maximum and average accuracies
max_in_domain_accuracy = max(in_domain_accuracies)
average_in_domain_accuracy = sum(in_domain_accuracies) / len(in_domain_accuracies)

max_out_of_domain_accuracy = max(out_of_domain_accuracies)
average_out_of_domain_accuracy = sum(out_of_domain_accuracies) / len(out_of_domain_accuracies)

# Print the maximum and average accuracies
print(f"Maximum in-domain accuracy: {max_in_domain_accuracy}")
print(f"Average in-domain accuracy: {average_in_domain_accuracy}")
print(f"Maximum out-of-domain accuracy: {max_out_of_domain_accuracy}")
print(f"Average out-of-domain accuracy: {average_out_of_domain_accuracy}")

# # Save the accuracies to a CSV file
# results_df = pd.DataFrame({
#     "in_domain_accuracy": in_domain_accuracies,
#     "out_of_domain_accuracy": out_of_domain_accuracies
# })
# results_df.to_csv("../Results/context_distillation_mnli.csv", index=False)


Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-125m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-125m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/6692 [00:00<?, ? examples/s]

Training: 100%|██████████| 8182/8182 [1:52:04<00:00,  1.22it/s]
Evaluating in-domain: 100%|██████████| 210/210 [00:42<00:00,  4.93it/s]


In-domain accuracy: 0.6313508667065152


Evaluating out-domain: 100%|██████████| 938/938 [03:09<00:00,  4.96it/s]


Out-of-domain accuracy: 0.5003666666666666


Training: 100%|██████████| 8182/8182 [1:52:05<00:00,  1.22it/s]
Evaluating in-domain: 100%|██████████| 210/210 [00:42<00:00,  4.92it/s]


In-domain accuracy: 0.6676628810520024


Evaluating out-domain: 100%|██████████| 938/938 [03:08<00:00,  4.97it/s]


Out-of-domain accuracy: 0.5249666666666667


Training: 100%|██████████| 8182/8182 [1:51:59<00:00,  1.22it/s]
Evaluating in-domain: 100%|██████████| 210/210 [00:42<00:00,  4.94it/s]


In-domain accuracy: 0.7208607292289301


Evaluating out-domain: 100%|██████████| 938/938 [03:08<00:00,  4.97it/s]

Out-of-domain accuracy: 0.49746666666666667
Maximum in-domain accuracy: 0.7208607292289301
Average in-domain accuracy: 0.6732914923291493
Maximum out-of-domain accuracy: 0.5249666666666667
Average out-of-domain accuracy: 0.5075999999999999





NameError: name 'pd' is not defined

In [9]:
import pandas as pd
# Save the accuracies to a CSV file
results_df = pd.DataFrame({
    "in_domain_accuracy": in_domain_accuracies,
    "out_of_domain_accuracy": out_of_domain_accuracies
})
results_df.to_csv("../Results/context_distillation_mnli.csv", index=False)
