# ICS 635 - Assignment 4
> Derek Garcia

## Load and Normalize FashionMNIST

In [None]:
from torch.utils.data import random_split, DataLoader
from torchvision import datasets, transforms

# Download training data
fashion_train_data = datasets.FashionMNIST(
    root='./data',  # save images to ./data
    train=True,     # training data
    download=True,  # download locally
    transform=transforms.ToTensor()     # Normalize the images to the range [0,1]
)
# split into 80/20 testing and validation
size = len(fashion_train_data)
train_size = int(0.8 * size)
validation_size = size - train_size
train, validation = random_split(fashion_train_data, [train_size, validation_size])

# download the testing data
test = datasets.FashionMNIST(
    root='./data',  # save images to ./data
    train=False,    # testing data
    download=True,  # download locally
    transform=transforms.ToTensor()     # Normalize the images to the range [0,1]
)

# create loaders
train_loader = DataLoader(train, batch_size=64, shuffle=True)
validation_loader = DataLoader(validation, batch_size=64, shuffle=True)
test_loader = DataLoader(test, batch_size=64, shuffle=False)

## Building the CNN Model

In [None]:
from torch import nn
import torch

# Tutorial: https://www.youtube.com/watch?v=pDdP0TFzsoQ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# implement conv net
class CNN(nn.Module):
    def __init__(self):
        # todo
        pass

    def forward(self, x):
        # todo
        pass


## Training the Model

In [None]:
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()   # loss function
optimizer = torch.optim.SGD(model.parameters()) # optimizer

# Initialize variables to track the best epoch
best_val_loss = float('inf')
best_val_epoch = 0
best_model_wts = model.state_dict()

# training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    # track training for early stopping
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)   # calc loss
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # track train loss
        running_loss += loss.item() * images.size(0)
        # track train accuracy
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        # sum correct matches
        for p in predicted:
            correct_train += 1 if p == labels else 0
        
    # calculate avg training loss and accuracy
    train_loss = running_loss / len(train_loader.dataset)
    train_acc = (correct_train / total_train) * 100
    
    # validation loop
    model.eval()  # set model to evaluation mode
    correct_val = 0
    total_val = 0
    running_val_loss = 0.0
    
    with torch.no_grad():  # no_grad disables gradiant calc to save memory b/c not updating weights
        for images, labels in validation_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)   # calc loss
            
            # Calculate loss
            val_loss = criterion(outputs, labels)
            
            # track validation loss
            running_val_loss += loss.item() * images.size(0)
            # track validation accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            # sum correct matches
            for p in predicted:
                correct_val += 1 if p == labels else 0
            
    # Calculate average validation loss and accuracy
    val_loss = running_val_loss / len(validation_loader.dataset)
    val_acc = (correct_val / total_val) * 100

    print(f"Epoch [{epoch+1}/{epochs}] - "
          f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}% - "
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")
    
    # Save model if beats previous
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_epoch = epoch
        best_model_wts = model.state_dict()
        
# after all epochs, use best model
model.load_state_dict(best_model_wts)
print(f"Best validation: epoch {best_val_epoch + 1}")

# Test best model
model.eval() 
correct_test = 0
total_test = 0
running_test_loss = 0.0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Calculate loss
        test_loss = criterion(outputs, labels)
        running_test_loss += test_loss.item() * images.size(0)
        
        # Track test accuracy
        _, predicted = torch.max(outputs.data, 1)
        total_test += labels.size(0)
        for p in predicted:
            correct_test += 1 if p == labels else 0

# Calculate average test loss and accuracy
test_loss = running_test_loss / len(test_loader.dataset)
test_acc = (correct_test / total_test) * 100

print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")