In [2]:
import os

In [3]:
os.chdir("../")

In [5]:
import tempfile
import torch
import numpy as np
from livelossplot import PlotLosses
from livelossplot.outputs import MatplotlibPlot
from tqdm import tqdm
from src.helper import after_subplot

In [6]:
def train_one_epoch(train_dataloader, model, optimizer, loss):
    """
    Performs one training epoch
    """

    if torch.cuda.is_available():
     # Transfer the model to the GPU
        model = model.cuda()

    # Set the model to training mode
    model.train()
    train_loss = 0.0

    for batch_idx, (data, target) in tqdm(
        enumerate(train_dataloader),
        desc="Training",
        total=len(train_dataloader),
        leave=True,
        ncols=80,
    ):
        # Move data to GPU
        if torch.cuda.is_available():
            data, target = data.cuda(), target.cuda()

        # 1. Clear the gradients of all optimized variables
        optimizer.zero_grad()

        # 2. Forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)

        # 3. Calculate the loss
        loss_value = loss(output, target)

        # 4. Backward pass: compute gradient of the loss with respect to model parameters
        loss_value.backward()

        # 5. Perform a single optimization step (parameter update)
        optimizer.step()

        # Update average training loss
        train_loss = train_loss + ((1 / (batch_idx + 1)) * (loss_value.data.item() - train_loss))

    return train_loss

In [7]:
def valid_one_epoch(valid_dataloader, model, loss):
    """
    Validate at the end of one epoch
    """

    with torch.no_grad():
        # Set the model to evaluation mode
        model.eval()

        if torch.cuda.is_available():
            model.cuda()

        valid_loss = 0.0
        for batch_idx, (data, target) in tqdm(
            enumerate(valid_dataloader),
            desc="Validating",
            total=len(valid_dataloader),
            leave=True,
            ncols=80,
        ):
            # Move data to GPU
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()

            # 1. Forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)

            # 2. Calculate the loss
            loss_value = loss(output, target)

            # Calculate average validation loss
            valid_loss = valid_loss + ((1 / (batch_idx + 1)) * (loss_value.data.item() - valid_loss))

    return valid_loss

In [8]:
def optimize(data_loaders, model, optimizer, loss, n_epochs, save_path, interactive_tracking=False):
    # Initialize tracker for minimum validation loss
    if interactive_tracking:
        liveloss = PlotLosses(outputs=[MatplotlibPlot(after_subplot=after_subplot)])
    else:
        liveloss = None

    valid_loss_min = None
    logs = {}

    # Learning rate scheduler: setup a learning rate scheduler that
    # Reduces the learning rate when the validation loss reaches a plateau
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10)

    for epoch in range(1, n_epochs + 1):

        # Training phase
        train_loss = train_one_epoch(data_loaders["train"], model, optimizer, loss)

        # Validation phase
        valid_loss = valid_one_epoch(data_loaders["valid"], model, loss)

        # Print training/validation statistics
        print(
            "Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}".format(
                epoch, train_loss, valid_loss
            )
        )

        # Early stopping if validation loss is below a threshold
        if valid_loss < 0.5:
            torch.save(model.state_dict(), save_path)
            print("Early stopping as validation loss is less than 0.5")
            break

        # If the validation loss decreases by more than 1%, save the model
        if valid_loss_min is None or ((valid_loss_min - valid_loss) / valid_loss_min > 0.01):
            print(f"New minimum validation loss: {valid_loss:.6f}. Saving model ...")
            torch.save(model.state_dict(), save_path)
            valid_loss_min = valid_loss

        # Update learning rate using the scheduler
        scheduler.step(valid_loss)

        # Log the losses and the current learning rate
        if interactive_tracking:
            logs["loss"] = train_loss
            logs["val_loss"] = valid_loss
            logs["lr"] = optimizer.param_groups[0]["lr"]

            liveloss.update(logs)
            liveloss.send()



In [10]:
def one_epoch_test(test_dataloader, model, loss):
    # Monitor test loss and accuracy
    test_loss = 0.0
    correct = 0.0
    total = 0.0

    with torch.no_grad():
        # Set the model to evaluation mode
        model.eval()

        if torch.cuda.is_available():
            model = model.cuda()

        for batch_idx, (data, target) in tqdm(
            enumerate(test_dataloader),
            desc='Testing',
            total=len(test_dataloader),
            leave=True,
            ncols=80
        ):
            # Move data to GPU
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()

            # 1. Forward pass: compute predicted outputs by passing inputs to the model
            logits = model(data)

            # 2. Calculate the loss
            loss_value = loss(logits, target)

            # Update average test loss
            test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss_value.data.item() - test_loss))

            # Convert logits to predicted class
            pred = logits.data.max(1, keepdim=True)[1]

            # Compare predictions to true label
            correct += torch.sum(torch.squeeze(pred.eq(target.data.view_as(pred))).cpu())
            total += data.size(0)

    print(f'Test Loss: {test_loss:.6f}\n')
    print(f'\nTest Accuracy: {100. * correct / total:.2f}% ({correct}/{total})')

    return test_loss
