In [2]:
!curl --output gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  522M  100  522M    0     0  3017k      0  0:02:57  0:02:57 --:--:-- 6198k0:00:03  0:05:42 1548k:25  0:00:06  0:03:19 3278k     0  0:03:17  0:00:17  0:03:00 2603k  2746k      0  0:03:14  0:01:32  0:01:42 3802k 4661k 0  0:03:03  0:02:11  0:00:52 3417k


In [4]:
!python main.py --text "Once when I was six years old"

Namespace(batch_size=-1, length=-1, nsamples=1, quiet=False, temperature=0.7, text='Once when I was six years old', top_k=40, unconditional=False)
Once when I was six years old
100%|█████████████████████████████████████████| 512/512 [00:23<00:00, 21.48it/s]
, our parents bought us a big white box of the most expensive, but surprisingly, most expensive, soap, which is called "Carmen's Kiss". "It's like a bathtub," I told myself.

I would like to think we used to be able to make soap, but we are not.

The world has changed in a profound way, according to a paper from the University of London called "L'Aquila." It says that the number of people who have been diagnosed with a condition called lupus-related anaemia, or Lupus, has been declining, due to the rising cost of soap and the increased availability of natural and synthetic oils.

"The lupus, of course, has a very severe impact on the health of patients," says Dr. Robert S. Schiraldi, the chief executive officer of the L'Aquila Found

In [None]:
import torch
import torch.nn.functional as F
from tqdm import trange

def top_k_logits(logits, k):
    if k == 0:
        return logits
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1]
    return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)

def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
        context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
    prev = context
    output = context
    past = None
    with torch.no_grad():
        for i in trange(length):
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
            log_probs = F.softmax(logits, dim=-1)
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(log_probs, k=1, dim=-1)
            output = torch.cat((output, prev), dim=1)
    return output

In [7]:
import os
import sys
import torch
import random
import argparse
import numpy as np
from GPT2.model import (GPT2LMHeadModel)
from GPT2.utils import load_weight
from GPT2.config import GPT2Config
from GPT2.sample import sample_sequence
from GPT2.encoder import get_encoder


state_dict = torch.load('gpt2-pytorch_model.bin', map_location='cpu' if not torch.cuda.is_available() else None)
enc = get_encoder()
config = GPT2Config()
model = GPT2LMHeadModel(config)
model = load_weight(model, state_dict)
model.to('cpu')
model.eval()
print()




In [21]:
context_tokens = enc.encode("Hello world, I am a pilot")
context = torch.tensor(context_tokens, device='cpu', dtype=torch.long).unsqueeze(0).repeat(1, 1)
prev = context
output = context
past = None

In [23]:
logits, past = model(prev, past=past)

In [12]:
import torch.nn as nn

def hook_fn(m, i, o):
    visualisation[m] = o 

def get_all_layers(net):
    
    for name, layer in net._modules.items():
        #If it is a sequential, don't register a hook on it
        # but recursively register hook on all it's module children
        if isinstance(layer, nn.Sequential):
            get_all_layers(layer)
        else:
          # it's a non sequential. Register a hook
          layer.register_forward_hook(hook_fn)

In [13]:
net = model
visualisation = {}
get_all_layers(net)

In [29]:
context_tokens = enc.encode("Hello world, I am a pilot")
context = torch.tensor(context_tokens, device='cpu', dtype=torch.long).unsqueeze(0).repeat(1, 1)
prev = context
output = context
past = None

logits, past = model(prev, past=past)

In [39]:
# Just to check whether we got all layers

visualisation['wte'] #output includes sequential layers

KeyError: 'wte'

In [66]:
for item in visualisation.values():
    if type(item) is tuple:
        for i in item:
            print(len(i))

1
12


In [107]:
keys = list(visualisation.keys())

In [108]:
keys[0].named_parameters()

<generator object Module.named_parameters at 0x7fe6484d79e0>

In [106]:
for name, wt in zip(keys[0].named_parameters(), visualisation[keys[0]]):
    print(name[0], wt)

wte.weight tensor([[[-8.8119e-06, -1.4021e-01, -2.0845e-01,  ..., -1.5329e-01,
          -6.7826e-02, -1.9630e-01],
         [-1.6633e-01,  2.1910e-01,  4.4472e-02,  ..., -1.7681e-01,
          -1.6563e-01,  4.3342e-01],
         [ 2.6832e-01,  2.9127e-01,  2.1967e-01,  ..., -9.3709e-02,
           1.2303e-01,  8.8264e-02],
         ...,
         [-3.3483e-01, -5.2388e-01, -5.2560e-01,  ...,  1.4899e-02,
          -2.8531e-02,  4.2532e-01],
         [-4.2882e-06,  3.0885e-01,  5.1677e-01,  ..., -3.9239e-02,
          -3.5887e-01,  4.2202e-01],
         [ 7.4062e-02,  6.5349e-02, -1.4431e+00,  ...,  1.1329e-01,
           7.1232e-02,  3.6873e-01]]], grad_fn=<ViewBackward0>)
wpe.weight [tensor([[[[[-1.2526e+00,  2.3200e+00,  1.7218e-01,  ..., -1.0076e+00,
            -1.8970e-01,  1.3219e+00],
           [-2.6307e+00,  3.8931e+00,  1.3101e+00,  ..., -2.6921e-01,
            -2.4905e+00,  1.4687e+00],
           [-1.7495e+00,  2.9933e+00,  1.4383e+00,  ..., -9.3925e-01,
            -2.033

In [73]:
for name, param in model.named_parameters():
    print(name, param.size())

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 