# Transformer for human number sequence prediction

[Fake college football](https://www.reddit.com/r/FakeCollegeFootball/) is a game where two teams, an offense and a defense, select numbers 1-1500. If those numbers are close together, the offense has a good play and if those numbers are far apart, the defense has a good play (note that 1 and 1500 are 1 apart, as the numbers are in a circle). The game takes around 100 numbers to complete.

Given the sequential nature of the game, I constructed a transformer to predict the probability that the next number submitted by a player is within each of the 15 100 number groups (1-100, 101-200, etc.).

In [None]:
import pandas as pd
import requests
import torch
import torch.nn as nn
from torch.nn import functional as F

usernames = ["nusm","belikemefr","mycatsananalasshole","therealslimcampbell","ihasmagyk","prankishpoet","bearded_wildcard","mile114","vrain_19",
    "crabface5","wazzup44","callmefoofking","erduck96","bighoppy75","metsareawesome5","yeeting_man","egglton","mississippimadness","rph2003",
    "snasty728","steelersforlife11","skeleton-with-skin10","tehmoofish","l3ach13","6g0d757","rickypop","davy_grolton","broo_lynn","klarge_",
    "alternateshapes","spanxc","3rantgp","creatxrcreator","dj_bradlezzz","naragog1","minimum_junket9199","psujosh","baetaro","sonerandomuser",
    "birdmanthethird","jakester1238","americansasquatch_24","qc_undercover","despacitoritzcracker","kelbo11","the_hoff34","torchedpineapple",
    "rocco5w","horribelspelling","hobbes_t_hero","scum-phoenix","scsprinter13","astockusername","callofmc","_toughscene","dark197","natestate",
    "jsteele1423","hman1500","sexy-chicagoan-1837","thebattlersprince","bearinthewoods1","matt__17","ethed99","edgerocks2","door_nav","leruul",
    "tee_jay9","zachfischer2528","shoonipatooti","inatro","oogadebob","fyre87"]
# Assuming the list of dictionaries is stored in a variable called 'data'
plays_data = []
for username in usernames:
  # API endpoint URL
  url = "https://api.1212.one/plays/coach/" + username
  try:
    response = requests.get(url, timeout=10)
  except requests.exceptions.RequestException as e:
    print("Error:", e)
    print(f"Failed on username: {username}")
    continue


  # Make a GET request to the API
  if response.status_code == 200:
      # Request was successful
      data = response.json()  # Parse the response as JSON
      print(f"Successfully fetched data for username: {username}")
  else:
      # Request failed
      print(f"Request failed with status code: {response.status_code}")
      print(f"Failed on username: {username}")
      continue


  playerNumber = -1
  opponentNumber = -1

  for play in data:
      # If too early, then not all data available including play number. Ignore those games
      if not "playNumber" in play.keys():
        continue
      if play['playNumber'] == 1:
        playerNumber = -1
        opponentNumber = -1

      if opponentNumber != None and playerNumber != None:
        play_data = {
            'gameId': play['game']['gameId'],
            'playNumber': play['playNumber'],
            'playerNumber': float(playerNumber),
            'opponentNumber': float(opponentNumber),
            'down': play['down'],
            'distance': play['distance'],
            'quarter': play['quarter'],
            'clock': play['clock'],
            # 'yards': play['yards'],
            # 'yardLine': play['yardLine'],
            'onOffense': int(play['coachIsOffense'])
        }
        # Used to shift numbers down one row. Because each row then holds previous numbers played and what down / offense it is
        offense = int(play['coachIsOffense'])
        if offense == 1:
          # Player is on offense
          playerNumber = play['offense']['number']
          opponentNumber = play['defense']['number']
        else:
          # Player is on defense
          playerNumber = play['defense']['number']
          opponentNumber = play['offense']['number']
        plays_data.append(play_data)

df = pd.DataFrame(plays_data)

# Remove overtime stuff
df = df[~(df['quarter']>=5)]
df.head(15)

In [None]:
import numpy as np

# Define the bins and labels for the categories
bins = [-np.inf, 1, 51, 101, 151, 201, 251, 301, 351, 401, 451, 501, 551, 601, 651, 701, 751, 801, 851, 901, 951, 1001, 1051, 1101, 1151, 1201, 1251, 1301, 1351, 1401, 1451, np.inf]
labels = [-1] + list(range(0, 30))

# Chunk the "playerNumber" column into categories
df['playerNumber']   = pd.cut(df['playerNumber'], bins=bins, labels=labels, right=False)

# Chunk the "opponentNumber" column into categories
df['opponentNumber'] = pd.cut(df['opponentNumber'], bins=bins, labels=labels, right=False)

In [None]:
# Create a new column to indicate the start of each game
df['new_game'] = df['gameId'] != df['gameId'].shift(1)

# Initialize an empty list to store the sequences
sequences = []

# Iterate over the rows of the DataFrame
for _, row in df.iterrows():
    # Tokenize the row (replace this with your actual tokenization logic)
    tokens = [row[col] for col in df.columns[1:-1]]

    sequences.append(tokens)

# Split the sequences into chunks of length 8
# input_sequences = [sequences[i:i+8] for i in range(0, len(sequences), 8)]
data = torch.tensor(sequences)
data[0:10]

# Now scale and reorder data
### Order is now:
1. Play (numeric)
2. Yards to Go (numeric)
3. Clock (numeric)
4. Your number (31 categories)
5. Opponent number (31 categories)
6. Down (4 categories)
7. Quarter (4 categories)
8. Offense (2 categories)

In [None]:
from sklearn.preprocessing import OneHotEncoder

# Now one hot encode categorical stuff and scale numeric stuff

numeric_columns = [0, 4, 6]
categorical_columns = [1, 2, 3, 5, 7]

categories = [[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
              [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
              [1, 2, 3, 4],
              [1, 2, 3, 4],
              [0, 1]]

# Convert to numpy and encode
encoder = OneHotEncoder(categories=categories, sparse_output=False)
one_hot_encoded = encoder.fit_transform(data[:, categorical_columns])

# Append it back
remaining_data = data[:, numeric_columns]
data = np.hstack((remaining_data, one_hot_encoded))

# Convert back to a PyTorch tensor if needed
data = torch.tensor(data, dtype=torch.float)


# Scale the vars now:::
# data[:, [0, 1, 2]] = torch.where(data[:, [0, 1, 2]] == -1, torch.tensor(0), data[:, [0, 1, 2]])
data[:, 0] = data[:, 0]/200
data[:, 1] = data[:, 1]/10
data[:, 2] = data[:, 2]/420



# Remove the -1 category columns
columns_to_remove = [3, 34]
mask = [i for i in range(data.shape[1]) if i not in columns_to_remove]
data = data[:, mask]
data[6:8]


# Transformer code

Transformer custom built for number sequence prediction

In [None]:

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# from utils import DEVICE


class AttentionHead(nn.Module):
    """
    One head of the self-attention layer
    """

    def __init__(self, head_size, num_embed, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(num_embed, head_size, bias=False)
        self.query = nn.Linear(num_embed, head_size, bias=False)
        self.value = nn.Linear(num_embed, head_size, bias=False)
        # tril is a lower triangular matrix. it is not a parameter
        # of the model, so we assign it to the module using register_buffer
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        # let's also add dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape # (Batch Size, Context Length, Embedding Size)
        k = self.key(x)
        q = self.query(x)
        # compute attention scores
        # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        # Tril matrix (lower triagular matrix) is used to mask
        # future positions (setting them to -inf) so that the
        # decoder "learns" to predict next words
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))  # (B,T,T)

        wei = F.softmax(wei, dim=-1)  # (B,T,T)
        wei = self.dropout(wei)
        # weighted aggregation of the values
        v = self.value(x)
        out = wei @ v  # (B,T,T) @ (B,T,C) ---> (B,T,C)
        return out


class MultiHeadAttention(nn.Module):
    """
    Multiple Heads of self-attention in parallel
    """

    def __init__(self, num_heads, head_size, num_embed, block_size, dropout):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                AttentionHead(
                    head_size=head_size,
                    num_embed=num_embed,
                    block_size=block_size,
                    dropout=dropout,
                )
                for _ in range(num_heads)
            ]
        )
        self.proj = nn.Linear(num_embed, num_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # output of the self-attention
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # apply the linear projection layer
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    """
    A simple linear layer followed by ReLu
    """

    def __init__(self, num_embed, dropout):
        super().__init__()
        self.net = nn.Sequential(
            # in the Attention is All You Need paper
            # authors are using the size of the ffwd layer 2048
            # and the output of the model is 512
            # so we apply the same factor of 4
            nn.Linear(num_embed, 4 * num_embed),
            nn.ReLU(),
            # apply the linear projection layer
            nn.Linear(4 * num_embed, num_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """
    This calss will group together MultiHead Attention and
    FeedForward NN, so that we can copy it in Transformer
    """

    def __init__(self, num_heads, block_size, num_embed, dropout):
        super().__init__()
        head_size = num_embed // num_heads
        self.sa = MultiHeadAttention(
            num_heads=num_heads,
            head_size=head_size,
            num_embed=num_embed,
            block_size=block_size,
            dropout=dropout,
        )
        self.ffwd = FeedForward(num_embed=num_embed, dropout=dropout)
        # add the layer normalization
        self.ln1 = nn.LayerNorm(num_embed)
        self.ln2 = nn.LayerNorm(num_embed)

    def forward(self, x):
        # "x +" is the skip (or residual) connection
        # it helps with optimization
        # also we apply layer normalization before self-attention
        # and feed-forward (a reshufle from original paper)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        # a simple lookup table that stores embeddings of a fixed dictionary and size
        # each token directly reads off the logits for the next token from a lookup table
        # see more: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
        self.input_size = kwargs.get("input_size", 75)
        self.output_size = kwargs.get("output_size", 30)
        self.num_embed = kwargs.get("num_embed", 32)
        self.block_size = kwargs.get("block_size", 8)
        self.num_heads = kwargs.get("num_heads", 4)
        self.num_layers = kwargs.get("num_layers", 4)
        self.dropout = kwargs.get("dropout", 0.2)
        # each token reads the logits for the next token from a lookup table
        # self.token_embedding_table = nn.Embedding(self.vocab_size, self.num_embed)
        # each position from 0 to block_size-1 will get its embedding
        self.embedding_creator = nn.Linear(self.input_size, self.num_embed)
        # self.embedding_creator = torch.rand(self.input_size, self.num_embed, device=DEVICE)
        self.position_embedding_table = nn.Embedding(self.block_size, self.num_embed)
        self.blocks = nn.Sequential(
            *[
                TransformerBlock(
                    num_heads=self.num_heads,
                    block_size=self.block_size,
                    num_embed=self.num_embed,
                    dropout=self.dropout,
                )
                for _ in range(self.num_layers)
            ]
        )
        # we add the layer norm before the Linear layer
        self.ln_f = nn.LayerNorm(self.num_embed)
        self.lm_head = nn.Linear(self.num_embed, self.output_size) # output logits for each output category

    def forward(self, idx, targets=None):
        #**
        B, T, _ = idx.shape

        # idx and targets are (B,T) tensor of integers
        # the token_emb is (B, T, C), C = NUM_EMBED
        # token_emb = self.token_embedding_table(idx)
        token_emb = self.embedding_creator(idx)
        # token_emb = torch.matmul(idx, self.embedding_creator)

        # (T, C)
        posit_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))

        x = token_emb + posit_emb
        # apply one head of self-attention
        x = self.blocks(x)
        # (B, T, vocab_size)
        logits = self.lm_head(x)

        # Get the targets, which is the players next number
        targets = torch.argmax(targets[:, :, 3:3+self.output_size], axis=2)

        # compute the loss
        if targets != None:
            # cross_entropy accepts inputs in a (batch_size, num_classes)
            # so we need to reformat our logits dimensions to
            # (batch_size * time, dim_vocabulary), time = block_size


            # Change this so only works with min context of 8. Don't care about 0 shot predictions
            logits = logits[:, MIN_CONTEXT:, :]
            targets = targets[:, MIN_CONTEXT:]
            B, T, C = logits.shape


            logits = torch.reshape(logits, (B * T, C))
            targets = torch.reshape(targets, (B * T,))

            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
        return logits, loss

    def generate(self, idx: torch.Tensor, max_new_tokens: int, block_size: int):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop the context too the  last block_size tokens
            # because tokens don't communicate between blocks
            idx_crop = idx[:, -block_size:]
            # get the predictions
            logits, loss = self.forward(idx_crop)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution with probabilities probs
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


from datetime import datetime
import os



def encode(text_seq: str, tokenizer: any) -> torch.Tensor:
    """
    Function to encode input text using a pre-trained tokenizer and vectorized lookups
    """
    # tokenize the input text
    tokens = tokenizer.tokenize(text_seq)
    # convert the tokens to their corresponding ids
    token_indices = tokenizer.convert_tokens_to_ids(tokens)
    token_indices = torch.tensor(token_indices, dtype=torch.long)
    return token_indices


def decode(enc_sec: torch.Tensor, tokenizer: any) -> str:
    """
    Function to decode a sequence of token indices back to a string
    """
    # convert the indices to a list
    enc_sec = enc_sec.tolist()
    # decode the indices to a string
    text = tokenizer.decode(enc_sec)
    return text


def get_batch(data: list[str], block_size: int, batch_size: int, valid_indices):
    """
    This is a simple function to create batches of data.
    GPUs allow for parallel processing we can feed multiple chunks at once
    so that's why we would need batches - how many independant sequences
    will we process in parallel.

    Parameters:
    data: list[str]: data to take batch from
    block_size (int): size of the text that is proccessed at once (AKA CONTEXT LENGTH - WILL)
    batch_size (int): number of sequences to process in parallel

    Returns:
    x, y: a tuple with token sequence and token target
    """
    ix = valid_indices[torch.randint(len(valid_indices), (batch_size,))] # New ix based on valid indices




    # we stack batch_size rows of sentences
    # so x and y are the matrices with rows_num=batch_size
    # and col_num=block_size
    x = torch.stack([data[i : i + block_size] for i in ix])
    # y is x shifted one position right - because we predict
    # word in y having all the previous words as context
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    x, y = x.to(DEVICE), y.to(DEVICE)

    return x, y


@torch.no_grad()
def estimate_loss(
    data: list[str],
    model: torch.nn.Module,
    block_size: int,
    batch_size: int,
    eval_iters: int = 10,
    valid_indices=None,
):
    out = {}
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch(data=data, block_size=block_size, batch_size=batch_size, valid_indices=valid_indices)
        logits, loss = model.forward(X, Y)
        losses[k] = loss.item()
    out = losses.mean()
    model.train()
    return out


@torch.no_grad()
def get_predictions(
    data: list[str],
    model: torch.nn.Module,
    block_size: int,
    batch_size: int,
    eval_iters: int = 10,
    valid_indices=None,
):
    model.eval()
    losses = torch.zeros(eval_iters)

    X, Y = get_batch(data=data, block_size=block_size, batch_size=batch_size, valid_indices=valid_indices)
    logits, loss = model.forward(X, Y)

    model.train()
    return logits


def load_model_from_checkpoint(
    model_class: torch.nn.Module,
    path_to_checkpoint: str = "checkpoints/state_dict_model.pt",
    **kwargs: dict,
) -> torch.nn.Module:
    try:
        state_dict = torch.load(path_to_checkpoint)
        print("Successfully loaded model from the checkpoint")
    except Exception as e:
        print(f"Error loading the model from the checkpoint. {e}")

    model = model_class(**kwargs)
    # load the state_dict into the model
    model.load_state_dict(state_dict)
    return model


def save_model_to_chekpoint(
    model: torch.nn.Module, path_to_checkpoint: str = "checkpoints", epoch: int = 0
):
    # check if path exists, otherwise create it
    if not os.path.exists(path_to_checkpoint):
        os.makedirs(path_to_checkpoint)

    # datetime object containing current date and time
    now = datetime.now()
    # dd/mm/YY H:M:S
    dt_string = now.strftime("%d.%m.%Y_%H:%M:%S")
    checkpoint_name = "checkpoint_epoch-" + str(epoch) + "_" + dt_string + ".pt"
    full_path = os.path.join(path_to_checkpoint, checkpoint_name)
    try:
        torch.save(model.state_dict(), full_path)
        print("Successfully saved the model to {}".format(full_path))
    except Exception as e:
        print(f"Error saving the model to checkpoint. {e}")

# Custom loss function

Calculates the distance between the two numbers (where 1 and 1500 are 1 apart).

This method is slow and not clearly better than cross entropy.

In [None]:
# This method is very slow
def reddit_football_loss(logits, target_index):
    """
    Custom loss function for logits and a single target index.

    Args:
        logits (torch.Tensor): Tensor of shape (batch_size, num_classes) with logits.
        target_index (int): The correct index (ground truth).

    Returns:
        torch.Tensor: The computed custom loss.
    """
    # Convert logits to probabilities using softmax
    probabilities = F.softmax(logits, dim=-1)

    # Get batch size and number of classes
    batch_size, num_classes = probabilities.shape

    # Create a tensor to hold the loss for each element in the batch
    loss = torch.zeros(batch_size, device=DEVICE)

    # Calculate the custom loss for each example in the batch
    for i in range(batch_size):
        target = target_index[i].item()
        for offset in range(1, (num_classes // 2) + 1):
            if offset == (num_classes // 2):
              # On last time, don't double count
              # NOTE:: #** could be good to double count cuz being exactly wrong is real bad!
              left_index = (target - offset) % num_classes
              loss[i] += offset * (probabilities[i, left_index])
            else:
              left_index = (target - offset) % num_classes
              right_index = (target + offset) % num_classes
              loss[i] += offset * (probabilities[i, left_index] + probabilities[i, right_index])

    # Return the mean loss over the batch
    return loss.mean()

# Example usage
# logits = torch.randn(3, 30)  # Example logits for a batch of 3 samples and 30 classes
logits = torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
target_index = torch.tensor([1, 0, 0])  # Example target indices for each sample in the batch

loss = reddit_football_loss(logits, target_index)
print("Random loss: ", loss.item())

# Training loop

In [None]:
import torch
from tqdm import tqdm
import numpy as np
# from model import Transformer
from transformers import AutoTokenizer  # pip install transformers
BATCH_SIZE = 32  # how many independent sequences will we process in parallel?
BLOCK_SIZE = 12  # what is the maximum context length for predictions?
MIN_CONTEXT = 6  # Minimum context length for training, aka only do backprop on guesses where you had at least MIN_CONTEXT other things in context window
MAX_ITER = 10000  # number of training iterations
EVAL_INTER = 100
LEARNING_RATE = 1e-5
NUM_HEAD = 4
NUM_EMBED = NUM_HEAD * 128
NUM_LAYER = 4
INPUT_SIZE = data.shape[1]
OUTPUT_SIZE = 30 # Number of categories
DROPOUT = 0

torch.manual_seed(1)

# load model from checkpoint
# m = load_model_from_checkpoint(Transformer,vocab_size=vocab_size)

# example to decode sequence
# enc_sec = m.generate(idx=torch.zeros((1,1), dtype=torch.long),
# max_new_tokens=20)[0].tolist()
# print(decode(vocab=vocab, enc_sec=enc_sec))


# Train and val data
train_data = data[0:int(0.7*data.shape[0])]
val_data = data[int(0.7*data.shape[0]):]


# Get valid indices to use in the training loop (Cant use indices where you will flip into the next game)
valid_indices_train = []
for i in range(BLOCK_SIZE, train_data.shape[0]):
  if train_data[i, 0] > train_data[i-BLOCK_SIZE, 0]:
    # If the current thing is still higher than all the old ones we good!
    valid_indices_train.append(i-BLOCK_SIZE) # Start at 8, append 0 if its valid 8 out
  else:
    i += BLOCK_SIZE+1
  # if data[i, 0] == data[i-BLOCK_SIZE, 0]+BLOCK_SIZE/200:
  #   valid_indices.append(i-BLOCK_SIZE) # Start at 8, append 0 if its valid 8 out
valid_indices_train = torch.tensor(valid_indices_train)


valid_indices_val = []
for i in range(BLOCK_SIZE, val_data.shape[0]):
  if val_data[i, 0] > val_data[i-BLOCK_SIZE, 0]:
    # If the current thing is still higher than all the old ones we good!
    valid_indices_val.append(i-BLOCK_SIZE) # Start at 8, append 0 if its valid 8 out
  else:
    i += BLOCK_SIZE+1
valid_indices_val = torch.tensor(valid_indices_val)



# train a new model
model = Transformer(
    # vocab_size=vocab_size,
    output_size=OUTPUT_SIZE,
    input_size=INPUT_SIZE,
    num_embed=NUM_EMBED,
    block_size=BLOCK_SIZE,
    num_heads=NUM_HEAD,
    num_layers=NUM_LAYER,
    dropout=DROPOUT,
)
# load model to GPU if available
m = model.to(DEVICE)
# print the number of parameters in the model
print(
    "Model with {:.2f}M parameters".format(sum(p.numel() for p in m.parameters()) / 1e6)
)
# optimizer takes the model's parameters and the learning rate as input,
# and updates the parameters during the training process in order to
# minimize the loss function.
optimizer = torch.optim.AdamW(m.parameters(), lr=LEARNING_RATE)

train_loss = []
val_loss = []

for step in tqdm(range(MAX_ITER)):

    # every EVAL_INTER evaluate the loss on train and val sets
    if step % EVAL_INTER == 0 or step == MAX_ITER - 1:
        loss_train = estimate_loss(
            data=train_data, model=m, block_size=BLOCK_SIZE, batch_size=BATCH_SIZE, valid_indices=valid_indices_train, eval_iters=50
        )
        train_loss.append(loss_train)

        loss_val = estimate_loss(
            data=val_data, model=m, block_size=BLOCK_SIZE, batch_size=BATCH_SIZE, valid_indices=valid_indices_val, eval_iters=50
        )
        val_loss.append(loss_val)

        print("step {:10} | train loss {:6.4f} | val loss {:6.4f}".format(step, loss_train, loss_val))

    # sample a batch of data
    xb, yb = get_batch(data=train_data, block_size=BLOCK_SIZE, batch_size=BATCH_SIZE, valid_indices=valid_indices_train)
    logits, loss = m.forward(xb, yb)
    # zero_grad() method sets the gradients of all parameters in the optimizer to zero
    optimizer.zero_grad(set_to_none=True)
    # backward() method on the loss variable calculates the gradients
    # of the loss with respect to the model's parameters.

    loss.backward()

    # step() method on the optimizer updates the model's parameters
    # using the calculated gradients, in order to minimize the loss.
    optimizer.step()

save_model_to_chekpoint(model=m, path_to_checkpoint="checkpoints", epoch=step)


In [None]:
# Create the plot
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.plot(train_loss[1:], label='Training Loss', marker='o')
plt.plot(val_loss[1:], label='Validation Loss', marker='x')

# Adding titles and labels
plt.title('Training and Validation Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()

# Show the plot
plt.grid(True)
plt.show()