# ICS 635 - Assignment 4
> Derek Garcia

## Load and Normalize FashionMNIST

In [None]:
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms

# Download training data
fashion_train_data = datasets.FashionMNIST(
    root='./data',  # save images to ./data
    train=True,     # training data
    download=True,  # download locally
    transform=transforms.ToTensor()     # Normalize the images to the range [0,1]
)
# split into 80/20 testing and validation
size = len(fashion_train_data)
train_size = int(0.8 * size)
validation_size = size - train_size
train, validation = random_split(fashion_train_data, [train_size, validation_size])

# download the testing data
test = datasets.FashionMNIST(
    root='./data',  # save images to ./data
    train=False,    # testing data
    download=True,  # download locally
    transform=transforms.ToTensor()     # Normalize the images to the range [0,1]
)

# create loaders
train_loader = DataLoader(train, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

## Building the CNN Model

In [None]:
from torch import nn
import torch

# Tutorial: https://www.youtube.com/watch?v=pDdP0TFzsoQ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# implement conv net
class CNN(nn.Module):
    def __init__(self):
        # todo
        pass

    def forward(self, x):
        # todo
        pass
# criterion = nn.CrossEntropyLoss()   # loss function
# optimizer = torch.optim.SGD(model.parameters()) # optimizer

## Training the Model: Train and Test Methods

In [None]:
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from typing import Tuple, Literal


def iter_dataset(model: CNN, mode: Literal['train', 'eval'], loader: DataLoader, 
                 criterion: CrossEntropyLoss, optimizer: SGD) -> Tuple[CNN, float, float]:
    """
    Iterate over a dataset and compute loss
    
    todo - different criterion and optimizer classes?
    :param model: Model to iter data over
    :param mode: Model in train or eval mode
    :param loader: loader of data
    :param criterion: loss function
    :param optimizer: optimizer
    :return: the updated model, loss, accuracy
    """
    # init running totals
    running_loss = 0.0
    correct = 0
    total = 0
    # set model mode and update gradiant
    match mode:
        case 'train':
            model.train()
            torch.set_grad_enabled(True)
        case 'eval':
            model.eval()
            torch.set_grad_enabled(False)   # no_grad disables gradiant calc to save memory b/c not updating weights
            
    # iter through data
    for images, labels in loader:
        # designate cpu or gpu to process
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)   # calc loss
        
        # Backward pass if training
        if mode == 'train':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # track loss
        running_loss += loss.item() * images.size(0)
        # track accuracy
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        # sum correct matches
        for p in predicted:
            correct += 1 if p == labels else 0
    # always reset to true
    torch.set_grad_enabled(True)    
    # calc return the updated model and stats
    loss = running_loss / len(loader.dataset)
    acc = (correct / total) * 100
    return model, loss, acc


def train(criterion: CrossEntropyLoss, optimizer: SGD, epochs: int = 5) -> CNN:
    """
    Train a model and return the best results
    
    :param criterion: loss function
    :param optimizer: optimizer
    :param epochs: Number of training epochs (default: 5)
    :return: Best model from training epochs
    """
    # init model
    model = CNN().to(device)
    
    # init variables to track the best epoch
    best_val_loss = float('inf')    # set to infinity so first epoch is always best
    best_val_epoch = 0
    best_model_wts = None
    
    # training loop
    for epoch in range(epochs):
        # training loop
        model, train_loss, train_acc = iter_dataset(model, 'train', train_loader, criterion, optimizer)
    
        # validation loop
        model, val_loss, val_acc = iter_dataset(model, 'eval', validation_loader, criterion, optimizer)
    
        # print stats
        print(f"Epoch [{epoch+1}/{epochs}] - "
              f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}% - "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")
        
        # Save model if beats previous
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_epoch = epoch
            best_model_wts = model.state_dict()     # save model as dict
            
    # after all epochs, use best model
    model.load_state_dict(best_model_wts)
    print(f"Best model: epoch {best_val_epoch + 1}")
    
    return model

def test(model: CNN, criterion: CrossEntropyLoss, optimizer: SGD) -> None:
    """
    Test the model and print results
    
    :param model: Model to test
    :param criterion: loss function
    :param optimizer: optimizer
    """
    # Test model
    model, test_loss, test_acc = iter_dataset(model, 'eval', test_loader, criterion, optimizer)
    # print final results
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")