## Place to put all extra functions that are not main training function

In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50
from torchvision import datasets, models, transforms
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def makeAll(dataloader: DataLoader, model: models.ResNet, device) -> tuple[list[float], list[float]]:
    '''
    Gets all labels and predictions for the images in the dataloader 
    '''
    predlist=torch.zeros(0,dtype=torch.long, device='cpu')
    lbllist=torch.zeros(0,dtype=torch.long, device='cpu')
    with torch.no_grad():
        for i, (inputs, classes) in enumerate(dataloader):
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Append batch prediction results
            predlist=torch.cat([predlist,preds.view(-1).cpu()])
            lbllist=torch.cat([lbllist,classes.view(-1).cpu()])
    return lbllist.numpy(), predlist.numpy()

In [None]:
%matplotlib inline
# taken from Lab 5
def imshow(inp, augment=True, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    #undo normalization
    if augment:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = std * inp + mean
        inp = np.clip(inp, 0, 1)

    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)


# Generic function to display predictions for a few images
def visualize_model(model: models.ResNet, dataloader: DataLoader, class_labels, device, num_images=6, augment=True):
    '''Shows some images and their predictions vs true label'''
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure().set_figwidth(12)

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(2, num_images//2, images_so_far)
                ax.axis('off')
                ax.set_title('pred/true: {}/{}'.format(class_labels[preds[j]],
                                                       class_labels[labels[j]]))
                imshow(inputs.cpu().data[j], augment=augment)

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)