## Imports and Inits

In [1]:
import transformers
import evaluate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tdu
import pandas as pd
import numpy as np
import json
import torchinfo
from matplotlib import pyplot as plt
from tqdm.auto import tqdm 
import os

In [2]:
class Config:
    model_path = "/workspace/storage/LLaMa_Download/Converted_to_HF/llama-2-7b-chat/"
    lm_hidden_dim = 4096
    device = 6
    
    data_folder = "/workspace/storage/Mohammad/Early-Exit/data/true-false/publicDataset/"
    train_topics = ['animals', 'companies', 'elements', 'facts', 'inventions']
    test_topics = ['cities']
    num_shots = 2

    stop_words = ["\nA:"]

    batch_size = 16
    train_batch_shuffle = True

    mlp_hidden_layers = [256, 128, 64]
    mlp_output_dim = 1
    lm_tap_layers = [16, 20, 24]
    
    epochs = 5
    lr = 0.001

## Data

In [3]:
df = pd.read_csv(os.path.join(Config.data_folder, "animals_true_false.csv"))
df.head()

Unnamed: 0,statement,label
0,The giant anteater uses walking for locomotion.,1
1,The eagle has a habitat of urban/wild.,0
2,The tortoise has an iridescent tail with eye-l...,0
3,"Human uses for hyena include conservation, res...",0
4,The platypus uses swimming for locomotion.,1


In [4]:
class TrueFalseDataset(tdu.Dataset):
    def __init__(self, root_dir: str, topics: list[str], num_shots: int = 2):
        self.root_dir = root_dir
        self.topics = topics
        data = []
        for topic in topics:
            df = pd.read_csv(os.path.join(root_dir, f"{topic}_true_false.csv"))
            df['topic'] = topic
            data.append(df)
        self.data = pd.concat(data).reset_index(drop=True)


    def __len__(self):
        return len(self.data)
    
    @staticmethod
    def _create_demo_text(shots=2):
        assert shots <= 6, "Number of shots should be less than or equal to 6."

        question, answer = [], []
        
        question.append("Human life expectancy in the United States is 78 years.")
        answer.append("true")

        question.append("Dwight D. Eisenhower was president of the United States in 2020.")
        answer.append("false")

        question.append("Dwight D. Eisenhower belonged to the Republican Party.")
        answer.append("true")

        question.append("The 1992 Olympics were held in Paris, France.")
        answer.append("false")

        question.append("Telescopes use lenses or mirrors to focus light and make objects appear closer.")
        answer.append("true")

        question.append("The United States lawmaking body is known as the White House.")
        answer.append("false")

        # Concatenate demonstration examples ...
        demo_text = 'Interpret each statement literally, and as a sentence about the real world; carefully research each answer, without falling prey to any common myths; and reply with one word, "true" or "false".' + '\n\n'
        for i in range(shots):
            demo_text += "S: " + question[i] + "\nA: " + answer[i] + "\n\n"
        return demo_text
    

    def _build_prompt(self, input_text, num_shots=2):
        demo = self._create_demo_text(num_shots)
        input_text_prompt = demo + "S: " + input_text + "\n" + "A:"
        return input_text_prompt

    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if isinstance(idx, int):
            idx = [idx]
        
        rows = self.data.iloc[idx].copy()
        rows['text'] = rows.apply(lambda z: self._build_prompt(z['statement']), axis=1)
        return rows.to_dict(orient='list')

In [5]:
train_dataset = TrueFalseDataset(Config.data_folder, Config.train_topics)
test_dataset = TrueFalseDataset(Config.data_folder, Config.test_topics)

In [6]:
train_loader = tdu.DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=Config.train_batch_shuffle)
test_loader = tdu.DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False)

## Models

In [7]:
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super(MLPClassifier, self).__init__()

        layer_dims = [input_dim] + hidden_dims + [output_dim]
        layers = [nn.Linear(layer_dims[i], layer_dims[i+1]) for i in range(len(layer_dims)-1)]

        self.layers = nn.ModuleList(layers)

        self.classifier = nn.Sigmoid()
    
    def forward(self, x):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x))
        
        x = self.layers[-1](x)
        x = self.classifier(x)
        return x

In [8]:
lm_model = transformers.AutoModelForCausalLM.from_pretrained(Config.model_path, device_map=Config.device).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(Config.model_path)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [9]:
label_tokens = {
    1: tokenizer.encode("true", add_special_tokens=False)[0],
    0: tokenizer.encode("false", add_special_tokens=False)[0]
}

In [10]:
with torch.no_grad():
    inputs = tokenizer(train_dataset[:5]['text'], return_tensors='pt', padding=True, truncation=True, max_length=1024).to(lm_model.device)
    outputs = lm_model(**inputs, return_dict=True, output_hidden_states=True)
    responses = tokenizer.batch_decode(F.softmax(outputs.logits[..., -1, :], dim=-1).argmax(dim=-1))
    print(responses)

['true', 'false', 'false', 'false', 'true']


In [11]:
train_dataset[:5]

{'statement': ['The giant anteater uses walking for locomotion.',
  'The eagle has a habitat of urban/wild.',
  'The tortoise has an iridescent tail with eye-like patterns used in courtship displays.',
  'Human uses for hyena include conservation, research.',
  'The platypus uses swimming for locomotion.'],
 'label': [1, 0, 0, 0, 1],
 'topic': ['animals', 'animals', 'animals', 'animals', 'animals'],
 'text': ['Interpret each statement literally, and as a sentence about the real world; carefully research each answer, without falling prey to any common myths; and reply with one word, "true" or "false".\n\nS: Human life expectancy in the United States is 78 years.\nA: true\n\nS: Dwight D. Eisenhower was president of the United States in 2020.\nA: false\n\nS: The giant anteater uses walking for locomotion.\nA:',
  'Interpret each statement literally, and as a sentence about the real world; carefully research each answer, without falling prey to any common myths; and reply with one word, "t

## Training

In [12]:
mlp_models = {
    l: MLPClassifier(Config.lm_hidden_dim, Config.mlp_hidden_layers, Config.mlp_output_dim).to(Config.device)
    for l in Config.lm_tap_layers
}
criterion = nn.BCELoss()

optimizers = {
    l: torch.optim.Adam(m.parameters(), lr=Config.lr)
    for l, m in mlp_models.items()
}

In [13]:
def train_step(x, labels, model, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    mlp_outputs = model(x).squeeze()
    loss = criterion(mlp_outputs, labels)
    loss.backward()
    optimizer.step()
    return loss.item(), mlp_outputs.detach().cpu().round()

In [14]:
input_column = 'text'

In [15]:
for epoch in range(Config.epochs):
# for epoch in range(1):
    print(f"{'#'*25} Epoch {epoch+1} {'#'*25}")
    running_loss = {l: 0.0 for l in Config.lm_tap_layers}
    lm_predictions_token = []
    lm_predictions_prob = []
    mlp_predictions = {l: [] for l in Config.lm_tap_layers}
    all_labels = []
    for i, batch in enumerate(tqdm(train_loader)):
        inputs = tokenizer(batch[input_column][0], padding=True, return_tensors='pt', max_length=1024, truncation=True).to(Config.device)
        with torch.no_grad():
            lm_outputs = lm_model(**inputs, return_dict=True, output_hidden_states=True)
            lm_probs = F.softmax(lm_outputs.logits[..., -1, :], dim=-1)
            lm_predictions_prob.extend(lm_probs[:, [label_tokens[0], label_tokens[1]]].tolist())
            lm_predictions_token.extend(lm_probs.argmax(dim=-1).tolist())
            
            # hidden_states will be in shape (batch_size, num_layers, hidden_dim)
            hidden_states = torch.stack(lm_outputs.hidden_states)[..., -1, :].transpose(0, 1)

        labels = batch['label'][0].float().to(Config.device)
        all_labels.extend(batch['label'][0])
        
        for tap_layer, mlp_model in mlp_models.items():
            loss, mlp_pred = train_step(
                hidden_states[..., tap_layer, :],
                labels,
                mlp_model,
                criterion,
                optimizers[tap_layer]
            )
            mlp_predictions[tap_layer].extend(mlp_pred.tolist())
            running_loss[tap_layer] += loss
        
        # print(f"\t batch {i+1} loss: {loss.item()}")
        # if i==5: break
    
    all_labels = torch.tensor(all_labels)

    llm_token_acc = (torch.tensor(lm_predictions_token) == torch.where(all_labels == 1, label_tokens[1], label_tokens[0]).cpu()).sum()
    llm_token_acc = llm_token_acc / len(lm_predictions_token)

    llm_prob_acc = (torch.tensor(lm_predictions_prob).argmax(dim=-1) == all_labels.cpu()).sum()
    llm_prob_acc = llm_prob_acc / len(lm_predictions_prob)

    MLP_acc = {
        l: (torch.tensor(mlp_predictions[l]) == all_labels.cpu()).sum() / len(mlp_predictions[l])
        for l in Config.lm_tap_layers
    }

    votes = pd.DataFrame(mlp_predictions)
    votes['final'] = votes.mean(axis='columns').round().astype(int)
    majority_acc = (votes['final'] == all_labels.numpy()).sum() / len(votes)

    print(f"LLM: Token Accuracy: {llm_token_acc:.2f}, Probability Accuracy: {llm_prob_acc:.2f}")
    print(f"MLPs:")
    for l in mlp_models.keys():
        print(f"    Tap Layer {l}: Train Accuracy: {MLP_acc[l]:.2f}, Loss: {running_loss[l]/len(train_loader):.4f}")
    print(f"    Majority Voting: {majority_acc:.2f}")

######################### Epoch 1 #########################


  0%|          | 0/290 [00:00<?, ?it/s]

LLM: Token Accuracy: 0.79, Probability Accuracy: 0.79
MLPs:
    Tap Layer 16: Train Accuracy: 0.82, Loss: 0.3951
    Tap Layer 20: Train Accuracy: 0.82, Loss: 0.4020
    Tap Layer 24: Train Accuracy: 0.81, Loss: 0.4060
    Majority Voting: 0.82
######################### Epoch 2 #########################


  0%|          | 0/290 [00:00<?, ?it/s]

LLM: Token Accuracy: 0.79, Probability Accuracy: 0.79
MLPs:
    Tap Layer 16: Train Accuracy: 0.84, Loss: 0.3515
    Tap Layer 20: Train Accuracy: 0.83, Loss: 0.3558
    Tap Layer 24: Train Accuracy: 0.83, Loss: 0.3621
    Majority Voting: 0.84
######################### Epoch 3 #########################


  0%|          | 0/290 [00:00<?, ?it/s]

LLM: Token Accuracy: 0.79, Probability Accuracy: 0.79
MLPs:
    Tap Layer 16: Train Accuracy: 0.85, Loss: 0.3376
    Tap Layer 20: Train Accuracy: 0.84, Loss: 0.3417
    Tap Layer 24: Train Accuracy: 0.84, Loss: 0.3466
    Majority Voting: 0.85
######################### Epoch 4 #########################


  0%|          | 0/290 [00:00<?, ?it/s]

LLM: Token Accuracy: 0.79, Probability Accuracy: 0.79
MLPs:
    Tap Layer 16: Train Accuracy: 0.85, Loss: 0.3334
    Tap Layer 20: Train Accuracy: 0.84, Loss: 0.3346
    Tap Layer 24: Train Accuracy: 0.84, Loss: 0.3403
    Majority Voting: 0.84
######################### Epoch 5 #########################


  0%|          | 0/290 [00:00<?, ?it/s]

LLM: Token Accuracy: 0.79, Probability Accuracy: 0.79
MLPs:
    Tap Layer 16: Train Accuracy: 0.85, Loss: 0.3212
    Tap Layer 20: Train Accuracy: 0.85, Loss: 0.3242
    Tap Layer 24: Train Accuracy: 0.84, Loss: 0.3313
    Majority Voting: 0.85


In [16]:
with torch.no_grad():    
    running_loss = {l: 0.0 for l in Config.lm_tap_layers}
    lm_predictions_token = []
    lm_predictions_prob = []
    mlp_predictions = {l: [] for l in Config.lm_tap_layers}
    all_labels = []
    for i, batch in enumerate(tqdm(test_loader)):
        inputs = tokenizer(batch[input_column][0], padding=True, return_tensors='pt', max_length=1024, truncation=True).to(Config.device)
        with torch.no_grad():
            lm_outputs = lm_model(**inputs, return_dict=True, output_hidden_states=True)
            lm_probs = F.softmax(lm_outputs.logits[..., -1, :], dim=-1)
            lm_predictions_prob.extend(lm_probs[:, [label_tokens[0], label_tokens[1]]].tolist())
            lm_predictions_token.extend(lm_probs.argmax(dim=-1).tolist())
            
            # hidden_states will be in shape (batch_size, num_layers, hidden_dim)
            hidden_states = torch.stack(lm_outputs.hidden_states)[..., -1, :].transpose(0, 1)

        labels = batch['label'][0].float().to(Config.device)
        all_labels.extend(batch['label'][0])
        
        for tap_layer, mlp_model in mlp_models.items():
            mlp_model.eval()
            mlp_outputs = mlp_model(hidden_states[..., tap_layer, :]).squeeze()
            loss = criterion(mlp_outputs, labels).item()
            mlp_predictions[tap_layer].extend(mlp_outputs.cpu().round().tolist())
            running_loss[tap_layer] += loss
        
        # print(f"\t batch {i+1} loss: {loss.item()}")
        # if i==5: break
    
    all_labels = torch.tensor(all_labels)

    llm_token_acc = (torch.tensor(lm_predictions_token) == torch.where(all_labels == 1, label_tokens[1], label_tokens[0]).cpu()).sum()
    llm_token_acc = llm_token_acc / len(lm_predictions_token)

    llm_prob_acc = (torch.tensor(lm_predictions_prob).argmax(dim=-1) == all_labels.cpu()).sum()
    llm_prob_acc = llm_prob_acc / len(lm_predictions_prob)

    MLP_acc = {
        l: (torch.tensor(mlp_predictions[l]) == all_labels.cpu()).sum() / len(mlp_predictions[l])
        for l in Config.lm_tap_layers
    }

    votes = pd.DataFrame(mlp_predictions)
    votes['final'] = votes.mean(axis='columns').round().astype(int)
    majority_acc = (votes['final'] == all_labels.numpy()).sum() / len(votes)


    print(f"LLM: Token Accuracy: {llm_token_acc:.2f}, Probability Accuracy: {llm_prob_acc:.2f}")
    print(f"MLPs:")
    for l in mlp_models.keys():
        print(f"    Tap Layer {l}: Test Accuracy: {MLP_acc[l]:.2f}, Loss: {running_loss[l]/len(train_loader):.4f}")
    print(f"    Majority Voting: {majority_acc:.2f}")

  0%|          | 0/92 [00:00<?, ?it/s]

LLM: Token Accuracy: 0.92, Probability Accuracy: 0.92
MLPs:
    Tap Layer 16: Test Accuracy: 0.93, Loss: 0.0726
    Tap Layer 20: Test Accuracy: 0.93, Loss: 0.0715
    Tap Layer 24: Test Accuracy: 0.92, Loss: 0.0716
    Majority Voting: 0.93


In [31]:
votes = pd.DataFrame(mlp_predictions)
votes['final'] = votes.mean(axis='columns').round().astype(int)
votes['label'] = all_labels
votes

Unnamed: 0,16,20,24,final,label
0,0.0,1.0,0.0,0,0
1,0.0,1.0,1.0,1,1
2,0.0,1.0,1.0,1,1
3,0.0,1.0,1.0,1,0
4,0.0,1.0,1.0,1,1
...,...,...,...,...,...
1453,0.0,0.0,0.0,0,0
1454,0.0,1.0,1.0,1,0
1455,0.0,1.0,1.0,1,1
1456,0.0,1.0,0.0,0,0


In [35]:
(votes['final'] == all_labels.numpy()).sum() / len(votes)

0.8820301783264746

In [17]:
save_path = "saved_models/prompted_statement_ensemble/"
os.makedirs(save_path, exist_ok=True)
for l, m in mlp_models.items():
    torch.save(m.state_dict(), f"{save_path}/mlp_model_tap_{l}.pt")
    print(f"MLP Model for Tap Layer {l} saved.")

MLP Model for Tap Layer 16 saved.
MLP Model for Tap Layer 20 saved.
MLP Model for Tap Layer 24 saved.
