In [2]:
import os 
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [14]:
# Import dependencies
import math
from collections import OrderedDict
from tqdm.auto import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
import torch.nn.init as init

import random
import numpy as np
import psutil
from torchsummary import summary
import matplotlib.pyplot as plt

import warnings 
warnings.filterwarnings("ignore")

In [4]:
def set_experiment_seed(seed=42):
    """
    Set random seeds and CUDA-related flags for experiment reproducibility.
    """
    # Set random seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # If using GPU, set random seed for CUDA operations
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False  # Set to False for reproducibility

    # Optionally log seed information
    print(f"Random seed set to {seed}")

set_experiment_seed()

Random seed set to 42


In [5]:
# Download MNIST dataset
data_train = MNIST(
    "./data/mnist", download=True, 
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.GaussianBlur(kernel_size=1),
        transforms.ToTensor()
    ])
)
        
data_test = MNIST(
    "./data/mnist", download=True, train=False,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ])
)

In [32]:
# Define hyperparameters
N_CLASSES = 10
EPOCHS = 50
BATCH_SIZE = 512
DROPOUT_PROB = 0.2
USE_BATCHNORM = True
LABEL_SMOOTHING = 0.1
LR = 1.34E-03
WARMUP_PROPORTION = 0.1

In [7]:
# plt.figure(figsize=(10,10))
# for i in range(25):
#     plt.subplot(5,5,i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(data_train[i][0].permute(1,2,0).numpy(), cmap=plt.cm.binary)
#     plt.xlabel(data_train[i][1])
# plt.show()

`torch.utils.data.DataLoader` supports asynchronous data loading and data augmentation in separate worker subprocesses. The default setting for DataLoader is `num_workers=0`, which means that the data loading is synchronous and done in the main process. As a result the main training process has to wait for the data to be available to continue the execution.

Settin`g num_workers >` 0 enables asynchronous data loading and overlap between the training and data loading. num_workers should be tuned depending on the workload, CPU, GPU, and location of training data.`

DataLo`ader acce`pts pin_me`mory argument, which defaults` to F`alse. When using a GPU it’s better to `set pin_memory=`True, this instru`cts DataLo`ader to use pinned memory and enables faster and asynchronous memory copy from the host to the GPU.

In [8]:
def calculate_num_workers():
    # Calculate the number of CPU cores
    num_cpu_cores = os.cpu_count()

    # Set a safe maximum multiplier value (e.g., 0.5) to avoid using all available resources
    max_multiplier = 0.5

    # Calculate the number of workers based on the available resources
    multiplier = min(sum(psutil.cpu_percent(interval=1, percpu=True)) / 100.0, max_multiplier)
    
    # Ensure that num_workers is at least 1
    num_workers = max(1, int(num_cpu_cores * multiplier))

    return num_workers

In [9]:
# Define train dataloader
train_dataloader = DataLoader(
    data_train, batch_size=BATCH_SIZE, 
    sampler=RandomSampler(data_train), 
    pin_memory=True, num_workers=calculate_num_workers()
)

# Define test dataloader
test_dataloader = DataLoader(
    data_test, batch_size=BATCH_SIZE, 
    sampler=SequentialSampler(data_test),
    pin_memory=True, num_workers=calculate_num_workers()
)

`torch.nn.Conv2d()` has `bias` parameter which defaults to `True` (the same is true for `Conv1d` and `Conv3d` ).

If a `nn.Conv2d` layer is directly followed by a `nn.BatchNorm2d` layer, then the bias in the convolution is not needed, instead use `nn.Conv2d(..., bias=False, ....)`. Bias is not needed because in the first step `BatchNorm` subtracts the mean, which effectively cancels out the effect of bias.

This is also applicable to 1d and 3d convolutions as long as `BatchNorm`(or other normalization layer) normalizes on the same dimension as convolution’s bias.

In [10]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
                
        self.features1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=(5, 5), bias=not USE_BATCHNORM),
            self._get_norm_layer(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Dropout2d(p=DROPOUT_PROB)
        )
        
        self.features2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=(5, 5), bias=not USE_BATCHNORM),
            self._get_norm_layer(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Dropout2d(p=DROPOUT_PROB)
        )

        self.features3 = nn.Sequential(
            nn.Conv2d(16, 120, kernel_size=(5, 5), bias=not USE_BATCHNORM),
            self._get_norm_layer(120),
            nn.ReLU(),
            nn.Dropout2d(p=DROPOUT_PROB)
        )

        self.classifier = nn.Sequential(
            nn.Linear(120, 84, bias=True),
            nn.ReLU(),
            nn.Dropout(p=DROPOUT_PROB),
            nn.Linear(84, 10, bias=True)
        )

        # Initialize layers
        self.apply(self._initialize_weights)

    def _get_norm_layer(self, channels):
        if USE_BATCHNORM:
            return nn.BatchNorm2d(channels)
        else:
            return nn.Identity()

    def _initialize_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            init.kaiming_uniform_(m.weight)
            if m.bias is not None:
                init.constant_(m.bias, 0)
        
    def forward(self, x):
        x = self.features1(x)
        x = self.features2(x)
        x = self.features3(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
# Instantiate the model
model = LeNet()
model = model.to(device)

# Test with a random input
summary(model, input_size=(1, 32, 32), batch_size=BATCH_SIZE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [512, 6, 28, 28]             150
       BatchNorm2d-2           [512, 6, 28, 28]              12
              ReLU-3           [512, 6, 28, 28]               0
         MaxPool2d-4           [512, 6, 14, 14]               0
         Dropout2d-5           [512, 6, 14, 14]               0
            Conv2d-6          [512, 16, 10, 10]           2,400
       BatchNorm2d-7          [512, 16, 10, 10]              32
              ReLU-8          [512, 16, 10, 10]               0
         MaxPool2d-9            [512, 16, 5, 5]               0
        Dropout2d-10            [512, 16, 5, 5]               0
           Conv2d-11           [512, 120, 1, 1]          48,000
      BatchNorm2d-12           [512, 120, 1, 1]             240
             ReLU-13           [512, 120, 1, 1]               0
        Dropout2d-14           [512, 12

#### 1. Cross Entropy Loss
The standard cross-entropy loss for classification tasks is given by:

$$ \text{Traditional Cross Entropy Loss: } H(y, \hat{y}) = - \sum_i y_i \log(\hat{y}_i) $$

 - $y_i$ is a binary indicator of whether class $i$ is the correct classification.  
 - $p_i$ is the predicted probability of class $i$.

#### 2. Label Smoothed Cross Entropy
Label Smoothing Cross Entropy Loss introduces a modification to the target distribution:

 $$ \text{Label Smoothed Cross Entropy Loss} = - \sum_i \left( (1 - \text{smoothing}) \cdot 1_{\{y_i\}} + \frac{\text{smoothing}}{C-1} \cdot 1_{\{1 - y_i\}} \right) \cdot \log(p_i) $$

Where:
- $1_{\{y_i\}}$ is a binary indicator of whether class $i$ is the correct classification.
- $p_i$ is the predicted probability of class $i$.
- $C$ is the number of classes.
- $\text{smoothing}$  is the smoothing factor.



In [13]:
class LabelSmoothedCrossEntropy(nn.Module):
    def __init__(self, num_classes, smoothing):
        super(LabelSmoothedCrossEntropy, self).__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, input_logits, target):
        log_probs = F.log_softmax(input_logits, dim=-1)
        true_dist = torch.zeros_like(log_probs)
        true_dist.fill_(self.smoothing / (self.num_classes - 1))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = -torch.sum(true_dist * log_probs) / input_logits.size(0)
        return loss

In [68]:
# Initialize loss
# criterion = nn.CrossEntropyLoss()
criterion = LabelSmoothedCrossEntropy(num_classes=N_CLASSES, smoothing=LABEL_SMOOTHING)

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)

#### Cosine Annealing With Warmup

In [69]:
def get_lr_lambda(initial_lr, warmup_steps, total_steps):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return initial_lr + (1.0 - initial_lr) * float(current_step) / float(max(1, warmup_steps))
        else:
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * (current_step - warmup_steps) / float(total_steps - warmup_steps))))
    return lr_lambda

In [70]:
len_dataloader = len(train_dataloader)
num_update_steps_per_epoch = max(len_dataloader, 1) 
num_examples = len(train_dataloader.dataset)
max_steps = math.ceil(EPOCHS * num_update_steps_per_epoch)
num_warmup_steps = math.ceil(max_steps * WARMUP_PROPORTION)
lr_lambda = get_lr_lambda(LR, num_warmup_steps, max_steps)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

#### LR Finder

In [71]:
class TrainDataLoaderIter:
    """
    Iterator for a PyTorch DataLoader, facilitating easy iteration over batches of data.

    Args:
        dataloader (torch.utils.data.DataLoader): The PyTorch DataLoader providing batches of data.
        auto_reset (bool, optional): If True, the iterator resets to the beginning of the DataLoader
            when it reaches the end. If False, a StopIteration exception is raised at the end.
            Default is True.

    Methods:
        inputs_labels_from_batch(batch_data):
            Extracts inputs and labels from a batch of data.

        __next__():
            Gets the next batch of data from the DataLoader, automatically resetting if necessary.

        __iter__():
            Returns the iterator object.

    Usage:
        # Create a TrainDataLoaderIter object
        train_iter = TrainDataLoaderIter(train_dataloader)

        # Iterate over batches of data
        for inputs, labels in train_iter:
            # Perform training using the current batch

    Example with custom inputs_labels_from_batch:
        class CustomTrainDataLoaderIter(TrainDataLoaderIter):
            def inputs_labels_from_batch(self, batch_data):
                # Custom logic to extract inputs and labels from the batch_data
                inputs, labels = batch_data['data'], batch_data['label']
                return inputs, labels
    """
    def __init__(self, dataloader, auto_reset=True):
        """
        Initialize the TrainDataLoaderIter.

        Args:
            dataloader (torch.utils.data.DataLoader): The PyTorch DataLoader providing batches of data.
            auto_reset (bool, optional): If True, the iterator resets to the beginning of the DataLoader
                when it reaches the end. If False, a StopIteration exception is raised at the end.
                Default is True.
        """
        self.dataloader = dataloader
        self._iterator = iter(dataloader)
        self.auto_reset = auto_reset

    def inputs_labels_from_batch(self, batch_data):
        """
        Extract inputs and labels from a batch of data.

        Args:
            batch_data: Batch of data returned by the DataLoader.

        Returns:
            Tuple: A tuple containing inputs and labels extracted from the batch_data.
        """
        inputs, labels = batch_data
        return inputs, labels

    def __next__(self):
        """
        Get the next batch of data from the DataLoader, automatically resetting if necessary.

        Returns:
            Tuple: A tuple containing inputs and labels for the next batch of data.

        Raises:
            StopIteration: If auto_reset is False and the end of the DataLoader is reached.
        """
        try:
            batch = next(self._iterator)
            inputs, labels = self.inputs_labels_from_batch(batch)
        except StopIteration:
            if self.auto_reset:
                self._iterator = iter(self.dataloader)
                batch = next(self._iterator)
                inputs, labels = self.inputs_labels_from_batch(batch)
            else:
                raise

        return inputs, labels

    def __iter__(self):
        """
        Returns the iterator object.

        Returns:
            self: The iterator object.
        """
        return self

In [72]:
class LinearLR(torch.optim.lr_scheduler._LRScheduler):
    """
    Linear learning rate scheduler.

    Args:
        optimizer (torch.optim.Optimizer): Optimizer to adjust the learning rate for.
        end_lr (float): The final learning rate after the specified number of iterations.
        num_iter (int): The total number of iterations for the learning rate schedule.
        last_epoch (int, optional): The index of the last epoch. Default is -1.

    Methods:
        get_lr():
            Calculate the linearly adjusted learning rates based on the current iteration.

    Usage:
        # Create a LinearLR scheduler
        linear_scheduler = LinearLR(optimizer, end_lr=0.001, num_iter=100)

        # Inside the training loop
        for epoch in range(num_epochs):
            for batch in data_loader:
                # Perform training using the current batch

                # Update the learning rate
                linear_scheduler.step()

    Note: Ensure that the scheduler is called within the training loop after each training step.

    Example:
        linear_scheduler = LinearLR(optimizer, end_lr=0.001, num_iter=100)
        for epoch in range(num_epochs):
            for batch in data_loader:
                # Perform training using the current batch

                # Update the learning rate
                linear_scheduler.step()
    """

    def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
        """
        Initialize the LinearLR scheduler.

        Args:
            optimizer (torch.optim.Optimizer): Optimizer to adjust the learning rate for.
            end_lr (float): The final learning rate after the specified number of iterations.
            num_iter (int): The total number of iterations for the learning rate schedule.
            last_epoch (int, optional): The index of the last epoch. Default is -1.
        """
        self.end_lr = end_lr
        self.num_iter = num_iter
        super(LinearLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        """
        Calculate the linearly adjusted learning rates based on the current iteration.

        Returns:
            list: List of adjusted learning rates for each parameter group.
        """
        r = self.last_epoch / (self.num_iter - 1)
        return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]

    
class ExponentialLR(torch.optim.lr_scheduler._LRScheduler):
    """
    Exponential learning rate scheduler.

    Args:
        optimizer (torch.optim.Optimizer): Optimizer to adjust the learning rate for.
        end_lr (float): The final learning rate after the specified number of iterations.
        num_iter (int): The total number of iterations for the learning rate schedule.
        last_epoch (int, optional): The index of the last epoch. Default is -1.

    Methods:
        get_lr():
            Calculate the exponentially adjusted learning rates based on the current iteration.

    Usage:
        # Create an ExponentialLR scheduler
        exponential_scheduler = ExponentialLR(optimizer, end_lr=0.001, num_iter=100)

        # Inside the training loop
        for epoch in range(num_epochs):
            for batch in data_loader:
                # Perform training using the current batch

                # Update the learning rate
                exponential_scheduler.step()

    Note: Ensure that the scheduler is called within the training loop after each training step.

    Example:
        exponential_scheduler = ExponentialLR(optimizer, end_lr=0.001, num_iter=100)
        for epoch in range(num_epochs):
            for batch in data_loader:
                # Perform training using the current batch

                # Update the learning rate
                exponential_scheduler.step()
    """

    def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
        """
        Initialize the ExponentialLR scheduler.

        Args:
            optimizer (torch.optim.Optimizer): Optimizer to adjust the learning rate for.
            end_lr (float): The final learning rate after the specified number of iterations.
            num_iter (int): The total number of iterations for the learning rate schedule.
            last_epoch (int, optional): The index of the last epoch. Default is -1.
        """
        self.end_lr = end_lr
        self.num_iter = num_iter
        super(ExponentialLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        """
        Calculate the exponentially adjusted learning rates based on the current iteration.

        Returns:
            list: List of adjusted learning rates for each parameter group.
        """
        r = self.last_epoch / (self.num_iter - 1)
        return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]


In [73]:
class LearningRateFinder:
    """
    Learning rate finder class for determining an optimal learning rate during model training.

    Args:
        model (torch.nn.Module): The neural network model to train.
        criterion: The loss function used for training.
        optimizer: The optimizer used for updating model parameters.
        device (torch.device): The device (CPU or GPU) on which to perform training.

    Attributes:
        model (torch.nn.Module): The neural network model to train.
        criterion: The loss function used for training.
        optimizer: The optimizer used for updating model parameters.
        device (torch.device): The device (CPU or GPU) on which to perform training.
        history (dict): A dictionary to store the learning rate and loss history during search.
            Keys: 'lr' for learning rates, 'loss' for corresponding losses.
        best_loss: The best observed loss during the learning rate search.

    Methods:
        find_lr(train_dataloader, start_lr=None, end_lr=10, num_iter=100, step_mode="exp",
                smooth_f=0.05, diverge_th=5, early_stop_patience=None):
            Search for the optimal learning rate within the specified range.

        _train_batch(train_iter):
            Train a single batch and update the model parameters.

        set_learning_rate(start_lrs):
            Set the learning rate(s) for the optimizer.

        plot(skip_start=10, skip_end=5, log_lr=True, show_lr=None):
            Plot the learning rate search results and suggest an optimal learning rate.

    Usage:
        # Create a LearningRateFinder instance
        lr_finder = LearningRateFinder(model, criterion, optimizer, device)

        # Find the optimal learning rate
        lr_finder.find_lr(train_dataloader, start_lr=1e-5, end_lr=10, num_iter=100)

        # Plot the learning rate search results
        lr_finder.plot()
    """
    def __init__(self, model, criterion, optimizer, device):
        """
        Initialize the LearningRateFinder.

        Args:
            model (torch.nn.Module): The neural network model to train.
            criterion: The loss function used for training.
            optimizer: The optimizer used for updating model parameters.
            device (torch.device): The device (CPU or GPU) on which to perform training.
        """
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.history = {"lr": [], "loss": []}
        self.best_loss = None
        
        self.model.to(device)
        
    def find_lr(
        self, train_dataloader, 
        start_lr=None, end_lr=10, 
        num_iter=100, step_mode="exp", 
        smooth_f=0.05, diverge_th=5,
        early_stop_patience=None
    ):
        """
        Search for the optimal learning rate within the specified range.

        Args:
            train_dataloader: The dataloader providing training batches.
            start_lr (float or List[float], optional): The starting learning rate(s).
                If None, uses the learning rate(s) from the optimizer. Default is None.
            end_lr (float): The ending learning rate for the search.
            num_iter (int): The number of iterations for the search.
            step_mode (str, optional): The mode for adjusting learning rates ("exp" or "linear").
                Default is "exp".
            smooth_f (float, optional): The smoothing factor for loss values. Default is 0.05.
            diverge_th (float, optional): The threshold for loss divergence. Default is 5.
            early_stop_patience (int, optional): The patience for early stopping. Default is None.

        Returns:
            None
        """
        # Initialize variables to track learning rate search history
        self.history = {"lr": [], "loss": []}
        self.best_loss = None # Best loss during the learning rate search
        early_stop_counter = 0 # Counter for early stopping criteria

        # Set initial learning rate if provided
        if start_lr:
            self.set_learning_rate(start_lr)

        # Choose a learning rate scheduler based on the specified step_mode
        if step_mode == "exp":
            lr_schedule = ExponentialLR(self.optimizer, end_lr=end_lr, num_iter=num_iter) 
        elif step_mode == "linear":
            lr_schedule = LinearLR(self.optimizer, end_lr=end_lr, num_iter=num_iter)

        # Create an iterator for the training dataloader
        train_iter = TrainDataLoaderIter(train_dataloader)

        # Iterate over the specified number of iterations to find the optimal learning rate
        for iteration in tqdm(range(num_iter)):
            # Train a batch and obtain the loss
            loss = self._train_batch(
                train_iter
            )
            # Record the current learning rate in the history
            self.history["lr"].append(lr_schedule.get_lr()[0])
            lr_schedule.step() # Update the learning rate according to the scheduler

            # Update the best_loss if this is the first iteration, or if a lower loss is encountered
            if iteration == 0:
                self.best_loss = loss
            else:
                # Apply smoothing to the loss if specified
                if smooth_f > 0:
                    loss = smooth_f * loss + (1 - smooth_f) * self.history["loss"][-1]

                # Update best_loss and early stopping counter if a lower loss is not encountered
                if loss < self.best_loss:
                    self.best_loss = loss
                else:
                    # Check for early stopping based on patience
                    early_stop_counter += 1
                    if early_stop_patience and early_stop_counter > early_stop_patience:
                        print(f"Early stopping: Loss has not improved for {early_stop_patience} iterations.")
                        break

            # Record the current loss in the history
            self.history["loss"].append(loss)
            # Check for divergence in loss and stop the search if criteria are met
            if loss > diverge_th * self.best_loss:
                print("Stopping early, the loss has diverged.")
                break

        # Print a message indicating the completion of the learning rate search
        print("Learning rate search finished.")
        
    def _train_batch(self, train_iter):
        """
        Train a single batch and update the model parameters.

        Args:
            train_iter: Iterator providing training batches.

        Returns:
            float: The loss value for the current batch.
        """
        # Set the model in training mode and zero the gradients
        self.model.train()
        self.optimizer.zero_grad()

        # Retrieve inputs and labels for the current batch
        inputs, labels = next(train_iter)
        inputs, labels = inputs.to(self.device), labels.to(self.device)

        # Forward pass: compute model predictions and calculate the loss
        outputs = self.model(inputs)
        loss = self.criterion(outputs, labels)

        # Backward pass: compute gradients and update model parameters
        loss.backward()
        self.optimizer.step()
        
        # Return the loss value as a float
        return loss.item()
    
    def set_learning_rate(self, start_lrs):
        """
        Set the learning rate(s) for the optimizer.

        Args:
            start_lrs (float or List[float]): The initial learning rate(s) to be set.

        Returns:
            None
        """
        # Ensure start_lrs is a list to handle single and multiple parameter groups
        if not isinstance(start_lrs, list):
            start_lrs = [start_lrs] * len(self.optimizer.param_groups)

        # Update the learning rate for each parameter group
        for param_group, start_lr in zip(self.optimizer.param_groups, start_lrs):
            param_group['lr'] = start_lr
            
    def plot(self, skip_start = 10, skip_end=5, log_lr=True, show_lr=None):
        """
        Plot the learning rate search results and suggest an optimal learning rate.

        Args:
            skip_start (int, optional): Number of initial iterations to skip in the plot. Default is 10.
            skip_end (int, optional): Number of final iterations to skip in the plot. Default is 5.
            log_lr (bool, optional): Whether to use a logarithmic scale for the learning rate axis.
                Default is True.
            show_lr (float, optional): The learning rate to highlight on the plot.

        Returns:
            tuple: A tuple containing the plot axis and the suggested optimal learning rate.
        """
        lrs = self.history["lr"]
        losses = self.history["loss"]
        if skip_end == 0:
            lrs = lrs[skip_start:]
            losses = losses[skip_start:]
        else:
            lrs = lrs[skip_start:-skip_end]
            losses = losses[skip_start:-skip_end]
            
        fig, ax = plt.subplots()    
        ax.plot(lrs, losses)
        
        print("LR suggestion: steepest gradient")
        min_grad_idx = (np.gradient(np.array(losses))).argmin()
        print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
        ax.scatter(
            lrs[min_grad_idx],
            losses[min_grad_idx],
            s=75,
            marker="o",
            color="red",
            zorder=3,
            label="steepest gradient",
        )
        ax.legend()
        
        if log_lr:
            ax.set_xscale("log")
        ax.set_xlabel("Learning rate")
        ax.set_ylabel("Loss")
        
        if show_lr is not None:
            ax.axvline(x=show_lr, color="red")
            
        plt.show()
            
        return ax, lrs[min_grad_idx]

In [74]:
# lr_finder = LearningRateFinder(model, criterion, optimizer, device)
# lr_finder.find_lr(train_dataloader, num_iter=500, start_lr=1e-3)
# lr_finder.plot()

#### Train

In [75]:
def train(epoch):
    model.train()
    total_loss = 0.0

    with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}/Training", unit="batch") as pbar:
        for i, (images, labels) in enumerate(train_dataloader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            total_loss += loss.item()
            pbar.update(1)
            pbar.set_postfix(loss=total_loss / (i+1), lr=f"{optimizer.param_groups[0]['lr']:.5f}")

    avg_loss = total_loss / len(train_dataloader.dataset)
    print(f"[Train][Epoch {epoch}] Average Loss: {avg_loss:.5f}, Updated Learning Rate: {optimizer.param_groups[0]['lr']}")

#### Eval

In [76]:
def test(epoch):
    model.eval()
    total_correct = 0
    total_loss = 0.0

    with tqdm(total=len(test_dataloader), desc=f"Epoch {epoch}/Testing", unit="batch") as pbar: 
        with torch.no_grad():
            for i, (images, labels) in enumerate(test_dataloader):
                images, labels = images.to(device), labels.to(device)
                
                output = model(images)                
                loss = criterion(output, labels)

                total_loss += loss.item()
                pbar.update(1)
                pbar.set_postfix(loss=total_loss / (i+1))
                
                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()
            
    avg_loss = total_loss / len(test_dataloader.dataset)
    accuracy = total_correct / len(test_dataloader.dataset)
    
    print(f"[Test][Epoch {epoch}] Loss: {avg_loss:.5f}, Accuracy: {accuracy:3f}")
    return avg_loss

#### Run

In [77]:
def run():
    for epoch in range(EPOCHS):
        train(epoch)
        val_loss = test(epoch)

In [78]:
run()

Epoch 0/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 0] Average Loss: 0.00513, Updated Learning Rate: 0.00026943648


Epoch 0/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 0] Loss: 0.00363, Accuracy: 0.492800


Epoch 1/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 1] Average Loss: 0.00400, Updated Learning Rate: 0.00053707736


Epoch 1/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 1] Loss: 0.00279, Accuracy: 0.698400


Epoch 2/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 2] Average Loss: 0.00337, Updated Learning Rate: 0.0008047182400000001


Epoch 2/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 2] Loss: 0.00212, Accuracy: 0.835500


Epoch 3/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 3] Average Loss: 0.00284, Updated Learning Rate: 0.0010723591200000001


Epoch 3/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 3] Loss: 0.00175, Accuracy: 0.902600


Epoch 4/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 4] Average Loss: 0.00250, Updated Learning Rate: 0.00134


Epoch 4/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 4] Loss: 0.00161, Accuracy: 0.923100


Epoch 5/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 5] Average Loss: 0.00229, Updated Learning Rate: 0.0013383679136740822


Epoch 5/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 5] Loss: 0.00153, Accuracy: 0.937000


Epoch 6/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 6] Average Loss: 0.00215, Updated Learning Rate: 0.001333479606056852


Epoch 6/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 6] Loss: 0.00149, Accuracy: 0.943900


Epoch 7/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 7] Average Loss: 0.00206, Updated Learning Rate: 0.0013253588924916498


Epoch 7/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 7] Loss: 0.00145, Accuracy: 0.950800


Epoch 8/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 8] Average Loss: 0.00200, Updated Learning Rate: 0.0013140453362786738


Epoch 8/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 8] Loss: 0.00144, Accuracy: 0.949500


Epoch 9/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 9] Average Loss: 0.00194, Updated Learning Rate: 0.0012995940559265588


Epoch 9/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 9] Loss: 0.00141, Accuracy: 0.955900


Epoch 10/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 10] Average Loss: 0.00190, Updated Learning Rate: 0.0012820754566205425


Epoch 10/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 10] Loss: 0.00140, Accuracy: 0.955900


Epoch 11/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 11] Average Loss: 0.00184, Updated Learning Rate: 0.001261574887215481


Epoch 11/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 11] Loss: 0.00138, Accuracy: 0.959300


Epoch 12/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 12] Average Loss: 0.00182, Updated Learning Rate: 0.0012381922244248053


Epoch 12/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 12] Loss: 0.00136, Accuracy: 0.960200


Epoch 13/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 13] Average Loss: 0.00179, Updated Learning Rate: 0.0012120413862312148


Epoch 13/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 13] Loss: 0.00136, Accuracy: 0.961000


Epoch 14/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 14] Average Loss: 0.00176, Updated Learning Rate: 0.0011832497768897153


Epoch 14/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 14] Loss: 0.00133, Accuracy: 0.965500


Epoch 15/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 15] Average Loss: 0.00174, Updated Learning Rate: 0.0011519576662268962


Epoch 15/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 15] Loss: 0.00133, Accuracy: 0.967300


Epoch 16/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 16] Average Loss: 0.00172, Updated Learning Rate: 0.001118317506260435


Epoch 16/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 16] Loss: 0.00132, Accuracy: 0.965400


Epoch 17/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 17] Average Loss: 0.00170, Updated Learning Rate: 0.001082493188468191


Epoch 17/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 17] Loss: 0.00132, Accuracy: 0.967300


Epoch 18/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 18] Average Loss: 0.00169, Updated Learning Rate: 0.0010446592453254005


Epoch 18/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 18] Loss: 0.00131, Accuracy: 0.968600


Epoch 19/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 19] Average Loss: 0.00168, Updated Learning Rate: 0.001005


Epoch 19/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 19] Loss: 0.00130, Accuracy: 0.967900


Epoch 20/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 20] Average Loss: 0.00166, Updated Learning Rate: 0.0009637086683486819


Epoch 20/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 20] Loss: 0.00129, Accuracy: 0.970600


Epoch 21/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 21] Average Loss: 0.00165, Updated Learning Rate: 0.0009209864175886612


Epoch 21/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 21] Loss: 0.00129, Accuracy: 0.971200


Epoch 22/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 22] Average Loss: 0.00164, Updated Learning Rate: 0.0008770413862312148


Epoch 22/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 22] Loss: 0.00128, Accuracy: 0.972900


Epoch 23/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 23] Average Loss: 0.00163, Updated Learning Rate: 0.0008320876700517773


Epoch 23/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 23] Loss: 0.00129, Accuracy: 0.969000


Epoch 24/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 24] Average Loss: 0.00162, Updated Learning Rate: 0.0007863442790368435


Epoch 24/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 24] Loss: 0.00128, Accuracy: 0.972100


Epoch 25/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 25] Average Loss: 0.00162, Updated Learning Rate: 0.0007400340703893279


Epoch 25/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 25] Loss: 0.00128, Accuracy: 0.973400


Epoch 26/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 26] Average Loss: 0.00161, Updated Learning Rate: 0.0006933826627906758


Epoch 26/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 26] Loss: 0.00128, Accuracy: 0.973100


Epoch 27/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 27] Average Loss: 0.00160, Updated Learning Rate: 0.0006466173372093244


Epoch 27/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 27] Loss: 0.00127, Accuracy: 0.973100


Epoch 28/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 28] Average Loss: 0.00160, Updated Learning Rate: 0.0005999659296106721


Epoch 28/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 28] Loss: 0.00126, Accuracy: 0.975400


Epoch 29/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 29] Average Loss: 0.00158, Updated Learning Rate: 0.0005536557209631567


Epoch 29/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 29] Loss: 0.00126, Accuracy: 0.974200


Epoch 30/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 30] Average Loss: 0.00159, Updated Learning Rate: 0.0005079123299482226


Epoch 30/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 30] Loss: 0.00126, Accuracy: 0.974000


Epoch 31/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 31] Average Loss: 0.00158, Updated Learning Rate: 0.0004629586137687853


Epoch 31/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 31] Loss: 0.00127, Accuracy: 0.973000


Epoch 32/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 32] Average Loss: 0.00157, Updated Learning Rate: 0.0004190135824113389


Epoch 32/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 32] Loss: 0.00126, Accuracy: 0.973700


Epoch 33/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 33] Average Loss: 0.00157, Updated Learning Rate: 0.0003762913316513183


Epoch 33/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 33] Loss: 0.00126, Accuracy: 0.975500


Epoch 34/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 34] Average Loss: 0.00157, Updated Learning Rate: 0.0003350000000000002


Epoch 34/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 34] Loss: 0.00126, Accuracy: 0.973400


Epoch 35/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 35] Average Loss: 0.00156, Updated Learning Rate: 0.0002953407546745997


Epoch 35/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 35] Loss: 0.00126, Accuracy: 0.973700


Epoch 36/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 36] Average Loss: 0.00156, Updated Learning Rate: 0.00025750681153180894


Epoch 36/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 36] Loss: 0.00125, Accuracy: 0.974600


Epoch 37/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 37] Average Loss: 0.00155, Updated Learning Rate: 0.00022168249373956498


Epoch 37/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 37] Loss: 0.00125, Accuracy: 0.975300


Epoch 38/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 38] Average Loss: 0.00156, Updated Learning Rate: 0.00018804233377310385


Epoch 38/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 38] Loss: 0.00125, Accuracy: 0.975100


Epoch 39/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 39] Average Loss: 0.00155, Updated Learning Rate: 0.00015675022311028482


Epoch 39/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 39] Loss: 0.00126, Accuracy: 0.974600


Epoch 40/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 40] Average Loss: 0.00155, Updated Learning Rate: 0.0001279586137687853


Epoch 40/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 40] Loss: 0.00125, Accuracy: 0.975300


Epoch 41/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 41] Average Loss: 0.00155, Updated Learning Rate: 0.00010180777557519461


Epoch 41/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 41] Loss: 0.00125, Accuracy: 0.974900


Epoch 42/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 42] Average Loss: 0.00155, Updated Learning Rate: 7.842511278451892e-05


Epoch 42/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 42] Loss: 0.00125, Accuracy: 0.975100


Epoch 43/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 43] Average Loss: 0.00155, Updated Learning Rate: 5.7924543379457494e-05


Epoch 43/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 43] Loss: 0.00125, Accuracy: 0.975700


Epoch 44/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 44] Average Loss: 0.00155, Updated Learning Rate: 4.040594407344143e-05


Epoch 44/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 44] Loss: 0.00125, Accuracy: 0.975500


Epoch 45/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 45] Average Loss: 0.00155, Updated Learning Rate: 2.5954663721326342e-05


Epoch 45/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 45] Loss: 0.00125, Accuracy: 0.975700


Epoch 46/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 46] Average Loss: 0.00154, Updated Learning Rate: 1.464110750835019e-05


Epoch 46/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 46] Loss: 0.00125, Accuracy: 0.976000


Epoch 47/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 47] Average Loss: 0.00155, Updated Learning Rate: 6.520393943147932e-06


Epoch 47/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 47] Loss: 0.00125, Accuracy: 0.975700


Epoch 48/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 48] Average Loss: 0.00155, Updated Learning Rate: 1.6320863259177876e-06


Epoch 48/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 48] Loss: 0.00125, Accuracy: 0.976300


Epoch 49/Training:   0%|          | 0/118 [00:00<?, ?batch/s]

[Train][Epoch 49] Average Loss: 0.00155, Updated Learning Rate: 0.0


Epoch 49/Testing:   0%|          | 0/20 [00:00<?, ?batch/s]

[Test][Epoch 49] Loss: 0.00125, Accuracy: 0.975600
