In [27]:
import torch
import transformers
import numpy as np
from datasets import load_dataset, ClassLabel
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
mnli = load_dataset("glue", "mnli")
hans = load_dataset("hans")

In [3]:
def binarize_mnli(dataset, remove_neutral=True):
    if remove_neutral:
        # neutral class has label 1
        dataset = dataset.filter(lambda example: example["label"] != 1)

    # change labels of contradiction examples from 2 to 1
    def change_label(example):
        # convert labels 2 into labels 1. this merges the neutral and contradiction class
        example["label"] = 1 if example["label"] == 2 else example["label"]
        return example
        
    # change labels
    dataset = dataset.map(change_label)

    # change features to reflect the new labels
    features = dataset["train"].features.copy()
    features["label"] = ClassLabel(num_classes=2, names=['entailment', 'contradiction'], id=None)
    dataset = dataset.cast(features)  # overwrite old features
        
    return dataset

In [4]:
mnli = binarize_mnli(mnli)

In [5]:
model_name = "facebook/opt-125m"
teacher_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

seed = 100


In [6]:
yes_id = tokenizer.convert_tokens_to_ids("yes")
no_id = tokenizer.convert_tokens_to_ids("no")
print(f"'yes' token id: {yes_id}")
print(f"'no' token id: {no_id}")

'yes' token id: 10932
'no' token id: 2362


In [7]:
task_context = "Given the premise and hypothesis: reply yes if the premise entails the hypothesis, or no otherwise"

def generate_mnli_prompts(examples):
    # teacher model receives task context + premise + hypothesis
    examples["teacher_prompt"] = f"{task_context}\nPremise: {examples['premise']}\nHypothesis: {examples['hypothesis']}"
    # student model only receives premise + hypothesis
    examples["student_prompt"] = f"Premise: {examples['premise']}\nHypothesis: {examples['hypothesis']}"
    return examples

mnli = mnli.map(generate_mnli_prompts)


Map:   0%|          | 0/261802 [00:00<?, ? examples/s]

Map:   0%|          | 0/6692 [00:00<?, ? examples/s]

Map:   0%|          | 0/6703 [00:00<?, ? examples/s]

Map:   0%|          | 0/9796 [00:00<?, ? examples/s]

Map:   0%|          | 0/9847 [00:00<?, ? examples/s]

In [76]:
train_size = 20
np.random.seed(seed)
selected_idx = np.random.choice(1000, train_size)

mnli_train = mnli["train"].select(selected_idx)
labels_train = torch.tensor(mnli_train["label"])

In [77]:
def tokenize_teacher(data):
    tokens = tokenizer(data["teacher_prompt"], padding=True, truncation=True, return_tensors="pt")
    return tokens


tokenized_teacher_mnli_train = mnli_train.map(tokenize_teacher, batched=True)
tokenized_teacher_mnli_train.set_format(type="torch", columns=["input_ids", "attention_mask"])

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

In [78]:

input_ids = tokenized_teacher_mnli_train["input_ids"]
attention_mask = tokenized_teacher_mnli_train["attention_mask"]
teacher_model.eval()

outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)

logits = outputs.logits
    
teacher_logits = logits[:, -1, [yes_id, no_id]]
teacher_pred = logits[:, -1, [yes_id, no_id]].argmax(dim=-1)
teacher_acc = (teacher_pred == labels_train).float().mean().item()
teacher_acc

0.5

In [112]:
def context_distillation_loss(labels, teacher_logits, student_logits, alpha=0.5):
    with torch.no_grad():
        teacher_logprob = torch.nn.functional.softmax(teacher_logits, dim=-1)
    student_prob = torch.nn.functional.log_softmax(student_logits, dim=-1)
    kl_loss = torch.nn.functional.kl_div(student_prob, teacher_logprob, reduction="batchmean")
    ce_loss = torch.nn.functional.cross_entropy(student_logits, labels)
    beta = 1 - alpha
    cd_loss = alpha * kl_loss + beta * ce_loss
    return cd_loss

In [80]:
def tokenize_student(data):
    tokens = tokenizer(data["student_prompt"], padding=True, truncation=True, return_tensors="pt")
    return tokens


tokenized_student_mnli_train = mnli_train.map(tokenize_student, batched=True)
tokenized_student_mnli_train.set_format(type="torch", columns=["input_ids", "attention_mask"])

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

In [87]:
def evaluate_model(model, inputs, labels):

    
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    model.eval()
    
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    
    logits = outputs.logits

    target_logits = logits[:, -1, [yes_id, no_id]]
    pred = target_logits.argmax(dim=-1)


    ce_loss = torch.nn.functional.cross_entropy(target_logits, labels).item()

    acc = (pred == labels).float().mean().item()
    return ce_loss, acc
    



In [82]:
eval_size = 20
np.random.seed(seed + 1)
selected_idx_eval = np.random.choice(1000, eval_size)

mnli_eval = mnli["train"].select(selected_idx_eval)
labels_eval = torch.tensor(mnli_eval["label"])
tokenized_student_mnli_eval = mnli_eval.map(tokenize_student, batched=True)
tokenized_student_mnli_eval.set_format(type="torch", columns=["input_ids", "attention_mask"])

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

In [113]:
def train_student_model(student_model, lr, epochs, train_tokenized, train_labels, validation_tokenized, validation_labels, teacher_logits, target_tokens):
    optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)

    for epoch in range(epochs):
        
        input_ids = train_tokenized["input_ids"]
        attention_mask = train_tokenized["attention_mask"]

        outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)

        logits = outputs.logits
        student_target_logits = logits[:, -1, target_tokens]

        optimizer.zero_grad()
        cd_loss = context_distillation_loss(train_labels, teacher_logits, student_target_logits)
        cd_loss.backward()
        optimizer.step()

        pred = student_target_logits.argmax(dim=-1)
        train_acc = (pred == train_labels).float().mean().item()
        
        val_loss, val_acc = evaluate_model(student_model, validation_tokenized, validation_labels)


        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"\tTraining Loss: {cd_loss.item():.4f}\t\tTraining Accuracy: {train_acc:.4f}")
        print(f"\tValidation Loss: {val_loss:.4f}\t\tValidation Accuracy: {val_acc:.4f}")

    train_loss, train_acc = evaluate_model(student_model, train_tokenized, train_labels)
    val_loss, val_acc = evaluate_model(student_model, validation_tokenized, validation_labels)
    print("Final model")
    print(f"\tTraining Loss: {train_loss:.4f}\t\tTraining Accuracy: {train_acc:.4f}")
    print(f"\tValidation Loss: {val_loss:.4f}\t\tValidation Accuracy: {val_acc:.4f}")

In [92]:
student_model = AutoModelForCausalLM.from_pretrained(model_name)
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-5)
epochs = 30

for epoch in range(epochs):
    
    input_ids = tokenized_student_mnli_train["input_ids"]
    attention_mask = tokenized_student_mnli_train["attention_mask"]

    outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)

    logits = outputs.logits
    student_logits = logits[:, -1, [yes_id, no_id]]

    optimizer.zero_grad()
    cd_loss = context_distillation_loss(labels_train, teacher_logits, student_logits)
    cd_loss.backward()
    optimizer.step()

    pred = student_logits.argmax(dim=-1)
    train_acc = (pred == labels_train).float().mean().item()
    
    val_loss, val_acc = evaluate_model(student_model, tokenized_student_mnli_eval, labels_eval)




    print(f"Epoch [{epoch + 1}/{epochs}]")
    print(f"\tTraining Loss: {cd_loss.item():.4f}\t\tTraining Accuracy: {train_acc:.4f}")
    print(f"\tValidation Loss: {val_loss:.4f}\t\tValidation Accuracy: {val_acc:.4f}")

train_loss, train_acc = evaluate_model(student_model, tokenized_student_mnli_train, labels_train)
val_loss, val_acc = evaluate_model(student_model, tokenized_student_mnli_eval, labels_eval)
print("Final model")
print(f"\tTraining Loss: {train_loss:.4f}\t\tTraining Accuracy: {train_acc:.4f}")
print(f"\tValidation Loss: {val_loss:.4f}\t\tValidation Accuracy: {val_acc:.4f}")




Epoch [1/30]
	Training Loss: 0.3568		Training Accuracy: 0.5500
	Validation Loss: 0.4873		Validation Accuracy: 0.7500
Epoch [2/30]
	Training Loss: 0.3647		Training Accuracy: 0.5500
	Validation Loss: 1.0290		Validation Accuracy: 0.2500
Epoch [3/30]
	Training Loss: 0.5094		Training Accuracy: 0.7500
	Validation Loss: 0.5738		Validation Accuracy: 0.8500
Epoch [4/30]
	Training Loss: 0.2948		Training Accuracy: 0.9000
	Validation Loss: 0.4685		Validation Accuracy: 0.7500
Epoch [5/30]
	Training Loss: 0.3497		Training Accuracy: 0.5000
	Validation Loss: 0.4726		Validation Accuracy: 0.7500
Epoch [6/30]
	Training Loss: 0.3339		Training Accuracy: 0.5000
	Validation Loss: 0.5181		Validation Accuracy: 0.8500
Epoch [7/30]
	Training Loss: 0.2722		Training Accuracy: 0.7500
	Validation Loss: 0.6347		Validation Accuracy: 0.7500
Epoch [8/30]
	Training Loss: 0.2701		Training Accuracy: 1.0000
	Validation Loss: 0.7193		Validation Accuracy: 0.3500
Epoch [9/30]
	Training Loss: 0.2929		Training Accuracy: 1.0000
	

In [100]:
indomain_size = 100
np.random.seed(seed)
indomain_idx = np.random.choice(mnli["validation_mismatched"].num_rows, indomain_size)
indomain = mnli["validation_mismatched"].select(indomain_idx)
indomain_labels = torch.tensor(indomain["label"])

tokenized_indomain = indomain.map(tokenize_student, batched=True)
tokenized_indomain.set_format(type="torch", columns=["input_ids", "attention_mask"])
loss, acc = evaluate_model(student_model, tokenized_indomain, indomain_labels)
print(f"In-domain Accuracy: {acc:.4f}")

In-domain Accuracy: 0.5800


In [102]:
# preprocess HANS dataset 


# add student prompt
def generate_hans_prompts(examples):
    # student model only receives premise + hypothesis
    examples["student_prompt"] = f"Premise: {examples['premise']}\nHypothesis: {examples['hypothesis']}"
    return examples

hans = hans.map(generate_hans_prompts)


Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

In [111]:
outdomain_size = 100
np.random.seed(seed)
outdomain_idx = np.random.choice(hans["validation"].num_rows, outdomain_size)
outdomain = hans["validation"].select(outdomain_idx)
outdomain_labels = torch.tensor(outdomain["label"])

tokenized_outdomain = outdomain.map(tokenize_student, batched=True)
tokenized_outdomain.set_format(type="torch", columns=["input_ids", "attention_mask"])
loss, acc = evaluate_model(student_model, tokenized_outdomain, outdomain_labels)
print(f"Out-domain Accuracy: {acc:.4f}")

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Out-domain Accuracy: 0.4800


In [114]:
# train model and compute in-domain and out-domain accuracies
def compute_model_performance(model_name, train_dataset, validation_dataset, indomain_dataset, outdomain_dataset, 
                              epochs, lr):
    print(f"Model: {model_name}")
    teacher_model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # compare yes/no token logit/probability for context distillation
    yes_id = tokenizer.convert_tokens_to_ids("yes")
    no_id = tokenizer.convert_tokens_to_ids("no")
    target_token_ids = [yes_id, no_id]
    
    # Pre-computing teacher logits to train student
    tokenized_teacher_train = train_dataset.map(tokenize_teacher, batched=True)
    tokenized_teacher_train.set_format(type="torch", columns=["input_ids", "attention_mask"])

    teacher_input_ids = tokenized_teacher_train["input_ids"]
    teacher_attention_mask = tokenized_teacher_train["attention_mask"]
    
    # disable gradients
    teacher_model.eval()
    teacher_logits = teacher_model(input_ids=teacher_input_ids, attention_mask=teacher_attention_mask).logits
        
    # extract yes/no logits
    teacher_target_logits = teacher_logits[:, -1, target_token_ids]
    teacher_pred = teacher_target_logits.argmax(dim=-1)
    teacher_acc = (teacher_pred == labels_train).float().mean().item()
    print(f"Teacher Training Accuracy: {teacher_acc}")

    # train student model
    tokenized_student_train = train_dataset.map(tokenize_student, batched=True)
    tokenized_student_train.set_format(type="torch", columns=["input_ids", "attention_mask"])

    tokenized_student_val = validation_dataset.map(tokenize_student, batched=True)
    tokenized_student_val.set_format(type="torch", columns=["input_ids", "attention_mask"])

    print("Training student model:")
    student_model = AutoModelForCausalLM.from_pretrained(model_name)
    train_labels = torch.tensor(train_dataset["label"])
    validation_labels = torch.tensor(validation_dataset["label"])
    train_student_model(student_model, lr, epochs, tokenized_student_train, train_labels, tokenized_student_val, validation_labels, teacher_target_logits, target_token_ids)

    # compute in-domain accuracy
    tokenized_indomain = indomain_dataset.map(tokenize_student, batched=True)
    tokenized_indomain.set_format(type="torch", columns=["input_ids", "attention_mask"])
    indomain_labels = torch.tensor(indomain_dataset["label"])
    id_loss, id_acc = evaluate_model(student_model, tokenized_indomain, indomain_labels)
    print(f"In-domain Accuracy: {id_acc:.4f}")

    # compute out-domain accuracy
    tokenized_outdomain = outdomain_dataset.map(tokenize_student, batched=True)
    tokenized_outdomain.set_format(type="torch", columns=["input_ids", "attention_mask"])
    outdomain_labels = torch.tensor(outdomain_dataset["label"])
    od_loss, od_acc = evaluate_model(student_model, tokenized_outdomain, outdomain_labels)
    print(f"Out-domain Accuracy: {od_acc:.4f}")

    return {
        "In-domain Accuracy": id_acc,
        "Out-domain Accuracy": od_acc,
    }

In [None]:
mnli = load_dataset("glue", "mnli")
hans = load_dataset("hans")

In [None]:
# preprocessing MNLI dataset


# convert MNLI into binary classification
def binarize_mnli(dataset, remove_neutral=True):
    if remove_neutral:
        # neutral class has label 1
        dataset = dataset.filter(lambda example: example["label"] != 1)

    # change labels of contradiction examples from 2 to 1
    def change_label(example):
        # convert labels 2 into labels 1. this merges the neutral and contradiction class
        example["label"] = 1 if example["label"] == 2 else example["label"]
        return example
        
    # change labels
    dataset = dataset.map(change_label)

    # change features to reflect the new labels
    features = dataset["train"].features.copy()
    features["label"] = ClassLabel(num_classes=2, names=['entailment', 'contradiction'], id=None)
    dataset = dataset.cast(features)  # overwrite old features
        
    return dataset

# add teacher and student prompts
task_context = "Given the premise and hypothesis: reply yes if the premise entails the hypothesis, or no otherwise"

def generate_mnli_prompts(examples):
    # teacher model receives task context + premise + hypothesis
    examples["teacher_prompt"] = f"{task_context}\nPremise: {examples['premise']}\nHypothesis: {examples['hypothesis']}"
    # student model only receives premise + hypothesis
    examples["student_prompt"] = f"Premise: {examples['premise']}\nHypothesis: {examples['hypothesis']}"
    return examples


mnli = binarize_mnli(mnli)
mnli = mnli.map(generate_mnli_prompts)

In [116]:
# generate train, validation, indomain, outdomain datasets

seed = 100
np.random.seed(seed)


# train dataset (MNLI train)
train_size = 20     # few shot examples
train_idx = np.random.choice(mnli["train"].num_rows, train_size)
train_dataset = mnli["train"].select(train_idx)

# validation dataset (MNLI validation matched)
val_size = 20
val_idx = np.random.choice(mnli["validation_matched"].num_rows, val_size)
validation_dataset = mnli["validation_matched"].select(val_idx)

# indomain dataset (MNLI validation mismatched)
indomain_size = 100
indomain_idx = np.random.choice(mnli["validation_mismatched"].num_rows, indomain_size)
indomain_dataset = mnli["validation_mismatched"].select(indomain_idx)

# outdomain (HANS validation)
outdomain_size = 100
outdomain_idx = np.random.choice(hans["validation"].num_rows, outdomain_size)
outdomain_dataset = hans["validation"].select(outdomain_idx)


In [117]:
model_name = "facebook/opt-125m"
epochs = 30
lr = 1e-5

compute_model_performance(model_name=model_name, train_dataset=train_dataset, validation_dataset=validation_dataset, 
                          indomain_dataset=indomain_dataset, outdomain_dataset=outdomain_dataset, epochs=epochs, lr=lr)

Model: facebook/opt-125m


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Teacher Training Accuracy: 0.44999998807907104


Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Training student model:
Epoch [1/30]
	Training Loss: 0.4500		Training Accuracy: 0.4000
	Validation Loss: 0.8251		Validation Accuracy: 0.4500
Epoch [2/30]
	Training Loss: 0.3942		Training Accuracy: 0.5000
	Validation Loss: 0.8014		Validation Accuracy: 0.5000
Epoch [3/30]
	Training Loss: 0.4022		Training Accuracy: 0.5500
	Validation Loss: 0.7920		Validation Accuracy: 0.5000
Epoch [4/30]
	Training Loss: 0.3891		Training Accuracy: 0.5500
	Validation Loss: 0.7870		Validation Accuracy: 0.5000
Epoch [5/30]
	Training Loss: 0.3770		Training Accuracy: 0.5500
	Validation Loss: 0.7785		Validation Accuracy: 0.4500
Epoch [6/30]
	Training Loss: 0.3749		Training Accuracy: 0.5000
	Validation Loss: 0.7634		Validation Accuracy: 0.5000
Epoch [7/30]
	Training Loss: 0.3672		Training Accuracy: 0.5000
	Validation Loss: 0.7445		Validation Accuracy: 0.5000
Epoch [8/30]
	Training Loss: 0.3583		Training Accuracy: 0.5500
	Validation Loss: 0.7257		Validation Accuracy: 0.5000
Epoch [9/30]
	Training Loss: 0.3531		Tra

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In-domain Accuracy: 0.5300


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Out-domain Accuracy: 0.5500


{'In-domain Accuracy': 0.5299999713897705,
 'Out-domain Accuracy': 0.550000011920929}