In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd

from transformers import AutoModelForCausalLM, AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns

import lm_eval
from lm_eval.models.huggingface import HFLM

In [2]:
device = torch.device('cuda:3')
torch.set_default_device(device)

In [3]:
model_name = 'microsoft/phi-2' 

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Divergence Plots

In [4]:
def sm_wassertain(logits1, logits2):
    prob1 = F.softmax(logits1, dim=-1)
    prob2 = F.softmax(logits2, dim=-1)
    
    cdf1 = torch.cumsum(prob1, dim=-1)
    cdf2 = torch.cumsum(prob2, dim=-1)
    
    wasserstein_dist = torch.sum(torch.abs(cdf1 - cdf2), dim=-1)
    
    return wasserstein_dist.mean()

In [5]:
def kl_div(p, q, epsilon=1e-10):
    p = p + epsilon
    q = q + epsilon
    return (p * (torch.log2(p) - torch.log2(q))).sum()

def sm_jsd(p, q):
    """Returns the Jensen-Shannon Divergence of softmax-ed logits"""
    p = F.softmax(p, dim=-1)
    q = F.softmax(q, dim=-1)
    m = 0.5 * (p + q)
    return 0.5 * kl_div(p, m) + 0.5 * kl_div(q, m)

In [6]:
lm_obj = HFLM(pretrained=model, tokenizer=tokenizer)
results = lm_eval.simple_evaluate(
    model=lm_obj,
    tasks=["truthfulqa"],
    num_fewshot=0,
    limit=10,
    log_samples=False)

2024-05-29:03:09:38,426 INFO     [evaluator.py:131] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
2024-05-29:03:09:48,544 INFO     [evaluator.py:218] num_fewshot has been set to 0 for truthfulqa_mc1 in its config. Manual configuration will be ignored.
2024-05-29:03:09:48,545 INFO     [evaluator.py:218] num_fewshot has been set to 0 for truthfulqa_mc2 in its config. Manual configuration will be ignored.
2024-05-29:03:09:48,545 INFO     [evaluator.py:218] num_fewshot has been set to 0 for truthfulqa_gen in its config. Manual configuration will be ignored.
2024-05-29:03:09:48,549 INFO     [task.py:395] Building contexts for truthfulqa_mc1 on rank 0...
100%|██████████| 10/10 [00:00<00:00, 681.89it/s]
2024-05-29:03:09:48,567 INFO     [task.py:395] Building contexts for truthfulqa_mc2 on rank 0...
100%|██████████| 10/10 [00:00<00:00, 1055.25it/s]
2024-05-29:03:09:48,579 INFO     [task.py:395] Building contexts for truthfulqa_gen on rank 0...
100%|█

In [7]:
df = pd.DataFrame(results['results']).transpose()
df['Baseline Accuracy'] = df[['acc,none','acc_stderr,none']].apply(lambda x : '{} ± {}'.format(round(x[0],2), round(x[1], 4)), axis=1)
df[['Baseline Accuracy']]

  df['Baseline Accuracy'] = df[['acc,none','acc_stderr,none']].apply(lambda x : '{} ± {}'.format(round(x[0],2), round(x[1], 4)), axis=1)


Unnamed: 0,Baseline Accuracy
truthfulqa,0.51 ± 0.1097
truthfulqa_gen,nan ± nan
truthfulqa_mc1,0.4 ± 0.1633
truthfulqa_mc2,0.61 ± 0.1465


In [8]:
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset

# def calculate_max_length(questions, answer_choices, tokenizer):
#     max_length = 0
#     for question, choices in zip(questions, answer_choices):
#         for choice in choices:
#             combined_length = len(tokenizer.encode(question + " " + choice))
#             if combined_length > max_length:
#                 max_length = combined_length
#     return max_length

class QADataset(Dataset):
    def __init__(self, tokenizer, max_length):
        self.dataset = load_dataset("truthful_qa", "multiple_choice", split="validation")
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        question = self.dataset[idx]['question']
        correct_choice_idx = self.dataset[idx]['mc1_targets']['labels'].index(1)
        answer = self.dataset[idx]['mc1_targets']['choices'][correct_choice_idx]


        encoding = self.tokenizer(
            f"Instruct: {question}\nOutput:{answer}",
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

       
        return {
            'input_ids': encoding["input_ids"].squeeze(),
            'attention_mask': encoding["attention_mask"].squeeze(),
            'labels': encoding["input_ids"].squeeze(),
        }

In [14]:
from typing import Optional, Union, Tuple, List

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
from transformers.models.phi.modeling_phi import PhiForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import AutoTokenizer


class PhiDOLa(PhiForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.final_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        # self.layer_weights = Parameter(torch.cat([torch.zeros(config.num_hidden_layers-1), torch.ones(1)]))
        self.layer_combnet = Parameter(32,1)

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
        ) -> Union[Tuple, CausalLMOutputWithPast]:

            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = True
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

            hidden_states = outputs[0]
            logits = hidden_states 
            logits = self.lm_head(hidden_states)
            logits = logits.float()

            logits_list = []
            
            for hidden_state in outputs.hidden_states[:-1]:
                int_logits = self.lm_head(self.final_layernorm(hidden_state)).float()
                logits_list.append(int_logits.unsqueeze(-1))

            logits_tensor = torch.cat(logits_list, dim=-1)
            # print(logits_tensor.shape)

            # weighted_logits = torch.zeros_like(logits_list[0])
            # for i, logits_ in enumerate(logits_list):
            #     weighted_logits += self.layer_weights[i] * logits_

            # logits = weighted_logits

            logits = self.layer_combnet(logits_tensor).squeeze(-1)
            print(logits.shape)

            loss = None
            if labels is not None:
                # Shift so that tokens < n predict n
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()
                # Flatten the tokens
                loss_fct = CrossEntropyLoss()
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
                shift_labels = shift_labels.view(-1)
                # Enable model parallelism
                shift_labels = shift_labels.to(shift_logits.device)
                loss = loss_fct(shift_logits, shift_labels)

            if not return_dict:
                output = (logits,) + outputs[1:]
                return (loss,) + output if loss is not None else output

            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=outputs.past_key_values,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

def custom_loss_fn(log_probs, correct_choice_idx):
    correct_log_prob = log_probs[correct_choice_idx]
    loss = -correct_log_prob.mean()
    return loss


def freeze_model_except_layer_weights(model):
    for param in model.parameters():
        param.requires_grad = False
    model.layer_combnet.requires_grad = True


def fine_tune_layer_weights(model, dataloader, optimizer, num_epochs, tokenizer):
    model.train()
    for epoch in range(num_epochs):
        for data in dataloader:
            input_ids = data['input_ids']
            attention_mask = data['attention_mask']
            labels = data['labels']

    
            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            print(outputs.loss)
            loss = outputs.loss
            loss.backward()
            optimizer.step()




In [15]:

             
dataset = QADataset(tokenizer, max_length=64)
train_split, test_split = torch.utils.data.random_split(dataset, [0.8, 0.2], generator=torch.Generator(device=device))

train_loader = DataLoader(dataset, batch_size=8)
test_loader = DataLoader(dataset, batch_size=8)

model = PhiDOLa.from_pretrained(model_name)
freeze_model_except_layer_weights(model)
optimizer = torch.optim.AdamW([*model.layer_combnet.parameters()], lr=1e-3)

fine_tune_layer_weights(model, train_loader, optimizer, num_epochs=50, tokenizer=tokenizer)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of PhiDOLa were not initialized from the model checkpoint at microsoft/phi-2 and are newly initialized: ['final_layernorm.bias', 'final_layernorm.weight', 'layer_combnet.bias', 'layer_combnet.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([8, 64, 51200])
tensor(12.2787, device='cuda:3')


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
model.layer_weights

Parameter containing:
tensor([-4.0275e-02, -2.9973e-02, -1.6777e-02, -2.8931e-02, -3.0201e-02,
        -2.9434e-02, -2.8702e-02, -2.7871e-02, -2.7116e-02, -2.6579e-02,
        -2.4894e-02, -2.4658e-02, -2.4114e-02, -2.3455e-02, -2.2129e-02,
        -2.0788e-02, -1.9976e-02, -1.8693e-02, -1.6270e-02, -1.2697e-02,
        -9.8543e-03, -8.1205e-03, -5.9172e-03, -3.3296e-03, -9.3828e-04,
         1.8110e-03, -2.5356e-03, -7.1096e-04, -1.4312e-03, -3.8122e-03,
        -1.2069e-02,  9.8400e-01], device='cuda:3', requires_grad=True)

In [None]:
lm_obj = HFLM(pretrained=model, tokenizer=tokenizer)
results = lm_eval.simple_evaluate(
    model=lm_obj,
    tasks=["truthfulqa"],
    num_fewshot=0,
    limit=10)

2024-05-28:19:57:26,606 INFO     [evaluator.py:131] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234
2024-05-28:19:57:35,451 INFO     [evaluator.py:218] num_fewshot has been set to 0 for truthfulqa_mc1 in its config. Manual configuration will be ignored.
2024-05-28:19:57:35,452 INFO     [evaluator.py:218] num_fewshot has been set to 0 for truthfulqa_mc2 in its config. Manual configuration will be ignored.
2024-05-28:19:57:35,453 INFO     [evaluator.py:218] num_fewshot has been set to 0 for truthfulqa_gen in its config. Manual configuration will be ignored.
2024-05-28:19:57:35,455 INFO     [task.py:395] Building contexts for truthfulqa_mc1 on rank 0...
100%|██████████| 10/10 [00:00<00:00, 1049.86it/s]
2024-05-28:19:57:35,469 INFO     [task.py:395] Building contexts for truthfulqa_mc2 on rank 0...
100%|██████████| 10/10 [00:00<00:00, 1052.34it/s]
2024-05-28:19:57:35,482 INFO     [task.py:395] Building contexts for truthfulqa_gen on rank 0...
100%|

In [None]:
df_dola = pd.DataFrame(results['results']).transpose()
df['DOLa Accuracy'] = df_dola[['acc,none','acc_stderr,none']].apply(lambda x : '{} ± {}'.format(round(x[0],2), round(x[1], 4)), axis=1)
df[['Baseline Accuracy','DOLa Accuracy' ]]

  df['DOLa Accuracy'] = df_dola[['acc,none','acc_stderr,none']].apply(lambda x : '{} ± {}'.format(round(x[0],2), round(x[1], 4)), axis=1)


Unnamed: 0,Baseline Accuracy,DOLa Accuracy
truthfulqa,0.51 ± 0.1097,0.4 ± 0.1076
truthfulqa_gen,nan ± nan,nan ± nan
truthfulqa_mc1,0.4 ± 0.1633,0.4 ± 0.1633
truthfulqa_mc2,0.61 ± 0.1465,0.4 ± 0.1402
