skip this for now, but later sections will refer to this cell, so feel free to come back to it later.

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import numpy as np

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_model(model, train_loader, criterion, optimizer, num_epochs, disp_loss=False):
    model.train()
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            if disp_loss:
                print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}")

def evaluate_model(model, val_loader):
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            all_labels.extend(labels.numpy())
            all_predictions.extend(predicted.numpy())
    accuracy = accuracy_score(all_labels, all_predictions)
    return accuracy

def cross_validate(dataset, num_folds, num_epochs, batch_size, learning_rate):
    labels = [label for _, label in dataset]
    skf = StratifiedKFold(n_splits=num_folds)

    accuracies = []

    for fold, (train_idx, val_idx) in enumerate(skf.split(np.arange(len(dataset)), labels)):
        print(f"Fold {fold + 1}/{num_folds}")
        
        train_subset = torch.utils.data.Subset(dataset, train_idx)
        val_subset = torch.utils.data.Subset(dataset, val_idx)
        
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
        
        model = SimpleNN()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        
        train_model(model, train_loader, criterion, optimizer, num_epochs, disp_loss=True)
        
        accuracy = evaluate_model(model, val_loader)
        accuracies.append(accuracy)
        print(f"Validation Accuracy for Fold {fold + 1}: {accuracy:.4f}")

    return accuracies


# Step 0: Load Data
Data! Here we load the data (your data probably wont look like this). We want the matrices to be a tensor, because in PyTorch, everything needs to be a tensor.

In [12]:
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Step 1: Split data into train and validation sets

For this dataset, we already have a separate train and validation set. We can just load the data from the validation set. 

## Why a validation set?
We want to have data which has never, ever been seen by the model. Neural networks are notorious for overfitting data, which is bad because if you overfit what you have seen, you will not be able to generalize to things which you have not seen. 

## What if my dataset is not split already?
If you have a dataset that is not split already, you can use the `train_test_split` function from `sklearn.model_selection` to split the data into train and validation sets. Generally, you want to use 80% of the data for training and 20% for validation, but it can be somewhat more nuanced if you have a really small or large dataset.


In [13]:
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Step 2: Train the neural network (cross-validated)

Basically we want to train and evaluate the performance of our neural network.

So, how do you determine how well you did? Well, if you have so much data in your validation set that it really represents the full dataset, then you are set. Often however, especially in scientific data, you do not have an infinite supply of validation data. In this case, we want to be very careful about overfitting, so when we train our model, we want to try to get the settings right so that we do not overfit. 

## Cross-validation

The solution to this problem is cross-validation. The idea is that given a training set, we want to split it into separate training and test sets, ideally so that the different splits are as disjoint as possible, and all share the same distribution of the data. 

There are many different cross-validation schemes, but if you have considerable class imbalance, then you should use *stratified cross-validation*. This is a cross-validation scheme that ensures that the classes are distributed in a way that is representative of the true distribution of the data. 

The number of folds is also generally important, since you want to collect meaningful statistics about the performance of your model across folds. A typical choice is 10-fold cross-validation, however if your model takes a very long time to train, then you may want to use a smaller number of folds. 

In [14]:
num_epochs = 5
batch_size = 64
learning_rate = 0.001
num_folds = 5

cross_validate(train_dataset, num_folds, num_epochs, batch_size, learning_rate)

Fold 1/5
Epoch 1/5, Loss: 2.3083
Epoch 1/5, Loss: 2.2797
Epoch 1/5, Loss: 2.2949
Epoch 1/5, Loss: 2.2800
Epoch 1/5, Loss: 2.2351
Epoch 1/5, Loss: 2.2279
Epoch 1/5, Loss: 2.2238
Epoch 1/5, Loss: 2.1749
Epoch 1/5, Loss: 2.1525
Epoch 1/5, Loss: 2.1406
Epoch 1/5, Loss: 2.0987
Epoch 1/5, Loss: 2.0917
Epoch 1/5, Loss: 2.0222
Epoch 1/5, Loss: 2.0029
Epoch 1/5, Loss: 2.0003
Epoch 1/5, Loss: 1.9724
Epoch 1/5, Loss: 1.8772
Epoch 1/5, Loss: 1.8964
Epoch 1/5, Loss: 1.8340
Epoch 1/5, Loss: 1.7083
Epoch 1/5, Loss: 1.6651
Epoch 1/5, Loss: 1.6407
Epoch 1/5, Loss: 1.4547
Epoch 1/5, Loss: 1.4896
Epoch 1/5, Loss: 1.5598
Epoch 1/5, Loss: 1.5007
Epoch 1/5, Loss: 1.2369
Epoch 1/5, Loss: 1.4549
Epoch 1/5, Loss: 1.2986
Epoch 1/5, Loss: 1.3863
Epoch 1/5, Loss: 1.2845
Epoch 1/5, Loss: 1.1120
Epoch 1/5, Loss: 0.9700
Epoch 1/5, Loss: 1.1269
Epoch 1/5, Loss: 0.9998
Epoch 1/5, Loss: 1.0692
Epoch 1/5, Loss: 1.0510
Epoch 1/5, Loss: 0.9438
Epoch 1/5, Loss: 0.9340
Epoch 1/5, Loss: 1.0082
Epoch 1/5, Loss: 0.9553
Epoch 1

[0.9733333333333334,
 0.9685833333333334,
 0.9681666666666666,
 0.9605833333333333,
 0.9705833333333334]

# Also Step 2: Train the neural network (on the validation set)

Once you are satisfied that the neural network is performining well on the training set, you can train it on the validation set. This will give you a better idea of how well the neural network will perform on new data, and is ultimately what is going to be important. 

Here we just setup our simple neural network to train on the validation set, train it, and then evaluate it. Here we only care about the accuracy score, so we will just use the `accuracy_score` function from scikit-learn. 

## Also: the loss function

Usually, when training a neural network, it is done with respect to some loss function. This choice can be critical, and different loss functions have different flavors. For classification problems (like MNIST), you obviously cannot use a loss function which is exactly like how you would evaluate the performance, however for regression problems it is very feasible to use something like mean squared error (MSE), which will at least be proportional to your usual evaluation metrics. 

Also if your model is not working well you can look at the loss function to see how it is behaving (usually you want to plot it to see if there are any hurdles it is struggling with).

## In Practice

In practice, accuracy is not the only metric you would use to evaluate the performance of a neural network. Model performance is usually multi-faceted, and doing well on one metric can come with tradeoffs on other metrics. 

For classification problems, you might want to consider things like precision and recall (wikipedia is fine for this). For regression problems, you might want to consider things like mean absolute error (MAE), mean squared error (MSE), and mean absolute percentage error (MAPE).

In [15]:
final_train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
final_model = SimpleNN()
final_optimizer = optim.Adam(final_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

train_model(final_model, final_train_loader, criterion, final_optimizer, num_epochs)

final_val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
final_accuracy = evaluate_model(final_model, final_val_loader)
print(f"Final Validation Accuracy: {final_accuracy:.4f}")

Final Validation Accuracy: 0.9760
