In [3]:
import os
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x136792790>

In [19]:
# read in
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# 65 unique chars/tokens/iints
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [16]:
# tokenize string to integer according to vocabulary, store in torch.tensor
stoi = { ch:i for i,ch in enumerate(chars) } # mapping from unique chars to ints
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: stoi
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: itos
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:1000]) # first 1000 chars look like this to GPT
# tensor is multi-dim mat/array [], dimensions such as batch size, sequence length (time dimension), embedding size

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [17]:
# split data into train (90%) and validation (10%) sets
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [20]:
# plug in training data into transformer - train on random batches of chunks/blocks of train_data
block_size = 8  # 8 characters, 8 contexts, 8 targets - time dimension
# context (1): 18 | target: 47
# context (2): 18, 47 | target: 56
# .
# .
# .
# context (block_size): 18, 47, 56, 57, 58, 1, 15, 47 | target: 58
# transformer is used to seeing contexts from 1 to block_size
batch_size = 4 # number of sequences/chunks

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # generate 4 random chunk positions
    x = torch.stack([data[i:i+block_size] for i in ix]) # contexts
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # targets
    return x, y # returns 4 batches of chunks of 8 tokens/chars for inputs and targets

xb, yb = get_batch('train') # from train_data
print('inputs/contexts:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

# xb input tensor is fed into transformer, transformer processes contexts, looks up correct target to predict in yb

inputs/contexts:
torch.Size([4, 8])
tensor([[57, 43, 60, 43, 52,  1, 63, 43],
        [60, 43, 42,  8,  0, 25, 63,  1],
        [56, 42,  5, 57,  1, 57, 39, 49],
        [43, 57, 58, 63,  6,  1, 58, 46]])
targets:
torch.Size([4, 8])
tensor([[43, 60, 43, 52,  1, 63, 43, 39],
        [43, 42,  8,  0, 25, 63,  1, 45],
        [42,  5, 57,  1, 57, 39, 49, 43],
        [57, 58, 63,  6,  1, 58, 46, 47]])
----
when input is [57] the target: 43
when input is [57, 43] the target: 60
when input is [57, 43, 60] the target: 43
when input is [57, 43, 60, 43] the target: 52
when input is [57, 43, 60, 43, 52] the target: 1
when input is [57, 43, 60, 43, 52, 1] the target: 63
when input is [57, 43, 60, 43, 52, 1, 63] the target: 43
when input is [57, 43, 60, 43, 52, 1, 63, 43] the target: 39
when input is [60] the target: 43
when input is [60, 43] the target: 42
when input is [60, 43, 42] the target: 8
when input is [60, 43, 42, 8] the target: 0
when input is [60, 43, 42, 8, 0] the target: 25
when inp

In [23]:
# Feed into neural network: bigram language model
# inputs (context indices/characters):
# tensor([[57, 43, 60, 43, 52,  1, 63, 43],
#        [60, 43, 42,  8,  0, 25, 63,  1],
#        [56, 42,  5, 57,  1, 57, 39, 49],
#        [43, 57, 58, 63,  6,  1, 58, 46]])

class BigramLanguageModel(nn.Module):
# Bigram model predicts likelihood of a next token/word in a sequence based on previous word/token - depends only on the immediate previous one

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        # ex: 57 will pluck out 57th row (65-dimensional embedding vector)
        # each int token/char is represented by a 65-dimensional embedding vector [-0.2442, 0.1703, ...] where each channel/dimension of the vector represents the score for the next token - hence why we need 65 possible channels for 65 possible next tokens
        # the embedding vector is not the semantic meaning of the char in this case, but is rather the scores/predictions of all possible next chars - these scores can be converted to a prob distr which is the predictions assigned to each label/class/token, the target is the ground truth label/class/token/index
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # 65x65 lookup table -- 65 unique tokens x 65 embedding channels/dimensions/next_token_scores
    
    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C) 4x8x65

        # evaluate the loss between the predicted labels (probability distribution) and the target labels (ground truth)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # 32x65 stretch out to 2d
            targets = targets.view(B*T) # 32 1d
            loss = F.cross_entropy(logits, targets)

        return logits, loss # scores for next character in the sequence

    # generate text up to max_new_tokens, continues generation of new tokens (8+1+2+3...max_new_tokens) in time dimension in all 4 batch dimensions:
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context (current batch of inputs)
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx) # don't pass targets and no need to evaluate loss here in generate(), since loss is calculated in the forward function DURING TRAINING
            # focus only on the last time step (last element in T dimension)
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities (convert logits to prob distr)
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # predicts one idx_next sample index for each T dimension (B, 1)
            # append sampled index to the running sequence, whatever is predicted is concatenated on top of the previous idx along the time dimension
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel() # construct lm
# Example random batches fed in:
logits, loss = m(xb, yb) # pass in sample inputs (idx) and targets (to forward)
print(logits.shape)
print(logits)
print(loss) # loss should be negative log likelihood: -ln(1/65) = 4.17

idx = torch.zeros((1, 1), dtype=torch.long) # 0 (newline character) is how we kick of the generation B=1, T=1
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist())) # need to index into [0] row to pluck the single batch dimension (array of indices) generated

# Ridiculous: right now we are feeding in the entire sequence, but since this is a bigram model, we only need the immediate previous token to make a prediction
#   later, characters will look further back in the history sequence

torch.Size([32, 65])
tensor([[ 0.3302,  1.4595, -1.7275,  ..., -1.4876, -0.7216, -1.6158],
        [ 0.0537, -0.4753,  1.6139,  ..., -0.5359,  0.4087,  1.5254],
        [-0.1348, -0.2903, -0.5741,  ...,  0.5060,  0.4991,  0.0977],
        ...,
        [-0.1262,  1.2708, -0.0055,  ..., -1.0670, -0.9107,  0.2090],
        [ 0.5353,  0.7397, -1.3648,  ..., -0.6719,  0.9182,  1.0367],
        [ 0.3575, -1.9182, -0.7526,  ...,  1.3856, -0.9983,  0.3111]],
       grad_fn=<ViewBackward0>)
tensor(4.5792, grad_fn=<NllLossBackward0>)

KDwe-zBAhLLcTEOSRS$jEKnYwfaD'-ErPM!nquxVkKeERb,Yo.p,zAdq'Ua$.mpdvt-cm :gH?VjIfV3KRnoZQRK,nADQAO3hE?y


In [29]:
# Instead of printing loss for each batch, estimate average loss
eval_iters = 200
@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

In [31]:
### Train the model:
# create a PyTorch optimizer - takes the gradients and updates the parameters (weights/channels/scores of the embedding vectors in the embedding table) using the gradients
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3) # stochastic gradient descent is simplest optimizer, but let's use AdamW, set learning rate to 1e-3
batch_size = 32
max_iters = 5000
eval_interval = 100
for iter in range(max_iters): # increase number of steps for good results... 
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"epoch {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb) # loss given the targets
    optimizer.zero_grad(set_to_none=True) # clear previous iteration gradients
    loss.backward() # loss is backpropagated to compute the gradients of the parameters w/respect to the loss
    optimizer.step() # update the parameters

# print(loss.item()) --each individual batch loss is noisy, so we need to average up losses using estimate_loss to get a better final value for loss

### generate text from the model:
idx = torch.zeros((1, 1), dtype=torch.long) # start from newline
print(decode(m.generate(idx, max_new_tokens=500)[0].tolist()))

# Up to this point, we are using the simplest possible bigram model
    # Tokens are not talking to each other, only looking at last char to predict
    #Need a transformer so tokens can get better context

epoch 0: train loss 2.4519, val loss 2.4920
epoch 100: train loss 2.4479, val loss 2.4904
epoch 200: train loss 2.4522, val loss 2.4941
epoch 300: train loss 2.4538, val loss 2.4898
epoch 400: train loss 2.4521, val loss 2.4790
epoch 500: train loss 2.4521, val loss 2.4983
epoch 600: train loss 2.4500, val loss 2.4832
epoch 700: train loss 2.4572, val loss 2.4946
epoch 800: train loss 2.4544, val loss 2.4939
epoch 900: train loss 2.4604, val loss 2.4990
epoch 1000: train loss 2.4448, val loss 2.4985
epoch 1100: train loss 2.4538, val loss 2.4877
epoch 1200: train loss 2.4526, val loss 2.4944
epoch 1300: train loss 2.4561, val loss 2.4902
epoch 1400: train loss 2.4503, val loss 2.4960
epoch 1500: train loss 2.4478, val loss 2.4861
epoch 1600: train loss 2.4486, val loss 2.4895
epoch 1700: train loss 2.4541, val loss 2.4905
epoch 1800: train loss 2.4532, val loss 2.4937
epoch 1900: train loss 2.4557, val loss 2.4904
epoch 2000: train loss 2.4478, val loss 2.4872
epoch 2100: train loss 2.

In [44]:
# The mathematical trick in self-attention

# toy example:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels (information at each point in the sequence)
x = torch.randn(B,T,C) # random tensor
x

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])

In [55]:
# version 1:
# 8 time dim sequence tokens are currently not talking to each other
# element at 5th location should not talk to 6th, 7th, 8th location, only talk to 4th 3rd, 2nd, 1st
#   can't talk to future
# easiest way to communicate with past is to average all previous elements
# take channels from all steps up to and including current location and average into a feature vector that summarizes current char in context of its history
# Just doing an average has a lot of losses of context/space though, weak communication
# For every batch element (Tth token in a sequence), calculate the average of all the vectors in the previous tokens and current token

# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # bag of words
for b in range(B): # iterate over batches, b is 0 to 3
    for t in range(T): # iterate over sequence, t is 0 to 7
        xprev = x[b,:t+1] # (t+1,C) 5x2
        xbow[b,t] = torch.mean(xprev, 0) # get a vector (1,C) like [mean0,mean1] that is stored in xbow[b,t] -- xbow will be (B,T,C)
print(x[0])
print(xbow[0]) # xbow[0][4] is average of x[0][0] to x[0][4]


tensor([[ 0.1808, -0.0700],
        [ 0.1723, -0.0832],
        [ 0.1737, -0.0819],
        [ 0.1738, -0.0819],
        [ 0.1737, -0.0818],
        [ 0.1737, -0.0819],
        [ 0.1737, -0.0819],
        [ 0.1737, -0.0819]])
tensor([[ 0.1808, -0.0700],
        [ 0.1765, -0.0766],
        [ 0.1756, -0.0784],
        [ 0.1751, -0.0792],
        [ 0.1749, -0.0798],
        [ 0.1747, -0.0801],
        [ 0.1745, -0.0804],
        [ 0.1744, -0.0805]])


In [65]:
# Above is inefficient
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation" (feature vectors that summarize each char in context of its history)
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3)) #3x3 returns lower triangular portion of matrix
a = a / torch.sum(a, 1, keepdim=True) # all rows of triangle sum to 1
b = torch.randint(0,10,(3,2)).float() #3x2 random mat
c = a @ b #3x2 matrix is averages

# Mat mult:
# a row 0 dot product b col 0 = c row 0 col 0
# a row 0 dot product b col 1 = c row 0 col 1
# a row 1 dot product b col 0 = c row 1 col 0
# a row 1 dot product b col 1 = c row 1 col 1 ...

print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [69]:
# version 2:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
print("weights:")
print(wei)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C) 8x8 @ 4x8x2

weights:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


In [73]:
# version 3: use Softmax, we'll use this in self attention
tril = torch.tril(torch.ones(T, T)) # tril is traingle of 1s TxT mat
wei = torch.zeros((T,T)) # TxT mat of 0s, weights are interaction strengths/affinities/interest between tokens ( how much of each token do we want to average up/aggregate )
wei = wei.masked_fill(tril == 0, float('-inf')) # TxT mat of 0s bottom triangle, top triangle is -inf, tokens from the past can't communicate with the future
wei = F.softmax(wei, dim=-1) # normalize
# wei is 8 x 8 i.e. at 5th token the wei is [token_1_wei, token_2_wei, token_3_wei, token_4_wei, token_5_wei, 0, 0, 0], these values are how much attention token 5 is giving to itself and past tokens
xbow3 = wei @ x # aggregate x values depending on tokens affinities/how interesting they find each other
            # the values (x) will be the embedding vectors' channels (semantic meaning of each token)

# i.e. token 5 has affinities to token 4, token 3, token 2, and token 1, particularly token 1.  Then token 5 pays more attention to token 1.  Self-attention adds more context to each token.
# xbow3 is 4x8x2, where the embedding vectors channels C=2 are a context representation (weighted aggregation) for each token - capture semantic meaning of each character in the sequence
# Can do weighted aggregation of past elements by using mat mult of a lower triangular fashion, and elements in the lower triangular part tell how much of each element fuses into this position/token

# MATH BEHIND SELF-ATTENTION SUMMARY EXAMPLE:
# wei=                                  affinities (T,T)
# tensor([[1.0000, 0.0000, 0.0000],		token 1
#         [0.5000, 0.5000, 0.0000],		token 2 pays 0.5000 of attention to itself, 0.5000 to token 1
#         [0.3333, 0.3333, 0.3333]])	token 3 pays 0.3333 attention to itself, 0.3333 to token 2, and 0.3333 to token 1

# x=                                    values (B,T,C) - embedding vectors
# tensor([[2., 7.],			            token 1 semantic meaning
#         [6., 4.],			            token 2 semantic meaning
#         [6., 5.]])			        token 3 semantic meaning

# xbow=wei @ x                          weighted aggregation (B,T,C) - contextual representation of each token
# tensor([[2.0000, 7.0000],             token 1
#         [4.0000, 5.5000],             token 2 semantic meaning given that it finds itself 0.5000 interesting, and token 2 0.5000 interesting (average of token 2 and token 1 semantic meanings)
#         [4.6667, 5.3333]])	        token 3 semantic meaning given that it pays 0.3333 attention to itself and 0.3333 to each past token (value depending on weight/affinity/interest/attention to itself and each other token)


In [4]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C) # tokens, each with identity and position as embedding vector of 32 channels

# self-attention:
# every single token at each position will emit 2 vectors: query, key
# query vector is what am I looking for
# key vector is what do I contain
# the way we get affinities between tokens is dot product between keys and queries
# my token query dot products with all the keys of all the other tokens and that dot product is wei
# if key and query is aligned, high interaction

# single head performing self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # project 32 channels of identity & position into 16 channels/dimensions (random)
query = nn.Linear(C, head_size, bias=False) # project 32 channels of identity & position into 16 channels (random)
value = nn.Linear(C, head_size, bias=False) # project 32 channels of identity & position into 16 channels (random)
k = key(x) #(B, T, 16) (TRAITS I CONTAIN)
q = query(x) #(B, T, 16) (TRAITS I AM LOOKING FOR)
v = value(x) # x into v -- private information to this token, example: 5th token has some identity in x, I will communicate v, project x into v
# 8th token has 16 query channels representing traits it wants, 4th token has 16 key channels representing its own traits, dot product of query and key leads to 8th token affinity value to 4th token 
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # (B, T, 16) @ (B, 16, T) ----> (B, T, T) for every row of B, we now have a (T,T) matrix giving us the affinities
print(wei[0])
# data dependent wei

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # need to mask so no communication with future
wei = F.softmax(wei, dim=-1) # normalize distribution to 1
print(wei[0])
out = wei @ v # (4,8,16)


tensor([[-0.4407, -0.3253,  0.1413,  0.5404, -0.2668,  0.4908,  0.2691, -0.1132],
        [-0.8334, -0.4139,  0.0260,  0.8446, -0.5456,  0.2604, -0.0139,  0.0732],
        [-0.2557, -0.3152,  0.0191, -0.0953, -0.2461, -0.3576,  0.0187, -0.2387],
        [ 0.1959, -0.2004, -0.0842, -0.2124, -0.1401, -0.2925, -0.3232, -0.2565],
        [-0.3142,  0.0047, -0.1970, -0.3301,  0.5091,  0.2160,  0.0930,  0.2314],
        [-0.0782,  0.6038, -0.0276, -0.2483,  0.8362, -0.6307,  0.3547,  0.3049],
        [ 0.2719,  0.4913, -0.0655, -0.0789,  0.1523,  0.3154, -0.1371,  0.2012],
        [-0.4511, -0.1031, -0.2077,  0.1475, -0.1997, -0.1464,  0.1608,  0.1576]],
       grad_fn=<SelectBackward0>)
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3966, 0.6034, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3069, 0.2892, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3233, 0.2175, 0.2443, 0.2149, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.14

8th token knows its identity/content and its position - creates a query based on that looking for tokens with traits, all tokens emitted keys, one key channel could be traits and that key could have a high number in that channel, so when query and key dot product, they find each other and create a high affinity

For example, let's say we have a token key vector that represents the token's traits. The key vector will have a high number in 4th channel (a vowel trait). If we have another token query vector that also has a high number in the 4th channel, the dot product between the two vectors will be higher, so there is a higher affinity

UNDER THE HOOD: SINGLE HEAD OF SELF-ATTENTION MECHANISM EXAMPLE
B,T,C=4,8,32
head_size=16
Let's look at batch 0, a sequence of 8 tokens.
---------------------------------------------------------------------------------------------------------------------------------------------------------
q=                                                                                   queries (B,T,16) -- each query is traits I am looking for
tensor([[-0.6567,  0.0283,  0.0094, -0.6995, -0.3604,  0.8376, -0.4446,  0.1228,
          0.6276, -0.6222,  0.3483,  0.2411,  0.5409, -0.2605,  0.3612, -0.0436],
        [-0.3932,  0.8220, -0.7027,  0.0954, -0.1222, -0.1518, -0.5024, -0.4636,
          0.1176,  1.4282, -0.5812,  0.1401,  0.9604,  0.0410, -0.6214, -0.6347],
        [ 0.2157, -0.3507,  0.0022,  0.4232, -0.2284, -0.0732, -0.3412,  0.9647,
         -0.5178,  0.0921, -0.5043,  0.8388,  0.6149, -0.0109, -0.5569,  0.5820],
        [ 0.9000, -0.1272,  0.5458,  0.4254, -0.4513, -0.0212,  0.1711,  0.2599,
         -0.9978,  0.4890,  0.1737, -0.0700, -0.3113,  0.3748, -0.1848, -0.6379],
        [ 0.0332,  0.5886, -0.4437,  0.3775, -0.6826, -0.2775,  0.4673, -1.2956,
          0.6603,  0.1633, -1.7573, -0.6582, -0.2302, -0.0862, -0.0060,  0.7573],
        [ 0.2098,  0.0439, -0.0702,  0.0727, -0.2012, -1.7539,  1.0369,  0.1163,
          0.2956,  0.3231,  0.5052,  0.7011, -0.2844, -0.7844,  0.4782, -0.5170],
        [ 0.6100, -0.3284, -0.8557,  0.8543,  0.7805, -0.4023, -0.8183, -0.0554,
          0.1873,  0.2706, -0.7066, -0.8637,  0.6998, -0.0670,  0.2551,  0.2149],
        [ 0.1459,  0.1349, -0.2335, -0.0417,  0.2928, -0.5080,  0.1177,  0.1861,
          0.1455,  0.0292, -0.8470,  0.6116,  1.2445,  0.1909,  0.3694, -0.0027]],   8th token query- 8th token looking for 12th channel trait, high 13th
---------------------------------------------------------------------------------------------------------------------------------------------------------
k=                                                                                   (B,T,16) keys -- each key is traits I contain
tensor([[ 0.1196, -0.3013,  0.3629,  1.1771,  1.1385, -0.2554,  0.1454, -0.2944,
         -0.7020, -1.0308,  0.7436, -0.8098, -0.6669,  0.0912, -0.0061,  0.1983],
        [-0.5423, -0.5558, -0.0761,  1.2929,  0.8653, -1.1998,  0.3878,  0.1939,
          0.7024, -0.8225,  0.2348, -0.8499, -0.3813, -0.2991,  0.0102, -0.5545],
        [-0.3736, -0.4678, -0.2156, -0.8034, -0.3715, -0.5443, -0.9146, -0.0559,
         -0.3290, -0.2102,  0.1166, -0.1798, -0.2820, -0.3320, -0.4596, -0.1325],
        [-0.3146,  0.0845, -0.1235, -0.7058, -0.1802,  0.5492, -0.8980, -0.4938,
          0.6791,  0.8827,  0.4911,  0.5190,  0.9011,  0.0913, -0.1933, -0.6770],     4th token key- high 12th channel trait, high 13th
        [ 0.0239,  0.0998, -0.1871, -0.0860, -0.4881, -1.6765,  0.2413,  0.7361,
          0.4608, -0.8722, -0.4259, -1.1347, -1.0571, -0.9401,  0.1343, -0.0157],
        [-0.2362, -0.7873, -0.3802,  0.5815, -0.3722,  1.2405, -0.7004, -1.4917,
          0.7678,  0.3584,  0.6120, -0.0794,  0.5983,  0.2635,  0.6490,  0.0709],
        [-0.7941, -0.1660, -0.2810, -0.1021, -0.7352, -0.7518, -0.1276, -0.0051,
          0.3325, -0.3374,  0.1678,  0.3105,  0.2258,  0.1243,  0.4617,  0.2016],
        [ 0.1651, -0.1599, -0.5717, -0.3957,  0.3930, -0.8567,  0.3390, -0.7977,
          0.2213, -0.5161,  0.1850, -0.2105,  0.3779,  0.0482, -0.4744, -0.0504]],
---------------------------------------------------------------------------------------------------------------------------------------------------------
wei=q @ k.transpose(-2, -1) * head_size**-0.5=                                       affinities (B,T,T)- how much attention to myself and previous tokens
8 queries dot products 8 keys gives 8 token affinities
tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],   8th token affinities -- 8th token query vector dot products 4th
                                                                                        token key vector = 0.5898 so 8th token has high affinity with 4th token
masking and softmax normalization
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],          
         [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
         [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
         [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
         [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],         8th token has 0.2391 affinity w itself, high 0.2297 affinity w 4th
---------------------------------------------------------------------------------------------------------------------------------------------------------
v=                                                                                  values (B,T,16) - token identity & position embedding vectors
tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.8321, -0.8144, -0.3242,  0.5191, -0.1252, -0.4898, -0.5287, -0.0314,
          0.1072,  0.8269,  0.8132, -0.0271,  0.4775,  0.4980, -0.1377,  1.4025],
        [ 0.6035, -0.2500, -0.6159,  0.4068,  0.3328, -0.3910,  0.1312,  0.2172,
         -0.1299, -0.8828,  0.1724,  0.4652, -0.4271, -0.0768, -0.2852,  1.3875],
        [ 0.6657, -0.7096, -0.6099,  0.4348,  0.8975, -0.9298,  0.0683,  0.1863,
          0.5400,  0.2427, -0.6923,  0.4977,  0.4850,  0.6608,  0.8767,  0.0746],
        [ 0.1536,  1.0439,  0.8457,  0.2388,  0.3005,  1.0516,  0.7637,  0.4517,
         -0.7426, -1.4395, -0.4941, -0.3709, -1.1819,  0.1000, -0.1806,  0.5129],
        [-0.8920,  0.0578, -0.3350,  0.8477,  0.3876,  0.1664, -0.4587, -0.5974,
          0.4961,  0.6548,  0.0548,  0.9468,  0.4511,  0.1200,  1.0573, -0.2257],
        [-0.4849,  0.1655, -0.2221, -0.1345, -0.0864, -0.6628, -0.0936,  0.1050,
         -0.2612,  0.1854,  0.3171, -0.1393,  0.5486, -0.4086, -0.3851,  0.7106],
        [ 0.2042,  0.3772, -1.1255,  0.3995,  0.1489,  0.3590, -0.1791,  1.3732,
          0.1588, -0.2320,  0.1651,  0.7604,  0.3521, -1.0864, -0.7939, -0.3025]],  8th token semantic meaning/identity & position
---------------------------------------------------------------------------------------------------------------------------------------------------------
out = wei @ v                                                                       (B,T,16) weighted aggregation - contextual rep of each token
semantic meaning & position of tokens given the weights (value depending on affinity with itself and each other token - attention)
tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.6764, -0.5477, -0.2478,  0.3143, -0.1280, -0.2952, -0.4296, -0.1089,
         -0.0493,  0.7268,  0.7130, -0.1164,  0.3266,  0.3431, -0.0710,  1.2716],
        [ 0.4823, -0.1069, -0.4055,  0.1770,  0.1581, -0.1697,  0.0162,  0.0215,
         -0.2490, -0.3773,  0.2787,  0.1629, -0.2895, -0.0676, -0.1416,  1.2194],
        [ 0.1971,  0.2856, -0.1303, -0.2655,  0.0668,  0.1954,  0.0281, -0.2451,
         -0.4647,  0.0693,  0.1528, -0.2032, -0.2479, -0.1621,  0.1947,  0.7678],
        [ 0.2510,  0.7346,  0.5939,  0.2516,  0.2606,  0.7582,  0.5595,  0.3539,
         -0.5934, -1.0807, -0.3111, -0.2781, -0.9054,  0.1318, -0.1382,  0.6371],
        [ 0.3428,  0.4960,  0.4725,  0.3028,  0.1844,  0.5814,  0.3824,  0.2952,
         -0.4897, -0.7705, -0.1172, -0.2541, -0.6892,  0.1979, -0.1513,  0.7666],
        [ 0.1866, -0.0964, -0.1430,  0.3059,  0.0834, -0.0069, -0.2047, -0.1535,
         -0.0762,  0.3269,  0.3090,  0.0766,  0.0992,  0.1656,  0.1975,  0.7625],
        [ 0.1301, -0.0328, -0.4965,  0.2865,  0.2704, -0.2636, -0.0738,  0.3786,
          0.0746,  0.0338,  0.0147,  0.3194,  0.2993, -0.1653, -0.0386,  0.3375]],  8th token contextual representation -- meaning given communication w
                                                                                        sequence tokens

Notes:

Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
Each example across batch dimension is of course processed completely independently and never "talk" to each other
In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
"self-attention" just means that the keys and values are produced from the same source (x) as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module in encoder-decoder transformer) for context we'd like to condition on
"Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much.
Multi-head attention is multiple attentions in parallel and concatenating the results

In [9]:
# Final Transformer Script
import torch
import torch.nn as nn
from torch.nn import functional as F

batch_size = 16 # number of sequences/chunks B
block_size = 32  # 32 characters, 32 contexts, 32 targets - time dimension T
n_embd = 64     # C
learning_rate = 1e-3
n_head = 4 # head_size will be n_embd // 4 = 16 dimensional
n_layer = 4
dropout = 0.0
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(1337)

# read in
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# 65 unique chars/tokens/iints
chars = sorted(list(set(text)))
vocab_size = len(chars)

# tokenize string to integer according to vocabulary, store in torch.tensor
stoi = { ch:i for i,ch in enumerate(chars) } # mapping from unique chars to ints
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: stoi
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: itos
data = torch.tensor(encode(text), dtype=torch.long)
# tensor is multi-dim mat/array [], dimensions such as batch size, sequence length (time dimension), embedding size

# split data into train (90%) and validation (10%) sets
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# plug in training data into transformer - train on random batches of chunks/blocks of train_data
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # generate 4 random chunk positions
    x = torch.stack([data[i:i+block_size] for i in ix]) # contexts
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # targets
    x, y = x.to(device), y.to(device)
    return x, y # returns 4 batches of chunks of 8 tokens/chars for inputs and targets

# xb input tensor is fed into transformer, transformer processes contexts, looks up correct target to predict in yb

class Head(nn.Module):
    ''' one head of self-attention '''

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False) # project n_embd channels of identity & position into head_size channels/dimensions (random)
        self.query = nn.Linear(n_embd, head_size, bias=False) 
        self.value = nn.Linear(n_embd, head_size, bias=False) 
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape   # C is n_embd
        k = self.key(x)   # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,head_size)
        out = wei @ v # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concatenates back to n_embd C
        out = self.dropout(self.proj(out)) # back into residual pathway
        return out

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # projection layer going back into the residual pathway
            nn.Dropout(dropout), # prevent overfitting
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """
    # constructor
    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, h_head: number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head # 4 16-dimensional self-attention heads
        self.sa = MultiHeadAttention(n_head, head_size) # sa and concatenates back to n_embd C
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    
    def forward(self, x):
        # residual connections - fork off and come back:
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# Feed into neural network: bigram language model
# inputs (context indices/characters):
# tensor([[57, 43, 60, 43, 52,  1, 63, 43],
#        [60, 43, 42,  8,  0, 25, 63,  1],
#        [56, 42,  5, 57,  1, 57, 39, 49],
#        [43, 57, 58, 63,  6,  1, 58, 46]])

class BigramLanguageModel(nn.Module):
# Bigram model predicts likelihood of a next token/word in a sequence based on previous word/token - depends only on the immediate previous one

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        # ex: 57 will pluck out 57th row (65-dimensional embedding vector)
        # each int token/char is represented by a 65-dimensional embedding vector [-0.2442, 0.1703, ...] where each channel/dimension of the vector represents the score for the next token - hence why we need 65 possible channels for 65 possible next tokens
        # the embedding vector is not the semantic meaning of the char in this case, but is rather the scores/predictions/logits of all possible next chars - these scores can be converted to a prob distr which is the predictions assigned to each label/class/token, the target is the ground truth label/class/token/index
        # self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # 65x65 lookup table -- 65 unique tokens x 65 embedding channels/dimensions/next_token_scores
    
        # SELF-ATTENTION VERSION:
        # for self-attention implementation, need a level of indirection, the embedding vector is not directly the scores/logits of next chars, is the semantic meaning
        # n_embd is number of embedding dimensions
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) #65xn_embd lookup table, will be updated during training
        # to go from token embeddings to logits, need a linear layer
        self.lm_head = nn.Linear(n_embd, vocab_size) # converts 16x32x64 token embeddings to 16x32x65 logits
        # don't encode just the identities of idx tokens, but also position
        self.position_embedding_table = nn.Embedding(block_size, n_embd) # 32x64 position table, updated during training
        #self.sa_head = Head(n_embd)
        #self.sa_heads = MultiHeadAttention(4, n_embd//4) # i.e. 4 16-dimensional self-attention heads concatenates back to n_embd C
        #self.ffwd = FeedForward(n_embd)
        #self.blocks = nn.Sequential(
        #    Block(n_embd, n_head=4),
        #    Block(n_embd, n_head=4),
        #    Block(n_embd, n_head=4),
        #    nn.LayerNorm(n_embd),
        #)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers

        #logits = self.token_embedding_table(idx) # (B,T,C)

        # SELF-ATTENTION VERSION:
        tok_emb = self.token_embedding_table(idx) # (B,T,n_embd C) 16x32x64 no longer logits, but token embeddings - each token has embedding vector of n_embd channels
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C) x holds token identities and positions where they occur
        #x = self.sa_heads(x) # apply heads of self-attention (B,T,head_size)
        #x = self.ffwd(x) # (B,T,C) layer for thinking on self-attended data
        x = self.blocks(x) # (B,T,n_embd C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size) converts the token embedding vectors and positions to scores/logits

        # evaluate the loss between the predicted labels (probability distribution) and the target labels (ground truth)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss # scores for next character in the sequence

    # generate text up to max_new_tokens, continues generation of new tokens (8+1+2+3...max_new_tokens) in time dimension in all 4 batch dimensions:
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context (current batch of inputs)
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond) # don't pass targets and no need to evaluate loss here in generate(), since loss is calculated in the forward function DURING TRAINING
            # focus only on the last time step (last element in T dimension)
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities (convert logits to prob distr)
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # predicts one idx_next sample index for each T dimension (B, 1)
            # append sampled index to the running sequence, whatever is predicted is concatenated on top of the previous idx along the time dimension
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
        
model = BigramLanguageModel() # construct lm
# Ridiculous old version: right now we are feeding in the entire sequence, but since this is a bigram model, we only need the immediate previous token to make a prediction
#   later, characters will look further back in the history sequence: self attention

# Instead of printing loss for each batch, estimate average loss
eval_iters = 200
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

### Train the model:
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
# create a PyTorch optimizer - takes the gradients and updates the parameters (weights/channels/scores of the embedding vectors in the embedding table) using the gradients
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # stochastic gradient descent is simplest optimizer, but let's use AdamW, set learning rate to 1e-3
max_iters = 5000
eval_interval = 100
for iter in range(max_iters): # increase number of steps for good results... 
    
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"epoch {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb) # loss given the targets
    optimizer.zero_grad(set_to_none=True) # clear previous iteration gradients
    loss.backward() # loss is backpropagated to compute the gradients of the parameters w/respect to the loss
    optimizer.step() # update the parameters

# print(loss.item()) --each individual batch loss is noisy, so we need to average up losses using estimate_loss to get a better final value for loss

### generate text from the model:
idx = torch.zeros((1, 1), dtype=torch.long) # 0 (newline character) is how we kick of the generation B=1, T=1
print(decode(m.generate(idx, max_new_tokens=500)[0].tolist())) # need to index into [0] row to pluck the single batch dimension (array of indices) generated

0.209729 M parameters
epoch 0: train loss 4.4399, val loss 4.4423
epoch 100: train loss 2.6634, val loss 2.6812
epoch 200: train loss 2.5023, val loss 2.4999
epoch 300: train loss 2.3979, val loss 2.4116
epoch 400: train loss 2.3234, val loss 2.3312
epoch 500: train loss 2.2713, val loss 2.2845
epoch 600: train loss 2.2138, val loss 2.2239
epoch 700: train loss 2.1722, val loss 2.1902
epoch 800: train loss 2.1308, val loss 2.1587
epoch 900: train loss 2.0940, val loss 2.1309
epoch 1000: train loss 2.0682, val loss 2.1032
epoch 1100: train loss 2.0416, val loss 2.1022
epoch 1200: train loss 2.0047, val loss 2.0534
epoch 1300: train loss 1.9934, val loss 2.0460
epoch 1400: train loss 1.9581, val loss 2.0193
epoch 1500: train loss 1.9422, val loss 2.0139
epoch 1600: train loss 1.9211, val loss 2.0184
epoch 1700: train loss 1.9046, val loss 1.9889
epoch 1800: train loss 1.8716, val loss 1.9805
epoch 1900: train loss 1.8725, val loss 1.9623
epoch 2000: train loss 1.8432, val loss 1.9675
epo

Side Notes:

The nn.Linear module in PyTorch takes a tensor of input features and applies a linear transformation to it, resulting in a tensor of output features. The linear transformation is a matrix multiplication, where the input features are multiplied by a weight matrix and then added to a bias term.

In the case of your token embeddings, the input features are the 32-dimensional embedding vector for each token. The weight matrix has 32 rows (one for each embedding dimension) and 65 columns (one for each possible token). The bias term has 65 elements.

The linear transformation is then applied to each of the 8 token embeddings, resulting in a tensor of 8 65-dimensional vectors. These vectors represent the scores for the next token in the sequence, for each of the 8 possible tokens.

The output of the nn.Linear module is a tensor of shape (4, 8, 65), which is the same shape as the input token embeddings, but with 65 output features instead of 32.