# PyTorch Demo

# Imports & Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms

from sklearn.metrics import confusion_matrix

### Check Cuda Availability

In [None]:
cuda_available = torch.cuda.is_available()
print(f'CUDA Available: {cuda_available}')
if cuda_available:
    print('CUDA device:', torch.cuda.get_device_name(0))

# Data Loading & Preprocessing

### Dataset Transforms

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),                                 # Transform images to tensors
    lambda image: image.view(image.size(1), image.size(2)) # Remove color channel dimension, since we only have black and white
])

### Getting the data, transforming, and loading it

In [None]:
mnist_train = MNIST(root="./", train=True, download=True, transform=transform)
mnist_test = MNIST(root="./", train=False, download=True, transform=transform)

train_loader = DataLoader(mnist_train, batch_size=10000, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=10000, shuffle=True)

### Example input

In [None]:
for images, labels in test_loader:
    sample_image = images[0]
    sample_label = labels[0]
    break

print('type:', type(sample_image))
print('shape:', sample_image.shape)
print('raw data:', sample_image)

### Example visual

In [None]:
plt.imshow(sample_image, cmap="gray")

print('sample lable:', sample_label) # Label tensor
plt.show()

### Reshaping example

In [None]:
print('original shape:', sample_image.shape)                                                              # Original shape
print('reshaped using `view`:', sample_image.view(sample_image.size(0)*sample_image.size(1)).shape)       # View to change shape
print('reshaped using `reshape`:', sample_image.reshape(sample_image.size(0)*sample_image.size(1)).shape) # Reshape to change shape
print('reshaped using `flatten`:', sample_image.flatten().shape)                                          # Flatten to change shape

# Neural Network Model

In [None]:
class Model(nn.Module):
    # Model initialization
    def __init__(self):
        super(Model, self).__init__()

        # Hidden layer initialization
        # Fill me in!

        # Output layer initialization
        # Fill me in!

        # Activation functions
        # Fill me in!
        
    # Model operation
    def forward(self, x):
        # Fill me in!
        return None

### Model Initialization

In [None]:
model = Model()
if cuda_available:
    model.cuda()

print(model)

### Training Protocol Initialization

In [None]:
criterion = None # Fill me in!  # Loss function
optim = None # Fill me in! # Optimizer

# Model Training

### Defining training function

In [None]:
# Run a single epoch
def run_epoch(train):
    if train:
        model.train()
        loader = train_loader
    else:
        model.eval()
        loader = test_loader

    # LOGGING
    total_loss = 0.                # Total epoch loss
    confusion = np.zeros((10, 10)) # Used for tracking accuracy (10 because of 10 labels)

    for x, y in loader:
        if cuda_available: # Move tensors over to cuda if present
            x = x.cuda()
            y = y.cuda()

        # Data in the loader is 28x28, but our model expects a flattened tensor
        x = x.flatten(start_dim=1, end_dim=2)

        # Put data through model, get predictions. This calls the forward() function!
        predictions = model(x)

        # Calculate loss (how wrong our predictions are)
        loss = criterion(predictions, y)

        if train:
            optim.zero_grad() # Reset gradients
            loss.backward()   # Calculate new gradients
            optim.step()      # Run optimizer & update weights and biases

        # LOGGING
        total_loss += loss.item() * predictions.size(0)
        confusion += confusion_matrix(y.cpu(), predictions.argmax(dim=1).cpu(), labels=range(10))

    return total_loss / len(loader.dataset), \
           confusion.diagonal().sum() / confusion.sum(), \
           (confusion.diagonal() / confusion.sum(axis=1)).mean()

### Training

In [None]:
epochs = 100
print_frequency = 5

In [None]:
train_losses = []
test_losses = []

train_accuracies = []
test_accuracies = []

train_balanced_accuracies = []
test_balanced_accuracies = []

for e in range(epochs):
    train_loss, train_accuracy, train_balanced_accuracy = run_epoch(True) # Train model
    test_loss, test_accuracy, test_balanced_accuracy = run_epoch(False)   # Test model
    
    # Performance tracking
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_accuracies.append(train_accuracy)
    test_accuracies.append(test_accuracy)
    train_balanced_accuracies.append(train_balanced_accuracy)
    test_balanced_accuracies.append(test_balanced_accuracy)
    
    if (e+1) % print_frequency == 0:
        epoch_counter = f'Epoch: {" " * (len(str(epochs)) - len(str(e + 1)))}{e + 1}/{epochs}'
        print('----------------------------------------------------------------------------------------------------')
        print(f'{epoch_counter} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_accuracy:.4f} | Train Balanced Accuracy: {train_balanced_accuracy:.4f}')
        print(f'{" " * len(epoch_counter)} |  Test Loss: {test_loss:.4f} |  Test Accuracy: {test_accuracy:.4f} |  Test Balanced Accuracy: {test_balanced_accuracy:.4f}')


print('----------------------------------------------------------------------------------------------------')

# Performance plots

In [None]:
fig, axs = plt.subplots(2,2,figsize=(15,10))

# Loss
axs[0,0].plot(train_losses, label='Training')
axs[0,0].plot(test_losses, label='Test')
axs[0,0].legend()
axs[0,0].set_title(f'Model Loss ({criterion.__class__.__name__})')
axs[0,0].set_xlabel('Epoch')
axs[0,0].set_ylabel('Loss')

# Accuracy
axs[0,1].plot(train_accuracies, label='Training')
axs[0,1].plot(test_accuracies, label='Test')
axs[0,1].legend()
axs[0,1].set_title('Accuracy')
axs[0,1].set_xlabel('Epoch')
axs[0,1].set_ylabel('Loss')

# Balanced Accuracy
axs[1,1].plot(train_balanced_accuracies, label='Training')
axs[1,1].plot(test_balanced_accuracies, label='Test')
axs[1,1].legend()
axs[1,1].set_title('Balanced Accuracy')
axs[1,1].set_xlabel('Epoch')
axs[1,1].set_ylabel('Balanced Accuracy')

fig.delaxes(axs[1,0])

plt.show()