In [1]:
import torch
import transformers
import numpy as np

from collections import OrderedDict

import sys
sys.path.insert(0, '..')

import gpt2_model
from gpt2_model import GPT2ForSequenceClassification, GPT2Config
from decompose_gpt2 import GPT2ForSequenceClassificationDecomposed


In [2]:
model_hf = transformers.AutoModelForSequenceClassification.from_pretrained("gpt2")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model_hf.config


GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.45.2",
  "use_cache": true,
  "vocab_size": 50257
}

In [4]:
model = GPT2ForSequenceClassification(GPT2Config.from_dict(model_hf.config.to_dict()))


In [5]:
new_state_dict = OrderedDict()
for key, value in model_hf.state_dict().items():
    # ref : https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
    new_key = key.replace("ln_1.weight", "ln_1.gamma").replace("ln_1.bias", "ln_1.beta")
    new_key = new_key.replace("ln_2.weight", "ln_2.gamma").replace("ln_2.bias", "ln_2.beta")
    new_key = new_key.replace("ln_f.weight", "ln_f.gamma").replace("ln_f.bias", "ln_f.beta")
    new_state_dict[new_key] = value
    
model.load_state_dict(new_state_dict)
model.eval()


GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x Block(
        (ln_1): LayerNorm((768,), eps=1e-05)
        (attn): Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05)
        (mlp): MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05)
  )
  (score): Linear(in_features=768, out_features=2, bias=False)
)

In [6]:
# test equal activation functions

f_1 = model_hf.transformer.h[0].mlp.act
f_2 = gpt2_model.gelu

x = torch.rand((100, 100))
assert torch.equal(f_1(x), f_2(x))


In [7]:
with torch.no_grad():
    for i in range(12):
        x = torch.rand((100, 768))

        # test equal mlp
        f_1 = model_hf.transformer.h[0].mlp
        f_2 = model.transformer.h[0].mlp
        assert torch.allclose(f_1(x), f_2(x), atol=1e-5)

        # test equal layernorm
        f_1 = model_hf.transformer.h[0].ln_1
        f_2 = model.transformer.h[0].ln_1
        assert torch.allclose(f_1(x), f_2(x), atol=1e-5)

        f_1 = model_hf.transformer.h[0].ln_2
        f_2 = model.transformer.h[0].ln_2
        assert torch.allclose(f_1(x), f_2(x), atol=1e-5)

        x = torch.rand((100, 100, 768))
        # test equal attention
        f_1 = model_hf.transformer.h[0].attn
        f_2 = model.transformer.h[0].attn
        assert torch.allclose(f_1(x)[0], f_2(x)[0], atol=1e-3)


In [8]:
tokenizer.pad_token = tokenizer.eos_token
model_hf.config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

inputs = tokenizer(["Hello, my dog is cute", "Hello, my dog is cute"], return_tensors="pt",
                   truncation=True, padding="max_length", max_length=128
                   )

inputs["input_ids"] = inputs["input_ids"]
model_hf.eval()

with torch.no_grad():
    logits_hf = model_hf(**inputs)

logits_hf.logits


tensor([[ 4.1302, -2.2830],
        [ 4.1302, -2.2830]])

In [9]:
with torch.no_grad():
    logits = model(**inputs)

logits["logits"]


tensor([[ 4.1302, -2.2830],
        [ 4.1302, -2.2830]])

In [12]:
model_decomposed = GPT2ForSequenceClassificationDecomposed.from_pretrained(
    "gpt2",
    debug=True, 
    num_contributions=2,
    shapley_include_bias=False
    )


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
beta_mask = (torch.rand(128) > 0.5).to(int)
beta_mask = torch.stack([beta_mask, 1 - beta_mask])
beta_mask = torch.stack([beta_mask] * 2)
print(beta_mask.shape)

with torch.no_grad():
    model_decomposed.model.eval()
    decomposed_logits = model_decomposed(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        beta_mask=beta_mask)


torch.Size([2, 2, 128])
GPT2 Block Layer 0 error: 3.1943662040110526e-15
GPT2 Block Layer 1 error: 3.72321620094784e-15
GPT2 Block Layer 2 error: 5.402918807774432e-15
GPT2 Block Layer 3 error: 6.624675560782511e-15
GPT2 Block Layer 4 error: 8.375004086873984e-15
GPT2 Block Layer 5 error: 9.722899297835816e-15
GPT2 Block Layer 6 error: 1.0484493767676955e-14
GPT2 Block Layer 7 error: 1.1153579231987063e-14
GPT2 Block Layer 8 error: 1.2169330210315444e-14
GPT2 Block Layer 9 error: 1.3611367102666473e-14
GPT2 Block Layer 10 error: 1.6635398838955063e-14
GPT2 Block Layer 11 error: 2.2772144047219033e-14
GPT2 Classifier logits error:  3.1086244689504383e-15


In [14]:
model_decomposed = GPT2ForSequenceClassificationDecomposed.from_pretrained(
    "gpt2",
    debug=True, 
    num_contributions=2,
    shapley_include_bias=True
    )


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
beta_mask = (torch.rand(128) > 0.5).to(int)
beta_mask = torch.stack([beta_mask, 1 - beta_mask])
beta_mask = torch.stack([beta_mask] * 2)
print(beta_mask.shape)

with torch.no_grad():
    model_decomposed.model.eval()
    decomposed_logits = model_decomposed(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        beta_mask=beta_mask)


torch.Size([2, 2, 128])
GPT2 Block Layer 0 error: 3.636919796067302e-15
GPT2 Block Layer 1 error: 5.3430141231155455e-15
GPT2 Block Layer 2 error: 1.3888991084679685e-14
GPT2 Block Layer 3 error: 2.3928438597846576e-14
GPT2 Block Layer 4 error: 5.889747758498866e-14
GPT2 Block Layer 5 error: 1.3840804799311158e-13
GPT2 Block Layer 6 error: 3.164191756937345e-13
GPT2 Block Layer 7 error: 6.646665963167283e-13
GPT2 Block Layer 8 error: 1.3170356660071007e-12
GPT2 Block Layer 9 error: 2.5115868300669534e-12
GPT2 Block Layer 10 error: 3.6315302116823175e-12
GPT2 Block Layer 11 error: 6.205510303893382e-12
GPT2 Classifier logits error:  1.325828336007362e-12
