In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import tiktoken
from pprint import pprint
import numpy as np
from load_gpt_weights import load_params

torch.__version__

torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)

Энкодер для GPT2

In [None]:
enc = tiktoken.encoding_for_model("gpt2")
assert enc.decode(enc.encode("hello world")) == "hello world"

In [None]:
hparams, params = load_params()

Гиперпараметры модели

In [None]:
n_head = hparams['n_head']
n_embd = hparams['n_embd']
n_ctx = hparams['n_ctx']
n_vocab = hparams['n_vocab']
n_layer = hparams['n_layer']
pprint(hparams)

In [None]:
def shape_tree(d):
    if isinstance(d, np.ndarray):
        return list(d.shape)
    elif isinstance(d, torch.Tensor):
        return list(d.shape)
    elif isinstance(d, list):
        return [shape_tree(v) for v in d]
    elif isinstance(d, dict):
        return {k: shape_tree(v) for k, v in d.items()}
    else:
        ValueError("uh oh")

pprint(shape_tree(params))

In [None]:
# def gelu(x):
#     return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

def attention(q, k, v):
    d = np.sqrt(q.shape[-1])
    x = q @ k.T / d

    n = x.shape[-1]
    casual_mask = torch.triu(torch.ones(n, n), diagonal=1) * -1e10
    x = x + casual_mask
    return torch.softmax(x, dim=-1) @ v


class MultiHeadCasualSelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.w_attn = nn.Linear(n_embd, 3 * n_embd)
        self.w_proj = nn.Linear(n_embd, n_embd)
    
    
    def forward(self, x):
        x = self.w_attn(x)
        q, k, v = torch.split(x, n_embd, dim=-1)
        qkv_heads = (
            torch.split(q, n_embd // n_head, dim=-1),
            torch.split(k, n_embd // n_head, dim=-1),
            torch.split(v, n_embd // n_head, dim=-1),
        )
        heads_qkv = list(zip(*qkv_heads))

        out_heads = [attention(q, k, v) for q, k, v in heads_qkv]

        x = torch.cat(out_heads, dim=-1)
        x = self.w_proj(x)
        return x

class FeedForwardNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(n_embd, 4 * n_embd)
        self.proj = nn.Linear(4 * n_embd, n_embd)
    
    def forward(self, x):
        x = F.gelu(self.fc(x))
        x = self.proj(x)
        return x

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.mha = MultiHeadCasualSelfAttention()
        self.ln2 = nn.LayerNorm(n_embd)
        self.ffn = FeedForwardNetwork()
    
    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

In [None]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.wte = nn.Embedding(n_vocab, n_embd)
        self.wpe = nn.Embedding(n_ctx, n_embd)
        blocks = [
            Block() for _ in range(n_layer)
        ]
        self.blocks = nn.ModuleList(blocks)
        self.lnorm = nn.LayerNorm(n_embd)

    def forward(self, x):
        x_wte = self.wte(x)
        x_wpe = self.wpe(torch.arange(x.shape[0]))
        x = x_wte + x_wpe

        for block in self.blocks:
            x = block(x)

        x = self.lnorm(x)
        x = x @ self.wte.weight.T
        return x
    
    def generate(self, inputs, n_tokens_to_generate, temperature=.1):
        assert len(inputs) + n_tokens_to_generate < n_ctx
        inputs = inputs.clone().detach()
        for _ in range(n_tokens_to_generate):
            output = self(inputs)
            logits = output[-1]
            if temperature == 0:
                next_id = torch.argmax(logits)
            else:
                logits = logits / temperature
                probs = torch.softmax(logits, dim=-1)
                next_id = torch.multinomial(probs, 1)
            inputs = torch.cat([inputs, next_id.view(1)])
            yield next_id


x = enc.encode('Alan Turing theorized')
x = torch.tensor(x)
test_gpt = GPT()
y = test_gpt(x)
print(f'Input shape {x.shape}')
print(f'Output shape {y.shape}')

test_gpt.state_dict()
print()
print('Parameters:')
pprint(shape_tree(test_gpt.state_dict()))
del test_gpt

Загружаем веса GPT2 в модель.

In [None]:
gpt = GPT()
state_dict = {
    'wpe.weight': params['wpe'],
    'wte.weight': params['wte'],
    'lnorm.weight': params['ln_f']['g'],
    'lnorm.bias': params['ln_f']['b'],
}
for i, block_dict in enumerate(params['blocks']):
    state_dict[f'blocks.{i}.mha.w_attn.bias'] = block_dict['attn']['c_attn']['b']
    state_dict[f'blocks.{i}.mha.w_attn.weight'] = block_dict['attn']['c_attn']['w'].T
    state_dict[f'blocks.{i}.mha.w_proj.bias'] = block_dict['attn']['c_proj']['b']
    state_dict[f'blocks.{i}.mha.w_proj.weight'] = block_dict['attn']['c_proj']['w'].T

    state_dict[f'blocks.{i}.ffn.fc.bias'] = block_dict['mlp']['c_fc']['b']
    state_dict[f'blocks.{i}.ffn.fc.weight'] = block_dict['mlp']['c_fc']['w'].T
    state_dict[f'blocks.{i}.ffn.proj.bias'] = block_dict['mlp']['c_proj']['b']
    state_dict[f'blocks.{i}.ffn.proj.weight'] = block_dict['mlp']['c_proj']['w'].T

    state_dict[f'blocks.{i}.ln1.bias'] = block_dict['ln_1']['b']
    state_dict[f'blocks.{i}.ln1.weight'] = block_dict['ln_1']['g']
    state_dict[f'blocks.{i}.ln2.bias'] = block_dict['ln_2']['b']
    state_dict[f'blocks.{i}.ln2.weight'] = block_dict['ln_2']['g']
state_dict = {
    k: torch.tensor(v) for k, v in state_dict.items()
}
gpt.load_state_dict(state_dict)

In [None]:
prompt = 'Alan Turing theorized that computers would one day become'
x = torch.tensor(enc.encode(prompt, allowed_special={'<|endoftext|>'}))
print(f'prompt: {prompt}\nanswer: ', end='')
for next_id in gpt.generate(x, 8, temperature=0.0):
    next_token = enc.decode([next_id.item()])
    print(next_token, end='')