In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from attention import *

# Data

In [4]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
chars = sorted(list(set(text)))
vocab_dim = len(chars)

In [8]:
c_i = {c:i for i, c in enumerate(chars)}
i_c = {i:c for i, c in enumerate(chars)}
encode = lambda s: [c_i[c] for c in s]
decode = lambda l: ''.join([i_c[i] for i in l])

In [9]:
data = torch.tensor(encode(text), dtype=torch.int64)
data.shape

torch.Size([1115394])

In [10]:
n = int(.9*len(data))
data_train = data[:n]
data_val = data[n:]

In [11]:
len(data_train)

1003854

# Model

In [12]:
batch_size = 64 # N
sequence_dim = 100 # L, S
embed_dim = 78 # E
num_heads = 13 # H
assert embed_dim % num_heads == 0
train_steps = 5000
lr = 1e-3 # learning rate
torch.manual_seed(78)
device = torch.device('cuda')

In [13]:
def get_batch(data, N, L):
    idx = torch.randint(len(data) - L, (N,))
    x = torch.stack([data[i:i+L] for i in idx])
    y = torch.stack([data[i+1:i+L+1] for i in idx])
    return x, y

In [14]:
x, y = get_batch(data_train, batch_size, sequence_dim)

In [15]:
x.shape, y.shape

(torch.Size([64, 100]), torch.Size([64, 100]))

In [16]:
torch.randn(1, 2, 3, 4).split(4)

(tensor([[[[ 0.9706,  1.8401,  0.7425,  1.7492],
           [-0.4090, -0.7430,  1.5891, -0.6899],
           [-1.9549, -1.1546, -2.9000, -1.6289]],
 
          [[ 0.4538,  0.8432, -0.4011, -0.3256],
           [-2.4454,  0.1959,  0.3256,  0.2596],
           [-1.1855, -1.0788,  0.5622, -0.4791]]]]),)

In [17]:
print([decode(l) for l in x.numpy()])
print([decode(l) for l in y.numpy()])

[',\nAnd spurn upon thee, beggar, for thy boldness.\n\nLADY ANNE:\nWhat, do you tremble? are you all afrai', ' for,\nif thou beest capable of things serious, thou must\nknow the king is full of grief.\n\nShepard:\nS', ' and twenty nose-gays for\nthe shearers, three-man-song-men all, and very good\nones; but they are mos', ' a trueborn Englishman.\n\nKING RICHARD II:\nWe did observe. Cousin Aumerle,\nHow far brought you high H', "been gadding?\n\nJULIET:\nWhere I have learn'd me to repent the sin\nOf disobedient opposition\nTo you an", ":\nHe makes a July's day short as December,\nAnd with his varying childness cures in me\nThoughts that ", 'uld break a thousand oaths to reign one year.\n\nRICHARD:\nNo; God forbid your grace should be forsworn', 'hy with some little train, my Lord of Buckingham?\n\nBUCKINGHAM:\nMarry, my lord, lest, by a multitude,', 'on,\nAnd then be gone and trouble you no more.\nShall I obtain it?\n\nHENRY BOLINGBROKE:\nName it, fair c', "leather bottle.\nHis wonted 

In [18]:
class FeedForward(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_features, 4 * out_features), # TODO: 4x according to paper
            nn.ReLU(),
            nn.Linear(4 * in_features, out_features), # projection layer
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [19]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, sequence_dim, embed_dim, num_heads, dropout=0.2): # L, E, H
        super().__init__()
        # self.register_buffer("attn_mask", torch.triu(torch.full((sequence_dim, sequence_dim), float('-inf')), diagonal=1)) # flavor 1 - pytorch
        self.register_buffer("attn_mask", torch.tril(torch.ones(sequence_dim, sequence_dim)) == 0) # flavor 2 - karpathy

        self.ln1 = nn.LayerNorm(embed_dim) # https://arxiv.org/pdf/2002.04745.pdf
        self.mha = MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ff = FeedForward(embed_dim, embed_dim, dropout=dropout)

    def forward(self, x):
        x = self.ln1(x)
        attn_output, attn_output_weights = self.mha(x, x, x, need_weights=True, attn_mask=self.attn_mask, is_causal=True) # self attend
        x = x + attn_output
        x = x + self.ff(self.ln2(x)) # think on data
        return x

In [20]:
class GPT(nn.Module):
    def __init__(self, vocab_dim, sequence_dim, embed_dim, num_heads):
        super().__init__()
        self.sequence_dim = sequence_dim
        
        self.token_embedding = nn.Embedding(vocab_dim, embed_dim)
        self.position_embedding = nn.Embedding(sequence_dim, embed_dim)
        self.blocks = nn.Sequential(
            SelfAttentionBlock(sequence_dim, embed_dim, num_heads),
            SelfAttentionBlock(sequence_dim, embed_dim, num_heads),
            SelfAttentionBlock(sequence_dim, embed_dim, num_heads),
            nn.LayerNorm(embed_dim),
        )
        self.linear = nn.Linear(embed_dim, vocab_dim)

    def forward(self, x, y=None):
        # N is batch, L is length of time series, E is embedding dim
        N, L = x.shape
        token_embeddings = self.token_embedding(x) # (N, L, E)
        position_embeddings = self.position_embedding(torch.arange(L).to(device)) # (L, E) # T <= sequence_dim
        x = token_embeddings + position_embeddings # (N, L, E) +  (-, L, E) -> (N, L, E)
        x = self.blocks(x)
        logits = self.linear(x) # pred
        if y is None:
            loss = None
        else:
            N, L, E = logits.shape
            logits = logits.view(N*L, E)
            y = y.view(N*L)
            loss = F.cross_entropy(logits, y)
        return logits, loss

    def generate(self, x, n_tokens):
        for i in range(n_tokens):
            x_cropped = x[:, -self.sequence_dim:] # crop s.t. it's <= sequence_dim
            logits, _ = self(x_cropped) # (N, L, E)
            logits = logits[:, -1, :] # (N, E)
            probs = F.softmax(logits, dim=-1) # (N, E)
            y_pred = torch.multinomial(probs, num_samples=1) # (N, 1)
            x = torch.cat((x, y_pred), dim=1) # (N, L) + (N, 1) = (N, L + 1)
        return x

In [21]:
model = GPT(vocab_dim, sequence_dim, embed_dim, num_heads).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [22]:
model

GPT(
  (token_embedding): Embedding(65, 78)
  (position_embedding): Embedding(100, 78)
  (blocks): Sequential(
    (0): SelfAttentionBlock(
      (ln1): LayerNorm((78,), eps=1e-05, elementwise_affine=True)
      (mha): MultiheadAttention(
        (query): Linear(in_features=78, out_features=78, bias=False)
        (key): Linear(in_features=78, out_features=78, bias=False)
        (value): Linear(in_features=78, out_features=78, bias=False)
        (dropout1): Dropout(p=0.2, inplace=False)
        (projection): Linear(in_features=78, out_features=78, bias=True)
      )
      (ln2): LayerNorm((78,), eps=1e-05, elementwise_affine=True)
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=78, out_features=312, bias=True)
          (1): ReLU()
          (2): Linear(in_features=312, out_features=78, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
    )
    (1): SelfAttentionBlock(
      (ln1): LayerNorm((78,), eps=1e-05, elementwise_

In [23]:
# pre training
print(decode(model.generate(torch.zeros((sequence_dim, sequence_dim), dtype=torch.int64).to(device), 100).cpu().numpy()[0]))





































































































xbc$nPcNyCfW.y
o.z?,yLFzgsK'sD:3sY$bZdLeyLSpr3D-wvF,W.e.Pibgj.NDwFQBmY
LMhGVIiir&x3cbv&s'oaSyiv.t3jR


In [24]:
@torch.no_grad()
def estimate_loss(model, iters, device):
    out = []
    model.to(device)
    model.eval()
    losses = torch.zeros(iters)
    for data in [data_train, data_val]:
        for i in range(iters):
            x, y = get_batch(data, batch_size, sequence_dim)
            x = x.to(device)
            y = y.to(device)
            logits, loss = model(x, y)
            losses[i] = loss.item()
        out.append(losses.mean())
    model.train()
    return out

In [25]:
%%time
tenth = train_steps//10
for steps in range(train_steps):
    x, y = get_batch(data_train, batch_size, sequence_dim)
    x = x.to(device)
    y = y.to(device)
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % tenth == 0:
        train_loss, val_loss = estimate_loss(model, 100, device)
        print(train_loss, val_loss)

tensor(4.0399) tensor(4.0498)
tensor(2.1530) tensor(2.1911)
tensor(1.8653) tensor(1.9626)
tensor(1.7370) tensor(1.8729)
tensor(1.6652) tensor(1.8210)
tensor(1.6200) tensor(1.7882)
tensor(1.5879) tensor(1.7570)
tensor(1.5731) tensor(1.7515)
tensor(1.5505) tensor(1.7238)
tensor(1.5323) tensor(1.7157)
CPU times: total: 17.5 s
Wall time: 1min 56s


In [26]:
# post training
model.eval()
print(decode(model.generate(torch.zeros((sequence_dim, sequence_dim), dtype=torch.int64).to(device), 1000).cpu().numpy()[0]))





































































































FIRTIUS:
No, let him your agreed
The hear come I dear; I tell 'tis full treak.

ARCHOP,
So were to the orsong, ever crison's your good tell do me.

POLIXENES:
That king of deed.
To bood. What, and your jest discreed;
Or know your gaents Richard many to enjet,
That that, this chappet to steak pearle friends of thin
along is receive us amer! rangely the l'd cheely
youth peace. When brief, my fortune, for as so;
Thou shumbled upon his una was womanish,
Will to be yield. What you.

JOHN OF GOFORK:
Mord news, royal womd
Those in haste the receeqe's in with 
PANLIA:
I sto-content! which anothing the Romanation
As mone and in the rupposed, or wish peptater:
Here liecte ease thee tyranner:
These but be to affidge unto her more,
To father and fiar up of at justicion new!
Where shall did a comfort to her?
Montabst the arsabed, the det you earth; I keep,
He we lopking: my pacity deserved
Spishrot f