In [None]:
import csv

import numpy as np
np.random.seed(10)

import warnings
warnings.filterwarnings('ignore')

import torch
from sklearn.metrics import accuracy_score
from transformers import AutoConfig, AutoTokenizer, GPTNeoForCausalLM, GPTNeoModel

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_card = 'EleutherAI/gpt-neo-125M'

In [None]:
config = AutoConfig.from_pretrained(model_card)

config.num_heads

config.json:   0%|          | 0.00/1.01k [00:00<?, ?B/s]

12

In [None]:
# load model
model = GPTNeoForCausalLM.from_pretrained(model_card, output_attentions=True).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_card)

model.safetensors:   0%|          | 0.00/526M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [None]:
# read dataset
rows = []
with open('/content/IMDB_100.csv', newline='') as csvfile:
    reader = csv.reader(csvfile)

    header = next(reader)

    for row in reader:
        rows.append(row)


from tqdm import tqdm

examples = [0, 3, 4, 7]
# prompt with examples for few-shot in-context learning
context = "Classify the review into positive or negative.\n"
for i in examples: context += f"\nReview: {rows[i][0]}\nSentiment: {rows[i][1]}##"

for i in examples: del rows[i]

def prompt(row):
  return context + f"\nReview: {row[0]}\nSentiment:"


In [None]:
def decode(last_state):
  # compute logits and decode predicted token
  logits = model.lm_head(last_state)

  token_prob = torch.softmax(logits[:, -1, :], dim=-1)
  token_id = torch.multinomial(token_prob, 1).item()
  # decode
  token = tokenizer.decode(token_id)

  return token


def predict_sentiment(prompt, head_ix=None):
  # tokenize
  input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
  # pass
  outputs = model(input_ids, output_attentions=True, output_hidden_states=True)
  last_state = outputs.hidden_states[-1]
  # for a given head_num
  if head_ix != None:
    # last layer
    weights = outputs['attentions'][-1]

    head_weights = weights[0][head_ix]

    noise = torch.randn_like(head_weights) * 0.00001
    head_weights = head_weights + noise.to(device)
    # propagation of noisy head
    last_state = torch.matmul(head_weights, outputs.hidden_states[-1])

  token = decode(last_state)
  return token.strip()


def compute_metric(true, pred):
  return accuracy_score(true, pred)


def run(head_ix=None):
  true, pred = [], []

  for i in tqdm(range(len(rows))):
    input = prompt(rows[i])
    result = predict_sentiment(input, head_ix)

    true.append(rows[i][1]); pred.append(result)

  return compute_metric(true, pred)

In [None]:
metrics = {}

acc_i = run(head_ix=None)
print('default metric:', acc_i)

for i in range(config.num_heads):
  acc = run(head_ix=i)
  print(f"{i}: {acc}")

  metrics[i] = acc_i - acc

100%|██████████| 95/95 [00:17<00:00,  5.37it/s]


default metric: 0.5684210526315789


100%|██████████| 95/95 [00:18<00:00,  5.27it/s]


0: 0.0


100%|██████████| 95/95 [00:18<00:00,  5.10it/s]


1: 0.0


100%|██████████| 95/95 [00:18<00:00,  5.08it/s]


2: 0.47368421052631576


100%|██████████| 95/95 [00:18<00:00,  5.00it/s]


3: 0.010526315789473684


100%|██████████| 95/95 [00:19<00:00,  4.95it/s]


4: 0.5052631578947369


100%|██████████| 95/95 [00:18<00:00,  5.03it/s]


5: 0.0


100%|██████████| 95/95 [00:18<00:00,  5.04it/s]


6: 0.0


100%|██████████| 95/95 [00:20<00:00,  4.69it/s]


7: 0.0


100%|██████████| 95/95 [00:19<00:00,  4.99it/s]


8: 0.0


100%|██████████| 95/95 [00:18<00:00,  5.00it/s]


9: 0.0


100%|██████████| 95/95 [00:19<00:00,  4.98it/s]


10: 0.0


100%|██████████| 95/95 [00:19<00:00,  4.95it/s]

11: 0.0





In [None]:
met = sorted(metrics, key=metrics.get, reverse=True)
for r in met:
    print(r, round(metrics[r], 3))

0 0.568
1 0.568
5 0.568
6 0.568
7 0.568
8 0.568
9 0.568
10 0.568
11 0.568
3 0.558
2 0.095
4 0.063
