# Transformer Notes
This notebook will be based off of [this video](https://www.youtube.com/watch?v=kCc8FmEb1nY) which goes into depth on how to build the exact base model I need for this project. I will be copy pasting a lot of his work and annotating it to help myself understand the process of making the Transformer work.

In [3]:
# Imports
from collections import Counter

# pytorch functionality
import torch
from torch import Tensor

# data
from torchtext.vocab import vocab, Vocab
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer

device = "cpu"

## Get Data
For this example I am going to be using the IMDB dataset but the work/model should be generalizable to all text.

In [4]:
train_iter, val_iter, test_iter = WikiText2()

### Tokenization 

For now I'm going to keep my tokenizer very simple. You can use a multitude of techniques for tokenizing your corpus. Here is a [library](https://github.com/openai/tiktoken) worth looking into at some point.

We are going to be getting very long sequences but small token spaces. This can be changed with this library.

#### Some Helper Functions for Data

In [5]:
def build_vocab(in_data, tokenizer):
  counter = Counter()
  for string in in_data:
    counter.update(tokenizer(string))

  return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])

def data_process(in_data, tokenizer, vocab: Vocab):
  raw_iter = iter(in_data)
  data = []
  for raw in raw_iter:
    tensor = torch.tensor([vocab[token] for token in tokenizer(raw)], dtype=torch.long)
    data.append(tensor)
    
  return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

##### Define Tokenizer

I am going to be doing something slightly different to the video. I'm choosing to use the provided torch tokenizer for words rather than doing it char by char. Torch's tools support this kind of work more but it will require some slight adjustments to the work done. 

In [6]:
tokenizer = get_tokenizer("basic_english")

#### Build Vocabulary

In [7]:
vocab = build_vocab(train_iter, tokenizer)

print("Train Vocab Size:", len(vocab))

Train Vocab Size: 28785


Short Example for how encoding and decoding works with Vocab Object

In [8]:
encoded_word = vocab.get_stoi()["there"]
decoded_word = vocab.get_itos()[encoded_word]

print("'There' after encoding:", encoded_word)
print("'There' after decoding", decoded_word)

'There' after encoding: 248
'There' after decoding there


### Convert Data to Tensor Format
Using above data_process function, build a torch tensor representation based on the vocab

In [9]:
train_data = data_process(train_iter, tokenizer, vocab)
val_data = data_process(val_iter, tokenizer, vocab)
test_data = data_process(test_iter, tokenizer, vocab)

print("Training Data Shape and Type:", train_data.shape, train_data.dtype)
print("Validation Data Shape and Type:", val_data.shape, val_data.dtype)
print("Testing Data Shape and Type:", test_data.shape, test_data.dtype)

Training Data Shape and Type: torch.Size([2049990]) torch.int64
Validation Data Shape and Type: torch.Size([214417]) torch.int64
Testing Data Shape and Type: torch.Size([241859]) torch.int64


#### Example show how src and tgt work
The tgt of a given src index i should be i+1 in the target tensor. There should be an offset of 1.

In [10]:
block_size = 15
train_data[:block_size+1]

tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10,  0,  6, 11, 12, 13, 14, 15])

In [11]:
src = train_data[:block_size]
tgt = train_data[1:block_size+1]
for t in range(block_size):
    context = src[:t+1]
    target = tgt[t]
    print(f"when input is {context} the target is: {target}")

when input is tensor([4]) the target is: 5
when input is tensor([4, 5]) the target is: 6
when input is tensor([4, 5, 6]) the target is: 7
when input is tensor([4, 5, 6, 7]) the target is: 4
when input is tensor([4, 5, 6, 7, 4]) the target is: 8
when input is tensor([4, 5, 6, 7, 4, 8]) the target is: 9
when input is tensor([4, 5, 6, 7, 4, 8, 9]) the target is: 5
when input is tensor([4, 5, 6, 7, 4, 8, 9, 5]) the target is: 10
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10]) the target is: 0
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10,  0]) the target is: 6
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10,  0,  6]) the target is: 11
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10,  0,  6, 11]) the target is: 12
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10,  0,  6, 11, 12]) the target is: 13
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  5, 10,  0,  6, 11, 12, 13]) the target is: 14
when input is tensor([ 4,  5,  6,  7,  4,  8,  9,  

### Batch The Data

In [12]:
batch_size = 4 # how many independant sequences will we process in parallel
block_size = 8 # what is the maximum context length for predictions.

def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1: i + block_size + 1] for i in ix])

    # x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch("train")
print("inputs")
print(xb.shape)
print(xb)
print("targets")
print(yb.shape)
print(yb)

inputs
torch.Size([4, 8])
tensor([[   48,   131,  1696,    45,     0,    14,  1650, 12931],
        [   16,    18, 21257, 10220,    17,    18,  5553,  4046],
        [ 2887, 19396,  8651,    14,    38,   207,  5268,    23],
        [  119,  8908, 23126,    14,   310,  2597, 23564,    23]])
targets
torch.Size([4, 8])
tensor([[  131,  1696,    45,     0,    14,  1650, 12931,    14],
        [   18, 21257, 10220,    17,    18,  5553,  4046, 20663],
        [19396,  8651,    14,    38,   207,  5268,    23,  3724],
        [ 8908, 23126,    14,   310,  2597, 23564,    23,   273]])


In [13]:
for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context} the target is: {target}")

when input is tensor([48]) the target is: 131
when input is tensor([ 48, 131]) the target is: 1696
when input is tensor([  48,  131, 1696]) the target is: 45
when input is tensor([  48,  131, 1696,   45]) the target is: 0
when input is tensor([  48,  131, 1696,   45,    0]) the target is: 14
when input is tensor([  48,  131, 1696,   45,    0,   14]) the target is: 1650
when input is tensor([  48,  131, 1696,   45,    0,   14, 1650]) the target is: 12931
when input is tensor([   48,   131,  1696,    45,     0,    14,  1650, 12931]) the target is: 14
when input is tensor([16]) the target is: 18
when input is tensor([16, 18]) the target is: 21257
when input is tensor([   16,    18, 21257]) the target is: 10220
when input is tensor([   16,    18, 21257, 10220]) the target is: 17
when input is tensor([   16,    18, 21257, 10220,    17]) the target is: 18
when input is tensor([   16,    18, 21257, 10220,    17,    18]) the target is: 5553
when input is tensor([   16,    18, 21257, 10220,    

## Feed Data Into a Mode

### Bigram Language Model
For understanding purposes and to get something working quickly, we're going to use a basic Bigram language model to work with our data. We'll build something bigger later.

In [14]:
import torch
import torch.nn as nn
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
  def __init__(self, vocab_size) -> None:
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    num_embedding_dim = 32 
    self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=num_embedding_dim)
    self.lm_head = nn.Linear(num_embedding_dim, vocab_size)
    self.position_embedding_table = nn.Embedding(block_size, num_embedding_dim)

  def forward(self, idx, targets=None):
    token_embeddings = self.token_embedding_table(idx) # (batch by time by channel) tensor
    pos_embeddings = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
    x = token_embeddings + pos_embeddings
    logits = self.lm_head(x) # (B, T, vocab_size)

    if targets == None:
      loss = None
    else:
      # we need to reshape our logits because the loss function expects (B by C by T)
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)

      loss = F.cross_entropy(logits, targets)

    return logits, loss # scores for next char sequence 
  
  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of inidices in the current context
    for _ in range(max_new_tokens):
      logits, loss = self(idx)          # get the predictions 
      logits = logits[:, -1, :]         # focus only on last time step
      probs = F.softmax(logits, dim=-1) # apply softmax to get probabilities
      idx_next = torch.multinomial(probs, num_samples=1) # sample from the distrabution
      idx = torch.cat((idx, idx_next), dim=1)
    
    return idx

In [15]:
vocab_size = len(vocab)

model = BigramLanguageModel(vocab_size) #.to(device)

logits, loss = model(xb, yb)
print(logits.shape) # (batch_size*block_size, number_of_tokens)
print(loss)

torch.Size([32, 28785])
tensor(10.8159, grad_fn=<NllLossBackward0>)


#### Generate Text
We will use the generate function to make 100 tokens starting with 0th token. Note that this will make garbage due to the model being entirley random.

In [16]:
# create 1 by 1 tensor which holds a zero 
idx = torch.zeros((1,1), dtype=torch.long) 

# kick off the generation with a zero
gen_idx = model.generate(idx, max_new_tokens=100)[0].tolist() # create list of words
decoded_gen_idx = [vocab.get_itos()[item] for item in gen_idx]
print(decoded_gen_idx)

['<unk>', 'trumpet', 'asphalt', 'exploits', 'org', 'litigators', 'condemned', 'birthplace', 'periodically', 'idyllwild', 'harmonix', 'nominated', 'wet', 'monstrous', 'hugs', 'estate', 'schlesinger', 'conceal', 'appears', 'commandment', 'rolling', 'sunshine', 'kubica', 'a320', 'dodge', 'shirts', 'mechanism', 'sloping', 'peat', 'coaster', 'collections', 'replayed', 'freely', 'determines', 'brenton', 'proves', 'conformist', 'seniors', 'bloody', 'neighborhood', 'sexually', 'command', 'coups', 'twinned', 'relatable', 'desktop', 'humphrey', 'wallabies', 'speculator', 'exit', 'declaration', 'chaotic', 'moravec', 'hunwick', 'est', 'commitment', 'smashed', 'weapons', 'essentially', 'dynamite', 'schopenhauer', 'beatty', '1518', 'winchester', 'themed', 'make', 'halliwell', 'void', 'virampattinam', 'obo', 'right', 'kesselring', 'cooperate', 'vipers', 'garter', 'dumps', 'smackdown', 'inaction', 'shimitsu', 'robbery', 'pleading', 'predications', 'criticisms', 'developer', 'hovers', 'italo', 'midnigh

In [17]:
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

In [18]:
batch_size = 32
for steps in range(1):
    # sample a batch of training data
    xb, yb = get_batch("train")

    # eval the loss 
    logits, loss = model(xb, yb)
    # opt.zero_grad(set_to_none=True)
    # loss.backward()
    # opt.step()

    # print(loss.item())

In [19]:
# @torch.no_grad() # tells pytorch nothing here will need to be called .backward() on
# def estimate_loss():
#   out = {}
#   model.eval()
#   for split in ["train", "val"]:
#     losses = torch.zeros(eval_iters)
#     for k in range(eval_iters):
#       X, y = get_batch(split)
#       logits, lofss = model(X, y)
#       losses[k] = loss.item()
#     out[split] = losses.mean()
#   model.train()
#   return out

## The Mathematical Trick in Self-Attention
There should be some "communication" between the ith token and all previous tokens (the attention component)

In [20]:
B,T,C = 4, 8, 2 # batch, time component, channels (some info at each point in sequence)
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

Below is printed out X and the averaged x_bag_of_words. 
- x[0] is the 0th batch element

The first row of x and the bag of words will be the same. for each bag of words row after, the values will be the average of each above row in x.

Version 1:

In [26]:
x_bag_of_words = torch.zeros((B,T,C))
for b in range(B):
  for time in range(T):
    x_prev = x[b, :time+1] # (t, C) 
    x_bag_of_words[b,time] = torch.mean(x_prev, 0)

print("X:")
print(x[0])
print("X bag of words")
print(x_bag_of_words[0])

X:
tensor([[-0.6284,  1.9393],
        [ 0.9123, -0.6665],
        [-1.2152, -0.8152],
        [-0.5096,  0.6938],
        [ 0.7681, -1.3212],
        [ 0.0314, -1.0469],
        [-1.1389, -0.2334],
        [ 0.0226,  0.8392]])
X bag of words
tensor([[-0.6284,  1.9393],
        [ 0.1419,  0.6364],
        [-0.3105,  0.1526],
        [-0.3602,  0.2879],
        [-0.1346, -0.0339],
        [-0.1069, -0.2028],
        [-0.2543, -0.2072],
        [-0.2197, -0.0764]])


Version 2: We can do this more efficiently with matrix multiplication. We are doing weighted sums to get a triangular shape.

In [28]:
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)

x_bag_of_words_2 = weights @ x # (B, T, T) matmul by (B, T, C) --> (B, T, C)
print("Are Close:", torch.allclose(x_bag_of_words, x_bag_of_words_2))

print(weights)

Are Close: True
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


Version 3: Using softmax

In [34]:
tril = torch.tril(torch.ones(T, T))
weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float("-inf")) # future cant communicate with the past
print("masked fill:")
print(weights)

weights = F.softmax(weights, dim=-1) # normalize
x_bag_of_words_3 = weights @ x  # sum
print("Are Close:", torch.allclose(x_bag_of_words, x_bag_of_words_3))

print(weights)

masked fill:
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
Are Close: True
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1

### Attention - Most important Part

NOTES
- It is a communication mechanism. Can be seen as nodes in a directed graph. Edges point between nodes. Every node has a vector of information, and it gets to aggregate information from all nodes pointed to it.

- In our graph each token is pointed to by itself and each previous token.

- There is no notion of space. This is why we need our positional encodings to inform the graph of order. 

- Each example accross batch dimensions are compeletly independant and never talk to each other.

- "self-attention" is called as such because all of the information which defines the attention *weights* comes from *x*

- "scaled attention" when query Q and key K are unit varience, weights will be unit varience too and softmax will stay diffuse and not saturate too much.

Every single node/token will emit two vectors 
- query (what am I looking for)
- key (what do I contain)

query is dot producted with keys to make weights

We are going to make one "head"

In [44]:
B,T,C = 4,8,32
x = torch.randn(B,T,C)

# lets see a single self-attention head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16) <- 16 is from head_size
q = query(x) # (B, T, 16)
v = value(x) # value is used for elements that we aggregate rather than x
weights = q @ k.transpose(-2, -1) # (B, T, 16) matmul (B, 16, T) ---> (B, T, T)
weights *= head_size**-0.5 # "scaled attention" - (1/sqrt(head_size))

tril = torch.tril(torch.ones(T, T))
# weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float("-inf")) # future cant communicate with the past

weights = F.softmax(weights, dim=-1) # normalize
x_bag_of_words_4 = weights @ v  # sum

print(weights[0])

print(x_bag_of_words_4.shape)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7179, 0.2821, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1932, 0.3257, 0.4811, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4252, 0.1471, 0.2120, 0.2157, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3869, 0.1577, 0.1527, 0.1365, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.2094, 0.1212, 0.1904, 0.1660, 0.1020, 0.2110, 0.0000, 0.0000],
        [0.1112, 0.1016, 0.1692, 0.2236, 0.0995, 0.1818, 0.1131, 0.0000],
        [0.0923, 0.1413, 0.1114, 0.1750, 0.1668, 0.0749, 0.1413, 0.0968]],
       grad_fn=<SelectBackward0>)
torch.Size([4, 8, 16])


### Redefine Hyper-parameters

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 64
block_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
eval_iters = 200
num_embedding_dim = 384
num_attn_heads = 6
num_layers = 6 
dropout = 0.2 # random shuts off some subset of nuerons, allows prevention of overfitting

if device == "cpu":
  # if not on GPU, drop the parameters down
  num_embedding_dim = 8
  num_attn_heads = 1

### Self-Attention Head Class

In [None]:
class Head(nn.Module):
  """One head of self-attention"""
  def __init__(self, head_size) -> None:
    super().__init__()
    self.key = nn.Linear(C, head_size, bias=False)
    self.query = nn.Linear(C, head_size, bias=False)
    self.value = nn.Linear(C, head_size, bias=False)
    self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, x):
    B,T,C = x.shape
    k = key(x) # (B, T, 16) <- 16 is from head_size
    q = query(x) # (B, T, 16)
    v = value(x) # value is used for elements that we aggregate rather than x

    weights = q @ k.transpose(-2, -1) # (B, T, 16) matmul (B, 16, T) ---> (B, T, T)
    weights *= C**-0.5 # "scaled attention" - (1/sqrt(head_size))
    weights = weights.masked_fill(self.tril[:T, :T] == 0, float("-inf")) # future cant communicate with the past
    weights = F.softmax(weights, dim=-1) # normalize
    weights = self.dropout(weights)
    out = weights @ v  # sum
    return out

### Multi-head Attention
multi-head attention is multiple heads of attention running in parallel and then the results are concat'ed

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, head_size) -> None:
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size=head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(num_embedding_dim, num_embedding_dim)
    self.dropout = nn.Dropout(dropout)
  
  def forward(self, x):
    out = torch.cat([head(x) for head in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

#### Feed forward layers of transformer
The tokens looked at each other but didn't have time to "think" about what the others meant. 

In [45]:
class FeedForward(nn.Module):
  """A simple linear layer followed by a non-linearity"""
  def __init__(self, num_embedding_dim) -> None:
    super().__init__()

    # there is a ration of 4:1 for input/output for d_model and the inner-layer dimensionality
    self.net = nn.Sequential(
      nn.Linear(num_embedding_dim, 4 * num_embedding_dim),
      nn.ReLU(),
      nn.Linear(4 * num_embedding_dim, num_embedding_dim), # projection layer
      nn.Dropout(dropout)
    )

  def forward(self, x):
    return self.net(x)

#### Residual Connections

In [None]:
class Block(nn.Module):
  """Transformer block, communication followed by computation"""
  def __init__(self, num_embedding_dim, num_heads) -> None:
    super().__init__()
    head_size = num_embedding_dim // num_heads
    self.self_attn = MultiHeadAttention(
      num_heads=num_heads,
      head_size=head_size
    )
    self.ffwd = FeedForward(num_embedding_dim)
    self.layer_norm_1 = nn.LayerNorm(num_embedding_dim)
    self.layer_norm_2 = nn.LayerNorm(num_embedding_dim)

  def forward(self, x):
    x = x + self.self_attn(self.layer_norm_1(x))
    x = x + self.ffwd(self.layer_norm_2(x))
    return x

#### Layernorm
normalizing columns of input, I'm using the provided pytorch implementation rather than writing my own

### Updated Bigram Model

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

vocab_size = vocab_size

class BigramLanguageModel(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=num_embedding_dim)
    self.position_embedding_table = nn.Embedding(block_size, num_embedding_dim)
    
    self.blocks = nn.Sequential(
      Block(num_embedding_dim, num_heads=num_attn_heads),
      Block(num_embedding_dim, num_heads=num_attn_heads),
      Block(num_embedding_dim, num_heads=num_attn_heads),
      nn.LayerNorm(num_embedding_dim)
    )
    self.ffwd = FeedForward(num_embedding_dim)

  def forward(self, idx, targets=None):
    token_embeddings = self.token_embedding_table(idx) # (batch by time by channel) tensor
    pos_embeddings = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
    x = token_embeddings + pos_embeddings
    x = self.blocks(x) # (B, T, C)
    logits = self.lm_head(x) # (B, T, vocab_size)

    if targets == None:
      loss = None
    else:
      # we need to reshape our logits because the loss function expects (B by C by T)
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)

      loss = F.cross_entropy(logits, targets)

    return logits, loss # scores for next char sequence 
  
  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of inidices in the current context
    for _ in range(max_new_tokens):
      idx_cond = idx[:, -block_size]    # crop idx to the last block_size tokens
      logits, loss = self(idx_cond)     # get the predictions 
      logits = logits[:, -1, :]         # focus only on last time step
      probs = F.softmax(logits, dim=-1) # apply softmax to get probabilities
      idx_next = torch.multinomial(probs, num_samples=1) # sample from the distrabution
      idx = torch.cat((idx, idx_next), dim=1)
    
    return idx

## A Review of the code
We have written a decoder only transformer, it is only capable of making random output.

We will have to implement our own encoder block should we want to take in tokens on the fly as input. This encoder wouldnt have the same -nf mask as the decoder but would have attention to everything. **should be at around 1:40:00 in the video**

### Pretraining vs Fine Tuning

#### Pretraining
In the Chat-GPT paper it contains a table with the hyperparameters they used. This setp is the main bulk of the training where you load in massive sets of parameters. 

#### Fine tuning
Rather than spitting out garbage in attempt to finish sequences you want to optimize it for a certain task. Chat-GPT's website give a general 3 step outline for how this should be done. 
1. Collection demonstration data and train a supervised policy.
2. Collect comparison data and train a reward model
3. Optimize a policy against the reward model using the PPO reinforcement learning algo.
