## Build the dataset

In [3]:
!pip install convokit

Collecting convokit
  Downloading convokit-3.0.0.tar.gz (183 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.2/183.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting msgpack-numpy>=0.4.3.2 (from convokit)
  Downloading msgpack_numpy-0.4.8-py2.py3-none-any.whl (6.9 kB)
Collecting dill>=0.2.9 (from convokit)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
Collecting clean-text>=0.6.0 (from convokit)
  Downloading clean_text-0.6.0-py3-none-any.whl (11 kB)
Collecting unidecode>=1.1.1 (from convokit)
  Downloading Unidecode-1.3.8-py3-none-any.whl (235 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.5/235.5 kB[0m [31m10.0 MB/s[0m eta [3

In [4]:
import os
import torch
from convokit import Corpus, download

# filename = "~/.convokit/downloads/friends-corpus"
# corpus = Corpus(filename=os.path.expanduser(filename))
corpus = Corpus(download('friends-corpus'))

utterance = corpus.get_utterance('s07_e14_c01_u018')
print(utterance.speaker.id)
print(utterance.text)

Downloading friends-corpus to /root/.convokit/downloads/friends-corpus
Downloading friends-corpus from http://zissou.infosci.cornell.edu/convokit/datasets/friends-corpus/friends-corpus.zip (6.1MB)... Done
No configuration file found at /root/.convokit/config.yml; writing with contents: 
# Default Backend Parameters
db_host: localhost:27017
data_directory: ~/.convokit/saved-corpora
default_backend: mem
Rachel Green
Well, can I keep the presents and still be 29?


In [5]:
import re
re_pattern = "[^0-9a-zA-Z,.?!' ]"

all_utterance = []

for utterance in corpus.iter_utterances():
    speaker = utterance.speaker.id
    if speaker == "TRANSCRIPT_NOTE":
        # Only interested in conversations
        continue
    speaker = re.sub(re_pattern, '', speaker)
    text = re.sub(re_pattern, '', utterance.text)
    all_utterance.append(f"{speaker}\n{text}")

n = int(len(all_utterance) * 0.9)
train_data_text = '\n\n'.join(all_utterance[:n])
val_data_text = '\n\n'.join(all_utterance[n:])

print(train_data_text[:500])

Monica Geller
There's nothing to tell! He's just some guy I work with!

Joey Tribbiani
C'mon, you're going out with the guy! There's gotta be something wrong with him!

Chandler Bing
All right Joey, be nice. So does he have a hump? A hump and a hairpiece?

Phoebe Buffay
Wait, does he eat chalk?

Phoebe Buffay
Just, 'cause, I don't want her to go through what I went through with Carl oh!

Monica Geller
Okay, everybody relax. This is not even a date. It's just two people going out to dinner and no


In [6]:
all_characters = sorted(list(set(train_data_text)))
stoi = {s:i for i, s in enumerate(sorted(all_characters))}
itos = {i:s for s, i in stoi.items()}

print("Dictionary size:", len(stoi))
print(stoi)

Dictionary size: 69
{'\n': 0, ' ': 1, '!': 2, "'": 3, ',': 4, '.': 5, '0': 6, '1': 7, '2': 8, '3': 9, '4': 10, '5': 11, '6': 12, '7': 13, '8': 14, '9': 15, '?': 16, 'A': 17, 'B': 18, 'C': 19, 'D': 20, 'E': 21, 'F': 22, 'G': 23, 'H': 24, 'I': 25, 'J': 26, 'K': 27, 'L': 28, 'M': 29, 'N': 30, 'O': 31, 'P': 32, 'Q': 33, 'R': 34, 'S': 35, 'T': 36, 'U': 37, 'V': 38, 'W': 39, 'X': 40, 'Y': 41, 'Z': 42, 'a': 43, 'b': 44, 'c': 45, 'd': 46, 'e': 47, 'f': 48, 'g': 49, 'h': 50, 'i': 51, 'j': 52, 'k': 53, 'l': 54, 'm': 55, 'n': 56, 'o': 57, 'p': 58, 'q': 59, 'r': 60, 's': 61, 't': 62, 'u': 63, 'v': 64, 'w': 65, 'x': 66, 'y': 67, 'z': 68}


In [7]:
encode = lambda s: torch.tensor([stoi[c] for c in s])
decode = lambda c: ''.join([itos[v.item()] for v in c])

train_data = encode(train_data_text)
val_data = encode(val_data_text)

print(encode("Hello world"))
print(decode(encode("Hello world")))

tensor([24, 47, 54, 54, 57,  1, 65, 57, 60, 54, 46])
Hello world


In [8]:
torch.manual_seed(100)
batch_size = 4
block_size = 8
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", device)

def get_batch(split):
    if split == 'train':
        data = train_data
    elif split == 'val':
        data = val_data

    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return (x, y)

x, y = get_batch('train')
print(x[0], '->', y[0])
print()
for t in range(block_size):
    print(f"{x[0, :t+1]} -> {y[0, t]}")
    print(f"{decode(x[0, :t+1])} -> {decode(y[0, t].view(-1))}")

Device: cuda
tensor([52, 43, 45, 53, 47, 62,  5,  0], device='cuda:0') -> tensor([43, 45, 53, 47, 62,  5,  0,  0], device='cuda:0')

tensor([52], device='cuda:0') -> 43
j -> a
tensor([52, 43], device='cuda:0') -> 45
ja -> c
tensor([52, 43, 45], device='cuda:0') -> 53
jac -> k
tensor([52, 43, 45, 53], device='cuda:0') -> 47
jack -> e
tensor([52, 43, 45, 53, 47], device='cuda:0') -> 62
jacke -> t
tensor([52, 43, 45, 53, 47, 62], device='cuda:0') -> 5
jacket -> .
tensor([52, 43, 45, 53, 47, 62,  5], device='cuda:0') -> 0
jacket. -> 

tensor([52, 43, 45, 53, 47, 62,  5,  0], device='cuda:0') -> 0
jacket.
 -> 



# A simple MLP

In [9]:
from torch import nn
from torch.nn import functional as F

class FeedForward(nn.Module):
    def __init__(self, d_emb, d_ff, dropout):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(d_emb, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_emb),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)

In [10]:
d_vocab = len(stoi)
d_emb = 64
d_ff = 128
dropout = 0.1


# v0.1
class SimpleLanguageModel(nn.Module):
    '''Simple Bigram language model with a single feedforward layer'''
    def __init__(self, d_vocab, d_emb, d_ff, dropout):
        super().__init__()
        self.emb = nn.Embedding(d_vocab, d_emb)  # B, T -> B, T, d_emb
        self.ffwd = FeedForward(d_emb, d_ff, dropout)  # B, T, d_emb -> B, T, d_emb
        self.lm_head = nn.Linear(d_emb, d_vocab)  # B, T, d_emb -> B, T, d_vocab

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, targets=None):
        # B, T = x.shape
        x = self.emb(x) # B, T, d_emb
        logits = self.lm_head(x) # B, T, d_vocab

        loss = None
        if targets is not None:
            B, T, d_vocab = logits.shape
            logits = logits.view(-1, d_vocab)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, x, n_tokens):
        # B, T = x.shape
        for _ in range(n_tokens):
            x_cond = x[:, -block_size:]
            logits, loss = self(x_cond)  # logits: B, T, d_vocab
            logits = logits[:, -1, :]  # B, d_vocab
            probs = F.softmax(logits, dim=-1)  # B, d_vocab
            x_next = torch.multinomial(probs, num_samples=1)  # B, 1
            x = torch.cat((x, x_next), dim=1)  # B, T+1
        return x

In [11]:
model = SimpleLanguageModel(d_vocab, d_emb, d_ff, dropout).to(device)

In [12]:
xb, yb = get_batch('train')
xb, yb = xb[:1], yb[:1]
logit, loss = model(xb, yb)
print(logit.shape)
print(loss)

torch.Size([8, 69])
tensor(4.2356, device='cuda:0', grad_fn=<NllLossBackward0>)


In [13]:
# v0.2 - with the generate function
class SimpleLanguageModel(nn.Module):
    '''Simple Bigram language model with a single feedforward layer'''
    def __init__(self, d_vocab, d_emb, d_ff, dropout):
        super().__init__()
        self.emb = nn.Embedding(d_vocab, d_emb)  # B, T -> B, T, d_emb
        self.ffwd = FeedForward(d_emb, d_ff, dropout)  # B, T, d_emb -> B, T, d_emb
        self.lm_head = nn.Linear(d_emb, d_vocab)  # B, T, d_emb -> B, T, d_vocab

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, targets=None):
        # B, T = x.shape
        x = self.emb(x) # B, T, d_emb
        logits = self.lm_head(x) # B, T, d_vocab

        loss = None
        if targets is not None:
            B, T, d_vocab = logits.shape
            logits = logits.view(-1, d_vocab)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, x, n_tokens):
        # B, T = x.shape
        for _ in range(n_tokens):
            x_cond = x[:, -block_size:]
            logits, loss = self(x_cond)  # logits: B, T, d_vocab
            logits = logits[:, -1, :]  # B, d_vocab
            probs = F.softmax(logits, dim=-1)  # B, d_vocab
            x_next = torch.multinomial(probs, num_samples=1)  # B, 1
            x = torch.cat((x, x_next), dim=1)  # B, T+1
        return x

In [183]:
model = SimpleLanguageModel(d_vocab, d_emb, d_ff, dropout).to(device)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
gen_text = model.generate(x=torch.zeros((1, 1), dtype=torch.long).to(device), n_tokens=100)
print(gen_text)
print(decode(gen_text[0]))

25477
tensor([[ 0,  9, 28,  1, 11,  1, 38, 40, 35, 11, 21, 63, 22, 31, 67,  9,  2, 19,
         26, 28, 37, 32, 45, 48, 50,  8, 51, 35, 52, 27, 58, 50, 31, 21, 15, 55,
          0, 42, 14, 12,  5, 35, 35, 17, 28,  7,  1, 62, 56, 48, 23, 20, 20, 60,
          4, 48,  6, 31, 35,  9, 50,  5, 13, 17, 13, 52,  3, 26, 12, 63, 15, 15,
         15, 64, 24, 34, 13, 14, 23, 44, 68, 56, 14, 35, 16, 21, 43, 20, 47, 22,
         32,  4, 49, 58, 34,  8,  0,  8, 65,  8, 58]], device='cuda:0')

3L 5 VXS5EuFOy3!CJLUPcfh2iSjKphOE9m
Z86.SSAL1 tnfGDDr,f0OS3h.7A7j'J6u999vHR78Gbzn8S?EaDeFP,gpR2
2w2p


# Training & Evaluation

In [179]:
@torch.no_grad()
def estimate_loss(model, eval_iters=200):
    model.eval()
    total_loss = 0.
    for i in range(eval_iters):
        x, y = get_batch('val')
        logits, loss = model(x, y)
        total_loss += loss.item()
    return total_loss / eval_iters

estimate_loss(model)

4.233867290019989

In [180]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [181]:
batch_size = 32
for steps in range(100000):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 5000 == 0:
        print(steps, loss.item())
        if steps % 20000 == 0:
            print("Validation loss:", estimate_loss(model))

    # comment out after confirmed it works
    # if steps >= 5000:
    #     break

0 4.233941078186035
Validation loss: 4.23217838048935
5000 2.3934381008148193
10000 2.3754165172576904
15000 2.2746267318725586
20000 2.4064085483551025
Validation loss: 2.4029186367988586
25000 2.3725385665893555
30000 2.3118181228637695
35000 2.3217997550964355
40000 2.395559787750244
Validation loss: 2.3980125439167024
45000 2.3658084869384766
50000 2.5170881748199463
55000 2.466986656188965
60000 2.441925525665283
Validation loss: 2.394424878358841
65000 2.2491915225982666
70000 2.4393274784088135
75000 2.30452823638916
80000 2.3951332569122314
Validation loss: 2.399096369743347
85000 2.3104934692382812
90000 2.4537477493286133
95000 2.4453818798065186


In [184]:
# Generation
print(decode(
    model.generate(x=torch.zeros((1, 1), dtype=torch.long).to(device), n_tokens=2000)[0]
))


Docherengowand the Cha I'the s ing.



I y I'm Geveba Thenke p ang s wid ana Jo I'tou we knoy'sus leeaure Any bifowinge wh ang
Moufon matme, Rind yonewa Caksshe Th Thanonagou'ra g m ss chandat w I tr ateasebelit d hicastufand. g...

Ohey ouher yo inewosh wss Cho mey su see'sit anyo, wn, s hey Bitooutale'st ff walle





Try th, s Bufithr feat cy inost'cksenore ly as!
Y'ves Han



Jol I y choway?
Uhtha h. d n rin hthe s Tr Bat.. Bite I m. Fos ene wangh, t k l.
Many l her

Rotoooust knd Gerenonibbbbbelerelait'vin! s s t e, in 's cayowrryou, alf d e Hine Ge Tr
Ravalee'd Bililesot ost Buyo, bethang
Seboey, thas l! d nin
Row! t I t.

Whelllle tehatoer
Oheld y?
Soiverechellin outiomeaver llaffonngoularts?

Wenk. Phi
Yabacan.. wouth. meros Ran'knn
Joull t 'h, a.
Rfor sst n'mebanyoo qu y kan Eheyor hon's s?
Eaurethay?! copar
Rag waheyore g.
Four I'tre owam dom d y pl Ph. t?

Troey Bume g


Th I'sioealy! tre y mbe wan'leelllorr s. ite llp lerin uf ju Bitho tele?
Roo Chatryont Budnghaly?
Nibene

## Self attention

Self attention is a communication mechanism.

In [20]:
torch.manual_seed(415)
values = torch.randint(0, 10, (3, 2)).float()

# Version 1: naive for-loop
first_k_avg = torch.empty_like(values)
for i in range(first_k_avg.shape[0]):
    for j in range(first_k_avg.shape[1]):
        first_k_avg[i, j] = values[:i + 1, j].mean()

print("values = \n", values)
print('-' * 10)
print("first_k_avg = \n", first_k_avg)

values = 
 tensor([[5., 8.],
        [0., 5.],
        [3., 4.]])
----------
first_k_avg = 
 tensor([[5.0000, 8.0000],
        [2.5000, 6.5000],
        [2.6667, 5.6667]])


In [21]:
# Version 2: mat mul
num_rows = values.shape[0]
wei = torch.tril(torch.ones(num_rows, num_rows))
wei = wei / torch.sum(wei, 1, keepdim=True)
print("wei = \n", wei)
print('-' * 10)
print("first_k_avg = \n", wei @ values)

wei = 
 tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----------
first_k_avg = 
 tensor([[5.0000, 8.0000],
        [2.5000, 6.5000],
        [2.6667, 5.6667]])


In [22]:
# Version 3: softmax
wei = torch.zeros((num_rows, num_rows))
tril = torch.tril(torch.ones_like(wei))
wei = wei.masked_fill(tril == 0, float('-inf'))
print("wei before softmax:\n", wei)
print('-' * 10)
wei = F.softmax(wei, dim=1)
print("wei after softmax:\n", wei)
print('-' * 10)
print("first_k_avg = \n", wei @ values)

wei before softmax:
 tensor([[0., -inf, -inf],
        [0., 0., -inf],
        [0., 0., 0.]])
----------
wei after softmax:
 tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----------
first_k_avg = 
 tensor([[5.0000, 8.0000],
        [2.5000, 6.5000],
        [2.6667, 5.6667]])


In [23]:
# Version 4: serlf attention

B, T, C = 4, 4, 32  # batch, time, channels (embeddings)
x = torch.randn(B, T, C)

# Single head self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)  # (B, T, 16)
q = query(x)  # (B, T, 16)
wei = q @ k.transpose(-2, -1) * C**-0.5  # (B, T, 16) @ (B, 16, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

print("wei:\n", wei)
print('-' * 10)
print("weight vector for third prediction:\n", wei[0, 2])
print("out shape:", out.shape)

wei:
 tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4607, 0.5393, 0.0000, 0.0000],
         [0.3991, 0.2089, 0.3920, 0.0000],
         [0.2587, 0.2108, 0.2907, 0.2397]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4038, 0.5962, 0.0000, 0.0000],
         [0.2823, 0.3892, 0.3285, 0.0000],
         [0.2289, 0.2493, 0.2322, 0.2896]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.3789, 0.6211, 0.0000, 0.0000],
         [0.3268, 0.3199, 0.3533, 0.0000],
         [0.1970, 0.2184, 0.2262, 0.3584]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4430, 0.5570, 0.0000, 0.0000],
         [0.2956, 0.3503, 0.3541, 0.0000],
         [0.2012, 0.2451, 0.2963, 0.2574]]], grad_fn=<SoftmaxBackward0>)
----------
weight vector for third prediction:
 tensor([0.3991, 0.2089, 0.3920, 0.0000], grad_fn=<SelectBackward0>)
out shape: torch.Size([4, 4, 16])


## Notes on Key, Query, Value in self attention

Attention is a communication mechanism. It works as a directed graph, passing information along some direction. In text generation, it often involves passing information from past tokens to future tokens.

There are three components in self attention mechanism.
* Query - Q(x) projects what information x is seeking
* Key - K(x) projects what information x contains
* Value - V(x) determines what information should be aggregated for the purpose of this single attention head

To understand Q, K, V intuitively,
* Think of x like private information or private key of a token, it is then projected into the Query, Key, and Value handled by the attention head.
* The output of self attention is a weighted sum of the projection V(x), not the tokens themselves. Why?
    * It enables us to simultaneously consider various aspects of tokens in different heads after we introduce multi-head attention mechanism next.
    * For example, in processing the word "cat" within a sentence, different attention heads might aggregate information with regards to its grammatical role (noun), its conceptual meaning as an animal, or its syntactic function as a subject or object. This diversity allows for a richer, more nuanced understanding of text.


In [170]:
class Head(nn.Module):
    '''Single head self attention'''

    def __init__(self, d_emb, d_head):
        super().__init__()
        self.d_head = d_head
        self.key = nn.Linear(d_emb, d_head, bias=False)
        self.query = nn.Linear(d_emb, d_head, bias=False)
        self.value = nn.Linear(d_emb, d_head, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        # x: B, T, d_emb
        B, T, C = x.shape
        k = self.key(x)  # B, T, dh
        q = self.query(x)  # B, T, dh
        v = self.value(x)  # B, T, dh
        wei = q @ k.transpose(-2, -1) * (self.d_head ** -0.5)  # B, T, T
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        out = wei @ v  # B, T, dh
        return out


class MultiHead(nn.Module):
    '''Multi head self attention'''

    def __init__(self, num_heads, d_emb, d_head):
        super().__init__()
        self.heads = nn.ModuleList([Head(d_emb, d_head) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * d_head, d_emb)

    def forward(self, x):
        # x: B, T, d_emb
        head_out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(head_out)
        return out

## Final Notes on Attention Mechanism

1. **Position Encoding**: Attention mechanisms inherently lack a notion of order, unlike convolutions. Therefore, inputs to the attention mechanism should include positional information to maintain sequence context.
2. **Scaled Attention**: To prevent the softmax function from collapsing into a one-hot vector, it's crucial that the weights (Q @ K) are diffused appropriately, hence the need for scaling attention (divided by d_head**0.5).
3. **Batch Isolation**: Examples within the same batch do not interact; each instance is processed independently.
4. **Transformer Architecture Variations**:
    - **Decoder Block**: Restricts information flow to prevent future tokens from influencing the output, typically used in output generation phases.
    - **Encoder Block**: Allows free communication among all nodes, fully utilizing context, typically used in input interpretation phases.
    - **Application in the Transformer paper**: In context of machine translation, the original paper uses Encoder Blocks for source language text, encoding full contextual understanding, and uses Decoder Blocks for target language text, ensuring generated content is influenced only by preceding text and the source content.

In addition to self-attention, a transformer block comprises:

- **Computation Layer**: a feedforward network computes over the aggregated information, but on per token basis (no communication between two tokens at this step)
- **Optimization Techniques for Deep Networks**:
    - **Residual Connection**: Facilitates learning by creating shortcuts for gradients, acting as a "super-highway" for backpropagation.
    - **Layer Normalization**: Standardizes the inputs to each layer, ensuring consistent scale and aiding in stable training.
    - **Dropout**: Randomly omits a subset of features at each layer to prevent overfitting and encourage generalized representations.

## Attention Block

In [171]:
class LayerNorm1d(nn.Module):
    '''Layer normalization over the last dimension'''
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = torch.ones(dim, device=device)
        self.beta = torch.zeros(dim, device=device)

    def forward(self, x):
        # In batch norm, we aggregate over columns
        # In layer norm, we aggregate over rows
        xmean = x.mean(-1, keepdim=True)
        xvar = x.var(-1, keepdim=True)
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        out = self.gamma * xhat + self.beta
        return out

In [172]:
class TransformerBlock(nn.Module):
    '''Transformer block with multi-head attention and feedforward'''

    def __init__(self, n_emb, n_head):
        super().__init__()
        d_head = n_emb // n_head
        self.attn = MultiHead(n_head, n_emb, d_head)
        self.ffwd = FeedForward(n_emb, n_emb, dropout=0.1)
        self.ln1 = LayerNorm1d(n_emb)
        self.ln2 = LayerNorm1d(n_emb)

    def forward(self, x):
        # x: B, T, n_emb
        # with residual connection
        x = x + self.attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [173]:
# v0.3 - with the transformer block
class SimpleLanguageModelWithTransformer(nn.Module):
    '''Simple Bigram language model with transformer layers'''
    def __init__(self, d_vocab, d_emb, num_heads, n_layers):
        super().__init__()
        self.token_emb = nn.Embedding(d_vocab, d_emb)  # B, T -> B, T, d_emb
        self.position_emb = nn.Embedding(block_size, d_emb)  # B, T -> B, T, d_emb
        # Note: Changed here
        # self.ffwd = FeedForward(d_emb, d_ff, dropout)  # B, T, d_emb -> B, T, d_emb
        self.blocks = nn.Sequential(
            *[TransformerBlock(d_emb, num_heads) for _ in range(n_layers)]
        )
        self.ln_final = LayerNorm1d(d_emb)
        self.lm_head = nn.Linear(d_emb, d_vocab)  # B, T, d_emb -> B, T, d_vocab

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, targets=None):
        _, T = x.shape
        token_emb = self.token_emb(x) # B, T, d_emb
        pos_emb = self.position_emb(torch.arange(T, device=device))
        x = token_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_final(x)
        logits = self.lm_head(x) # B, T, d_vocab

        loss = None
        if targets is not None:
            B, T, d_vocab = logits.shape
            logits = logits.view(-1, d_vocab)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, x, n_tokens):
        # B, T = x.shape
        for _ in range(n_tokens):
            x_cond = x[:, -block_size:]
            logits, loss = self(x_cond)  # logits: B, T, d_vocab
            logits = logits[:, -1, :]  # B, d_vocab
            probs = F.softmax(logits, dim=-1)  # B, d_vocab
            x_next = torch.multinomial(probs, num_samples=1)  # B, 1
            x = torch.cat((x, x_next), dim=1)  # B, T+1
        return x

In [174]:
n_layer = 2
num_heads = 4
model = SimpleLanguageModelWithTransformer(d_vocab, d_emb, num_heads, n_layer)
model = model.to(device)
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

58949


In [175]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
batch_size = 32
for steps in range(100000):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if steps % 5000 == 0:
        print(steps, loss.item())
        if steps % 20000 == 0:
            print("Validation loss:", estimate_loss(model))

    # comment out after confirmed it works
    # if steps >= 5000:
    #     break
print("Validation loss:", estimate_loss(model))

0 4.246626377105713
Validation loss: 4.116809515953064
5000 1.7877894639968872
10000 1.5437992811203003
15000 1.7735244035720825
20000 1.457893967628479
Validation loss: 1.5588035702705383
25000 1.652052879333496
30000 1.3510609865188599
35000 1.6204665899276733
40000 1.3851908445358276
Validation loss: 1.5322303700447082
45000 1.4144151210784912
50000 1.4845095872879028
55000 1.551680088043213
60000 1.5847973823547363
Validation loss: 1.5175394719839097
65000 1.1751842498779297
70000 1.6589553356170654
75000 1.348968744277954
80000 1.4910104274749756
Validation loss: 1.517003682255745
85000 1.4565149545669556
90000 1.7095304727554321
95000 1.4697622060775757
Validation loss: 1.4878948092460633


In [177]:
# Generation
print(decode(
    model.generate(x=torch.zeros((1, 1), dtype=torch.long, device=device), n_tokens=2000)[0][8:]
))

 Geller
It want so listen! What are you talks.

Chright.

Rachel Green
Hey.

Ross Geller
Talk I doub uhy chel, like with Mart! I'm gonna go on the ladid!

Rachel Green
Come look that our do! Till bus! I cause even go me!! Y'know' the come won, some open by I'm the buass that here, I have again Chandler God, tereson. Look. I'm ask the self! It'll exep!

Rachel Green
Y'know, no think you like I'm somebody want to sance that ask somethrooman out even, who keganbody doesn't be for Chandler Bing
Come.

Joey Tribbiani
You've beenononononey on a wonderonce and a uh!

Ross Geller
Willior janna peoping bate a look' seem would call alacrazing... you not.

Chandler Bing
And wlamost because
NNo, I have get at that's going? Unto see, don't him than you've pisuck.

Chandler!! And uh, when you big out back you can how did it.

Rachel Green
Well a get me, but is, someone.

No you one.

Ross Geller
Yeah, I toping people find over good! And in to but, you if with you say because bood.

Joey Tribbiani
So