<a href="https://colab.research.google.com/github/joshuwaifo/A-Bible-Pre-trained-Transformer-Model/blob/main/Affinity_SelfAttention_Scaled_BibleGPT_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Previously on BibleGPT 1-4

15 seconds run time CPU Google Colab

In [2]:
!wget https://raw.githubusercontent.com/tushortz/variety-bible-text/master/bibles/nasb.txt

import torch
import torch.nn as nn
from torch.nn import functional as F

batch_size = 32
block_size = 8
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('nasb.txt', 'r', encoding='utf-8') as f:
  text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n_train = int(0.64*len(data))
n_val = int(0.8*len(data))
train_data = data[:n_train]
val_data = data[n_train:n_val]
test_data = data[n_val:]

def get_batch(split):
  data = train_data if split == 'train' else val_data if split == 'val' else test_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x, y = x.to(device), y.to(device)
  return x,y

@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      logits, loss = model(X, Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

class BigramLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)


  def forward(self, idx, targets=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, loss = self(idx)
      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx

model = BigramLanguageModel()
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

  if iter % eval_interval == 0:
    losses = estimate_loss()

  xb, yb = get_batch('train')

  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()

context = torch.zeros((1,1), dtype=torch.long, device=device)


torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)
wei = q @ k.transpose(-2, -1)


tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x


--2024-08-09 04:43:54--  https://raw.githubusercontent.com/tushortz/variety-bible-text/master/bibles/nasb.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4685837 (4.5M) [text/plain]
Saving to: ‘nasb.txt.1’


2024-08-09 04:43:55 (16.9 MB/s) - ‘nasb.txt.1’ saved [4685837/4685837]



Today on BibleGPT 5

In [3]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))

# initialisation done and resolved now
# wei = torch.zeros((T,T))

# this is masking
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ x

For every row of batch elements, we're going to now have a t-square matrix giving us the affinities

These now become the weights

The weighted aggregation is now a function in a data dependent manner between the keys and the queries of these nodes/tokens

In [4]:
# now every single batch element has different types of weights being aggregated
wei

# this is because every single batch element contains
# different tokens at different positions

# can deduce which tokens had high affinity to which other token
# for example in the first batch
# token 2 -> token 7(see 7th row)
# token 4 -> token 8 (see 8th row)

# remember that the direction is past/current -> current
# not current <- future

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
         [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
         [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
         [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1687, 0.8313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2477, 0.0514, 0.7008, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4410, 0.0957, 0.3747, 0.0887, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0069, 0.0456, 0.0300, 0.7748, 0.1427, 0.0000, 0.0000, 0.0000],
         [0.0660, 0.089

Let's see what the weights look like without masking and softmax

In [5]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))

# initialisation done and resolved now
# wei = torch.zeros((T,T))

# this is masking
# let's try and erase it
# wei = wei.masked_fill(tril == 0, float('-inf'))
# wei = F.softmax(wei, dim=-1)
out = wei @ x

outputs of the dot products

In [7]:
# raw outputs
# take on values from -inf to +inf
# raw interaction and raw affinities between all the nodes/tokens
wei[0]

# I would be interested to see how well this works when explaining models

tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5899, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [8]:
# we use the upper triangular masking to prevent communication between future tokens and the current token
wei = wei.masked_fill(tril == 0, float('-inf'))
wei[0]

tensor([[[-1.7629e+00,        -inf,        -inf,        -inf,        -inf,
                 -inf,        -inf,        -inf],
         [-3.3334e+00, -1.6556e+00,        -inf,        -inf,        -inf,
                 -inf,        -inf,        -inf],
         [-1.0226e+00, -1.2606e+00,  7.6228e-02,        -inf,        -inf,
                 -inf,        -inf,        -inf],
         [ 7.8359e-01, -8.0143e-01, -3.3680e-01, -8.4963e-01,        -inf,
                 -inf,        -inf,        -inf],
         [-1.2566e+00,  1.8719e-02, -7.8797e-01, -1.3204e+00,  2.0363e+00,
                 -inf,        -inf,        -inf],
         [-3.1262e-01,  2.4152e+00, -1.1058e-01, -9.9305e-01,  3.3449e+00,
          -2.5229e+00,        -inf,        -inf],
         [ 1.0876e+00,  1.9652e+00, -2.6213e-01, -3.1579e-01,  6.0905e-01,
           1.2616e+00, -5.4841e-01,        -inf],
         [-1.8044e+00, -4.1260e-01, -8.3061e-01,  5.8985e-01, -7.9869e-01,
          -5.8560e-01,  6.4332e-01,  6.3028e-01]],

In [10]:
# now we want to have a nice distribution
# so we exponentiate and normalise by using softmax
# this results in a nice distribution that sums to 1 and between 0 and 1 for every element
wei = F.softmax(wei, dim=-1)
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

This tells us in a data dependent manner how much of the information to aggregate from any of these tokens in the past

One more part to a single attention head

When we do the aggregation, we produce one more value before finishing

we call that the value and is produced in a similar way to the way the key and query is produced


In [None]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))


wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

# instead of aggregating x
# out = wei @ x

# calculate a v, achieved by propagating the value linear map
# v is the vector that we aggregrate instead of the raw x now
v = value(x)

# this then makes the output of the single head 4 by 8 by 16 (head_size) instead of 4 by 8 by 32
out = wei @ v
out.shape
#

In [11]:
# think of x as a private information to this token
# information is "kept" in vector x

# analogy:
# here's what I have
# if you find me interesting
# here's what I will communicate to you (stored in v)

# v is the thing that gets aggregated

Notes:

Attention is a communication mechanism

You have a number of nodes in a directed graph

Every node has some vector of information

It gets to aggregate information via a weighted sum (from all the nodes that point to it)

This is done in a data dependent manner

So depending on whatever data is actually stored at each node at any point in time

Graph here looks like:

8 nodes (because the block size is 8 and always 8 tokens)

First node only pointed to by itself

Second node pointed to by itself and the first node and so on

In [None]:
# In principle, attention can be applied to any arbitrary directed graph


Notes:

No notion of space

Attention simply acts over a set of vectors in this graph

By default these nodes have no idea where they are positioned in a space

This is why we need to encode them positionally (give them sort of some information that is anchored to a specific position, so they sort of know where they are)

This differs from convolution (layout of space intuitively included)

Attention, set of vectors out there in space that communicate with each other and if you want them to have a notion of space it needs to be specifically added (as we have done with the positional encoding)



In [None]:
# Elements across batch dimension which are independent examples
# Never talk to each other
# Batched matrix multiply
# apply the matrix multiplication in parallel across the batch dimension

# think of it batch being a separate pool where each pool has eight nodes in this example

Notes:

In the case of language modelling, have this specific structure of directed graph

Where future tokens will not communicate to the past tokens

But this doesn't have to be the case

In some cases you may want to have all the nodes talk to each other

For example in sentiment analysis with a transformer

In that case we would use an encoder block of self-attention. All this means is that there is no triangular matrix masking

This allows all the nodes to completely talk to each other (encoder block of self-attention)

In [None]:
# For BibleGPT I'd like to see the effect of not having the masking


In [None]:
# self attention types
# with masking: decoder block (implemented here)
# no masking: encoder block

# it's called decoder as its decoding language (recovering hidden language)
# in a autoregressive format
# so that nodes from the future never talk to the past
# so they never give away the answer

Attention vs Self-Attention vs Cross-attention

Self-attention: keys, queries and values are all coming from the same source (x)

So these nodes are "self-attending"

In principle, attention is more general than that

For example in the encoder-decoder transformer

queries are produced from x

but keys and (to be clarified, values) come from an external source

Sometimes even from encoder blocks that encode some context that we'd like to condition on


You can think of it as nodes on the side producing queries

And we're reading off information from the side

Cross attention is where there is a separate source of nodes that we'd like to pull information from into our nodes



In [None]:
# attention is all you need paper 2017
# dividing by square root of the head size d_k
# called scaled attention

k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)

wei = q @ k.transpose(-2, -1) * head_size**0.5
# important normalisation

# context:
# if you have unit gaussian input
# ie zero mean, unit variance

# important as wei feeds into softmax
# so at initialisation, wei(ghts) need to be fairly diffused

# it diffuses the values of the weights/logits before the softmax operation
# preventing softmax from converging to 1-hot vectors

# diffuse (getting closer to zero) versus sharpening effect (getting away from zero)

# prevents values from being too extreme (the normalisation)

# this in practice prevents every node just aggregating information from a single node
# scaling is used to control the variance at initialisation

Let's fully integrate a single-self attention block to the network

In [None]:
# Head Module

class Head(nn.Module):
  """ one head of self-attention """

  # give it a head size
  def __init__(self, head_size):
    super().__init__()

    # create the key, query and value linear layers
    # typically people don't use biases in these
    # these serve as the linear projections that we are going to apply to all of our nodes/tokens
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)

    # create a tril variable
    # this is not a parameter of the module
    # in python convention (pythonic) this is called a buffer
    # have to call and assign it to the module using register buffer
    # this creates the lower triangular matrix
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))



  # when we are given the input x
  def forward(self, x):
    B,T,C = x.shape

    # calculate the keys and queryies
    k = self.key(x) # (B,T,C)
    q = self.key(x) # (B,T,C)

    # compute attention scores ("affinities")
    # normalise it aka scaled attention
    wei = q @ k.transpose(-2, -1) * C**0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)

    # make sure future doesn't communicate with the present
    # thereby making it a decoder block by applying masking
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)

    # softmax
    wei = F.softmax(wei, dim=-1) # (B, T, T)

    # aggregate
    # perform the weighted aggregation of the values
    v = self.value(x) # (B, T, C)
    out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
    return out


# Updated Language Model
class BigramLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    # create a head in the constructor
    # self attention head
    self.sa_head = Head(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)


  def forward(self, idx, targets=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb

    # once we've encoded the information with the token embeddings and position embeddings

    # feed in the self-attention head
    x = self.sa_head(x) # apply one head of self-attention. (B,T,C)

    # output goes into the decoder language modelling head and creates the logits
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # have to make sure that the idx that we feed into the model
    # due to now using a positional embedding
    # we can never have more than block size coming in
    # because if idx is more than block size
    # then the position embedding table is going to run out of scope
    # because it only has embeddings up to block size
    # add some code to crop the context

    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # crop idx to the last block_size tokens
      idx_cond = idx[:, -block_size:]
      # feed into self so we never pass in more than block size elements
      # and get the predictions
      logits, loss = self(idx_cond)
      # focus only on the last time step
      logits = logits[:, -1, :] # becomes (B, C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B, C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx



More helpful for me to think of the order as query -> key -> value for ease of intuition

I think of questions: query

I try to get all the possible answer: key

I use the similarility between the answers and questions to improve the result via weighted aggregation: query


In [None]:
# Simplest way above to plug in the self-attention component into the network

# let's train the network



!wget https://raw.githubusercontent.com/tushortz/variety-bible-text/master/bibles/nasb.txt

import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

# increased the number of iterations

max_iters = 5000
eval_interval = 300
# decreased the learning rate
# as the self attention can't tolerate very very high learning rates
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

with open('nasb.txt', 'r', encoding='utf-8') as f:
  text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(text), dtype=torch.long)
n_train = int(0.64*len(data))
n_val = int(0.8*len(data))
train_data = data[:n_train]
val_data = data[n_train:n_val]
test_data = data[n_val:]

def get_batch(split):
  data = train_data if split == 'train' else val_data if split == 'val' else test_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  x, y = x.to(device), y.to(device)
  return x,y

@torch.no_grad()
def estimate_loss():
  out = {}
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      logits, loss = model(X, Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out


# Head Module

class Head(nn.Module):
  """ one head of self-attention """

  # give it a head size
  def __init__(self, head_size):
    super().__init__()

    # create the key, query and value linear layers
    # typically people don't use biases in these
    # these serve as the linear projections that we are going to apply to all of our nodes/tokens
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)

    # create a tril variable
    # this is not a parameter of the module
    # in python convention (pythonic) this is called a buffer
    # have to call and assign it to the module using register buffer
    # this creates the lower triangular matrix
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))



  # when we are given the input x
  def forward(self, x):
    B,T,C = x.shape

    # calculate the keys and queryies
    k = self.key(x) # (B,T,C)
    q = self.key(x) # (B,T,C)

    # compute attention scores ("affinities")
    # normalise it aka scaled attention
    wei = q @ k.transpose(-2, -1) * C**0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)

    # make sure future doesn't communicate with the present
    # thereby making it a decoder block by applying masking
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)

    # softmax
    wei = F.softmax(wei, dim=-1) # (B, T, T)

    # aggregate
    # perform the weighted aggregation of the values
    v = self.value(x) # (B, T, C)
    out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
    return out


# Updated Language Model
class BigramLanguageModel(nn.Module):

  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)
    # create a head in the constructor
    # self attention head
    self.sa_head = Head(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)


  def forward(self, idx, targets=None):
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb

    # once we've encoded the information with the token embeddings and position embeddings

    # feed in the self-attention head
    x = self.sa_head(x) # apply one head of self-attention. (B,T,C)

    # output goes into the decoder language modelling head and creates the logits
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # have to make sure that the idx that we feed into the model
    # due to now using a positional embedding
    # we can never have more than block size coming in
    # because if idx is more than block size
    # then the position embedding table is going to run out of scope
    # because it only has embeddings up to block size
    # add some code to crop the context

    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # crop idx to the last block_size tokens
      idx_cond = idx[:, -block_size:]
      # feed into self so we never pass in more than block size elements
      # and get the predictions
      logits, loss = self(idx_cond)
      # focus only on the last time step
      logits = logits[:, -1, :] # becomes (B, C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B, C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx


model = BigramLanguageModel()
m = model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

  if iter % eval_interval == 0:
    losses = estimate_loss()

  xb, yb = get_batch('train')

  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()

context = torch.zeros((1,1), dtype=torch.long, device=device)
