## Boilerplate code

In [None]:
# Funcitons for capturing time elapsed
import time, gc

# Timing utilities
start_time = None

def start_timer():
    global start_time
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.synchronize()
    start_time = time.time()

def end_timer_and_print(local_msg):
    if device == "cuda":
        torch.cuda.synchronize()
    end_time = time.time()
    print("\n" + local_msg)
    print("Total execution time = {:.3f} sec".format(end_time - start_time))
    print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))

In [None]:
import torch, datetime, os

# Essential packages for training an image classifier in PyTorch
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.cuda import amp

import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
torch.manual_seed(43)
cudnn.deterministic = True
cudnn.benchmark = False

In [None]:
# import and instantiate tensorboard for monitoring model performance
from torch.utils.tensorboard import SummaryWriter

Setting infrastructure for training in a Jupyter notebook.
In a python script version of the code, this section should be parsed in as arguments.

In [None]:
nodes = 1
gpus=0
num_workers = 8
batch_size=64
epochs=2
lr=1e-3
momentum=0.9
weight_decay=5e-4
print_interval=100

## Miscellaneous utility funtions

In [None]:
def accuracy(outputs, labels):
    preds = outputs.argmax(dim=1)
    return torch.sum(preds == labels).item()

## DataLoader
Add a data management section to load and transform data.
Here we manage not only the data location but also how it is loaded into memory

In [None]:
# Prepare training data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
    ])


val_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
    ])

datadir=os.environ['DATA_DIR']
trainset = torchvision.datasets.ImageFolder(root=os.path.join(datadir,'train'),
                                                transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=num_workers,
                                          pin_memory=True,
                                          drop_last=False)

valset = torchvision.datasets.ImageFolder(root=os.path.join(datadir,'val'),
                                              transform=val_transform)
valloader = torch.utils.data.DataLoader(valset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=num_workers,
                                             pin_memory=True,
                                             drop_last=False)

## Choose a Neural Network architecture

In [None]:
# Pre-training
net=torchvision.models.resnet50(weights=None,num_classes=200)
# Transfer learning
#net=torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)

## Define a Loss function and optimizer
Let's use a Classification Cross-Entropy loss and SGD with momentum.
If trianing on GPUs, we can move the object for loss function to GPU memory as well

