## Installs and imports

In [None]:
!pip install transformers torch accelerate

from transformers import pipeline, set_seed
import torch
import numpy as np
from matplotlib import pyplot as plt

## Set up transformers pipeline + global variables

In [None]:
generator = pipeline(
    'text-generation',
    model='gpt2',
    device_map="auto",
)
set_seed(42)

MAX_LENGTH = 256

print(generator.model.transformer)

In [None]:
print(generator)

## Custom Attention module to save activations

In [None]:
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from torch import nn
import json

class GPT2Attention_hooked(GPT2Attention):
  def __init__(self, config, is_cross_attention=False, layer_idx=None):
    super().__init__(config, is_cross_attention, layer_idx)
    self.activations = {}

  def _attn(self, query, key, value, attention_mask=None, head_mask=None):
    attn_weights = torch.matmul(query, key.transpose(-1, -2))

    #assert query.size(dim=0) == 1 and query.size(dim=2) == 1 or query.size(dim=2) == 15, "batch = 1 and sequence length = 1 or 15"

    if self.scale_attn_weights:
      attn_weights = attn_weights / torch.full(
          [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
      )

      """self.activations[self.layer_idx] = attn_weights.cpu().numpy().tolist()

      with open("attention" + str(self.layer_idx) +  ".json", "w") as outfile:
        json.dump(self.activations, outfile)"""

    # Layer-wise attention scaling
    if self.scale_attn_by_inverse_layer_idx:
      #print("Scale by layer")
      attn_weights = attn_weights / float(self.layer_idx + 1)

    if not self.is_cross_attention:
      #print("Not cross attention")
      # if only "normal" attention layer implements causal mask
      query_length, key_length = query.size(-2), key.size(-2)
      causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
      mask_value = torch.finfo(attn_weights.dtype).min
      # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
      # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
      mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
      attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)


    #with open("sample" + str(self.layer_idx) +  ".json", "w") as outfile:
    #  json.dump(self.activations, outfile)
    #print(attn_weights.shape)

    if attention_mask is not None:
      # Apply the attention mask
      attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1)

    # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
    attn_weights = attn_weights.type(value.dtype)
    attn_weights = self.attn_dropout(attn_weights)

    # Mask heads if we want to
    if head_mask is not None:
      attn_weights = attn_weights * head_mask

    self.activations[self.layer_idx] = attn_weights.detach().cpu().numpy().tolist()

    with open("attention" + str(self.layer_idx) +  ".json", "w") as outfile:
      json.dump(self.activations, outfile)

    attn_output = torch.matmul(attn_weights, value)

    return attn_output, attn_weights

We save and reload the model weights, after substituting our custom attention module

In [None]:
#generator.model.transformer
torch.save(generator.model.transformer.state_dict(), "model.pth")

for i in range(len(generator.model.transformer.h)):
  generator.model.transformer.h[i].attn = GPT2Attention_hooked(generator.model.config, generator.model.transformer.h[i].attn.is_cross_attention, i).to("cpu")

generator.model.transformer.load_state_dict(torch.load("model.pth"))

generator_hooked = pipeline(
    'text-generation',
    tokenizer=generator.tokenizer,
    model=generator.model,
    device_map="auto",
)

In [None]:
import time

start = time.perf_counter()

prompt = generator_hooked("I simply love to eat free lunch", max_length=MAX_LENGTH)

end = time.perf_counter()
print(end - start, "seconds")
print(prompt)


In [None]:
generator_hooked.tokenizer

## Generate attention graphics

In [None]:
LAYER = 4

with open('attention' + str(LAYER) + '.json', 'r') as f:
  data = json.load(f)

act = np.array(data[str(LAYER)])
act = act.squeeze()

print(act.shape)

In [None]:
layer_idx = 5

with open('attention' + str(layer_idx) + '.json', 'r') as f:
  data = json.load(f)

act = np.array(data[str(layer_idx)])
act = np.squeeze(act)

fig, axs = plt.subplots(4, 3, figsize=(18,16))

for i, ax in enumerate(axs.flat):
    ax.hist(act[i], 20)
    #ax.hist(act[i], density=True)
#fig.tight_layout()
fig.suptitle("Softmax scores of GPT-2 in layer 5 for all heads")
fig.savefig("distribution" + str(layer_idx) + ".png")

In [None]:
"""
### Deprecated code
head = 2

gen_sentence = prompt[0]["generated_text"].replace("\n", " _newline ")

print(gen_sentence)

gen_sentence = gen_sentence.replace(r"?", r" ? ")
gen_sentence = gen_sentence.replace(r".", r" . ")
gen_sentence = gen_sentence.replace(r",", r" , ")
gen_sentence = gen_sentence.replace(r"!", r" ! ")
gen_sentence = gen_sentence.replace(r"'s", r" 's ")
gen_sentence = gen_sentence.replace(r":", r" : ")
gen_sentence = gen_sentence.replace(r"'re", r" ' re ")

gen_sentence = gen_sentence.split()

gen_sentence.append("eos_token")

print(gen_sentence)
print(len(gen_sentence))

plt.figure().set_figwidth(15)

plt.boxplot(act, widths=0.6)
plt.xticks(range(len(gen_sentence)), gen_sentence, rotation=90, fontsize=9)"""

"""fig, axs = plt.subplots(6, 2, figsize=(25,20))

for i, ax in enumerate(axs.flat):
    ax.bar(range(len(gen_sentence)), act[i], width=0.5)
    ax.set_xticks(range(len(gen_sentence)), gen_sentence, rotation=90, fontsize=9)
    #ax.hist(act[i], density=True)

fig.tight_layout()
fig.suptitle("Attention scores per head")

fig.savefig("attention_scores32.jpg")"""
"""plt.bar(gen_sentence, act[head])
plt.xticks(gen_sentence, rotation=90, fontsize=7)
plt.title("Layer 4, Head " + str(head))
#plt.savefig("attention_scores4.jpg")"""

In [None]:
print((r"\nThis is very nice").replace(r"\n", r" \n "))