# Imports

In [1]:
import numpy as np
import datetime
import torch
import torch.optim as optim
import torch.nn as nn
import torch.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('fivethirtyeight')

2024-05-31 13:55:38.168547: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-31 13:55:38.171748: I external/local_tsl/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-05-31 13:55:38.218714: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# The StepByStep Class

In [None]:
class StepByStep():
    def __init__(self, model, loss_func, optimizer):
        # Here we define the attributes of our class
        # We start by storing the arguments as attributes
        # to use them later
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # Let's send the model to the specified device right away
        self.model.to(self.device)
        
        # These attributes are defined here, but since they are
        # not available at the moment of creation, we keep them None
        self.train_loader = None
        self.val_loader = None
        self.writer = None
        
        # These attributes are going to be computed internally
        self.losses = []
        self.val_losses = []
        self.total_epochs = 0
        
        # Creates the train_step function for our model,
        # loss function and optimizer
        # Note: there are NO ARGS there! It makes use of the class
        # attributes directly
        self.train_step_fn = self._make_train_step_fn()
        
        # Creates the val_step function for our model and loss
        self.val_step_fn = self._make_val_step_fn()
        
    def to(self, device):
        # This method allows the user to specify a different device
        # It sets the corresponding attribute (to be used later in
        # the mini-batches) and sends the model to the device
        try:
            self.device = device
            self.model.to(self.device)
        except RuntimeError:
            self.device = ('cuda' if torch.cuda.is_available()
            else 'cpu')
            print(f"Couldn't send it to {device}, \
            sending it to {self.device} instead.")
            self.model.to(self.device)
            
    def set_loaders(self, train_loader, val_loader=None):
        # This method allows the user to define which train_loader
        # (and val_loader, optionally) to use
        # Both loaders are then assigned to attributes of the class
        # So they can be referred to later
        self.train_loader = train_loader
        self.val_loader = val_loader
        
    def set_tensorboard(self, name, folder='runs'):
        # This method allows the user to create a SummaryWriter to
        # interface with TensorBoard
        suffix = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
        self.writer = SummaryWriter(f'{folder}/{name}_{suffix}')
        
    def _make_train_step_fn(self):
        # This method does not need ARGS... it can use directly
        # the attributes: self.model, self.loss_fn and self.optimizer
        
        # Builds function that performs a step in the train loop
        def perform_train_step_fn(x, y):
            # set the model to TRAIN mode
            self.model.train()
            
            # Step 1: computes model's predicted output - forward pass
            y_hat = self.model(x)
            
            # Step 2: computes the loss
            loss = self.loss_fn(y_hat, y)
            
            # Step 3: computes gradients w.r.t model params
            loss.backward()
            
            # Step 4: Updates params using gradients and the learning rate
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            # returns the loss
            return loss.item()
        
        # Returns the function that will be called inside the train loop
        return perform_train_step_fn
    
    def _make_val_step_fn(self):
        # Builds function that performs a step in the validation loop
        def perform_val_step_fn(x, y):
            # set the model to EVAL mode
            self.model.eval()
            
            # Step 1: computes model's predicted output - forward pass
            y_hat = self.model(x)
            
            # Step 2: computes the loss
            loss = self.loss_fn(y_hat, y)
                        
            # returns the loss
            return loss.item()
        
        # Returns the function that will be called inside the validation loop
        return perform_val_step_fn