In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import os
import json
import torch
import math
import torch.nn as nn
from einops import rearrange

In [2]:
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
gpt2_large = GPT2LMHeadModel.from_pretrained('gpt2-large')
tokenizer_large = GPT2Tokenizer.from_pretrained('gpt2-large')

In [4]:
# print(vars(gpt2).keys())
print(vars(gpt2)['config'])
# "n_ctx": 1024,
# "n_embd": 768,
# "n_head": 12,
# "n_inner": null,
# "n_layer": 12,
# "n_positions": 1024,
print(vars(gpt2_large)['config'])
# "n_ctx": 1024,
# "n_embd": 1280,
# "n_head": 20,
# "n_inner": null,
# "n_layer": 36,
# "n_positions": 1024,

GPT2Config {
  "_name_or_path": "gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.11.3",
  "use_cache": true,
  "vocab_size": 50257
}

GPT2Config {
  "_name_or_path": "gpt2-large",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdro

In [5]:
print(vars(gpt2_large).keys())

dict_keys(['training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_pre_hooks', '_state_dict_hooks', '_load_state_dict_pre_hooks', '_modules', 'config', 'name_or_path', 'model_parallel', 'device_map'])


In [17]:
print(gpt2_large._modules)

OrderedDict([('transformer', GPT2Model(
  (wte): Embedding(50257, 1280)
  (wpe): Embedding(1024, 1280)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0): GPT2Block(
      (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): GPT2Block(
      (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((1280,), eps=1e-05, elementwis

In [42]:
# From https://github.com/huggingface/transformers/issues/18282
print(gpt2_large._modules['transformer'].wte.weight.requires_grad)
print(type(gpt2_large._modules['transformer']))
print(type(gpt2_large.transformer))
print(len(gpt2_large.transformer.h))
print(gpt2_large.lm_head)
print(gpt2_large.lm_head.weight.requires_grad)

True
<class 'transformers.models.gpt2.modeling_gpt2.GPT2Model'>
<class 'transformers.models.gpt2.modeling_gpt2.GPT2Model'>
36
Linear(in_features=1280, out_features=50257, bias=False)
True


In [13]:
def _load_split_bykey(data_dir, source, split, key='text', n=np.inf):
    path = os.path.join(data_dir, f'{source}.{split}.jsonl')
    data = []
    for i, line in enumerate(open(path)):
        if i >= n:
            break
        data.append(json.loads(line)[key])
    return data

In [14]:
# webtext_train = _load_split_bykey('data/', 'webtext', 'train', key='text')
webtext_train_lens = _load_split_bykey('data/', 'webtext', 'train', key='length')

In [15]:
print(max(webtext_train_lens))
# no problem, <= n_positions == 1024

1024


In [8]:
# Examine degen data
def _load_degen_data(data_dir: str, filename: str, n=np.inf, return_type='json'):
    path = os.path.join(data_dir, filename)
    data = []
    for i, line in enumerate(open(path, 'r')):
        if i >= n:
            break
        try:
            obj = json.loads(line)
        except Exception:
            print(line)
            raise

        if return_type == 'json':
            data.append(obj)
        else:
            data.append(obj['string'])

    return data

In [14]:
lines = _load_degen_data('data/data_degen/unconditional', 'unconditional_gold.jsonl', n=1)
line1 = lines[0]

In [18]:
print(len(line1['tokens']))

82


In [19]:
line1_encoded = tokenizer_large(line1['string'], return_tensors='pt')

In [23]:
line1_encoded['input_ids']

tensor([[ 1212,  1772,   468,   407, 18141,  2687,  2073,   287,   428,  2393,
           198,   198,  1212,   318,   281,   555, 24421,  4126, 14335,   357,
         33210, 10313,   828,   543,   460,   307,  1043,   319,   262,  3084,
           287,  6401,  1703,   349, 16999,    13,   464, 11743,  2546,   318,
         36117,    87,  1238,  2780,   464, 14335,  2499,   329,  1111, 38054,
            13,   632,   635,  2499,   329, 49974,  1547,    11,   475,   407,
           329,  5311,  1228,   896,    13,  2949, 21671,  8364,  3017,    13,
         29668,    25, 31512, 12028,    25,   678, 25844,    25,   718, 11395,
            25,  5867]])

In [29]:
torch.equal(torch.tensor(line1['tokens'], dtype=torch.long), line1_encoded['input_ids'].squeeze())
# line1['tokens'] and line1_encoded['input_ids'] are equal

True

In [30]:
line1_output = gpt2_large(**line1_encoded, labels=line1_encoded['input_ids'])

In [32]:
print(line1_output.loss)

tensor(2.8305, grad_fn=<NllLossBackward0>)


In [35]:
math.exp(line1_output.loss.item())

16.953105148628296

In [41]:
logits = line1_output.logits
target = line1_encoded['input_ids']
logits = rearrange(logits, 'B L V -> B V L')

shift_logits = logits[..., :-1]
shift_target = target[..., 1:]

In [42]:
criterian = nn.NLLLoss(reduction='none')
log_softmax = nn.LogSoftmax(dim=1)

In [45]:
with torch.no_grad():
    nll_loss = criterian(log_softmax(shift_logits), shift_target).squeeze()

In [48]:
print(nll_loss)
print(nll_loss.size())

tensor([8.0242e+00, 6.2787e-01, 2.3978e-02, 1.6755e-01, 2.5531e-03, 7.2772e-04,
        8.0741e-06, 1.4212e-05, 2.7403e-04, 1.5045e+00, 1.7595e-04, 6.9639e+00,
        5.2225e+00, 2.3699e+00, 5.7105e+00, 1.0798e+01, 6.6186e-05, 6.7791e+00,
        5.4961e+00, 7.4312e+00, 1.8832e+00, 2.7035e+00, 2.4980e+00, 1.4197e+00,
        1.7844e-01, 1.3505e+00, 2.0760e+00, 1.0401e+00, 4.4163e+00, 7.2832e-01,
        6.0293e+00, 4.3941e+00, 4.6440e-03, 9.1549e+00, 6.7677e-01, 4.3946e+00,
        5.4922e+00, 4.8496e+00, 4.8915e-01, 2.7771e+00, 2.6787e-01, 2.3750e-01,
        2.8039e-03, 6.7883e+00, 2.2041e+00, 6.0548e+00, 3.0698e+00, 1.6665e+00,
        2.3861e+00, 1.3712e+00, 3.0961e+00, 3.1540e+00, 2.0761e+00, 1.2778e+00,
        4.2646e+00, 2.2633e-01, 1.1084e+00, 3.8290e+00, 2.6259e+00, 1.0037e+00,
        2.3126e-01, 4.9415e+00, 1.5021e+00, 4.5665e-01, 6.6110e+00, 7.2467e+00,
        2.0981e+00, 4.6058e+00, 7.7502e-01, 9.0728e+00, 8.0968e-01, 3.8939e+00,
        1.6324e+00, 5.4584e-01, 4.8240e+

In [47]:
print(line1['nll4tok'])

[4.121309280395508, 7.294855117797852, 0.053824424743652344, 0.00131988525390625, 0.000835418701171875, 0.002300262451171875, 0.0014324188232421875, 1.52587890625e-05, 2.6702880859375e-05, 0.0002231597900390625, 0.029817581176757812, 0.00023555755615234375, 2.1089344024658203, 1.381875991821289, 2.4119396209716797, 4.944422721862793, 4.329188346862793, 0.0001735687255859375, 6.0095062255859375, 4.792625427246094, 7.932603359222412, 1.8370590209960938, 2.991702079772949, 2.628235340118408, 1.5519218444824219, 0.1807241439819336, 1.3865737915039062, 2.0877504348754883, 0.9163331985473633, 4.781112194061279, 0.7594795227050781, 6.138852119445801, 4.280647277832031, 0.00800323486328125, 9.22953987121582, 0.6406097412109375, 4.8862762451171875, 4.948276996612549, 4.3904948234558105, 0.42863941192626953, 2.5045251846313477, 0.24117660522460938, 0.28272533416748047, 0.0037631988525390625, 7.134640216827393, 1.9799861907958984, 5.771195411682129, 3.0094079971313477, 1.6340065002441406, 2.69001

In [49]:
with torch.no_grad():
    nll_loss2 = criterian(log_softmax(logits), target).squeeze()

In [50]:
print(nll_loss2)

tensor([1.2056e+01, 1.0624e+01, 1.2891e+01, 1.4227e+01, 2.5744e+01, 1.8443e+01,
        1.4884e+01, 1.8720e+01, 2.0199e+01, 1.6006e+01, 1.7595e-04, 1.2120e+01,
        2.0256e+01, 1.1444e+01, 1.0426e+01, 2.2206e+01, 2.3018e+01, 1.8503e+01,
        1.1291e+01, 1.0120e+01, 9.1788e+00, 5.9369e+00, 1.2914e+01, 9.6444e+00,
        9.3654e+00, 9.5527e+00, 8.3844e+00, 9.4545e+00, 8.3958e+00, 9.6799e+00,
        9.3357e+00, 7.6487e+00, 2.2551e+01, 1.1530e+01, 7.2119e+00, 9.6576e+00,
        1.2383e+01, 6.8530e+00, 7.4405e+00, 9.1917e+00, 1.1599e+01, 1.3115e+01,
        1.5065e+01, 1.2199e+01, 1.3087e+01, 1.0424e+01, 9.7674e+00, 9.8880e+00,
        8.2775e+00, 1.2193e+01, 8.7850e+00, 1.3341e+01, 7.2318e+00, 1.0597e+01,
        9.1956e+00, 1.4798e+01, 1.3855e+01, 8.4157e+00, 9.6021e+00, 9.9664e+00,
        1.1122e+01, 1.8794e+01, 1.5913e+01, 1.4482e+01, 9.0713e+00, 9.2772e+00,
        8.7745e+00, 9.2786e+00, 9.0294e+00, 9.2107e+00, 8.2898e+00, 7.7623e+00,
        8.3909e+00, 7.8531e+00, 9.1712e+