In [None]:
# based on https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-use-k-fold-cross-validation-with-pytorch.md

In [None]:
from sklearn.model_selection import KFold
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader,  SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import MNIST
from tqdm import tqdm, trange
import torch

In [None]:
def reset_weights(model: nn.Module):
  for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
    print(f'Reset trainable parameters of layer {layer}')
    layer.reset_parameters()

In [None]:
class SimpleConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Conv2d(1, 10, kernel_size=3),
      nn.ReLU(),
      nn.Flatten(),
      nn.Linear(26 * 26 * 10, 50),
      nn.ReLU(),
      nn.Linear(50, 20),
      nn.ReLU(),
      nn.Linear(20, 10)
    )


  def forward(self, x):
    return self.layers(x)

In [None]:
torch.manual_seed(42)

dataset_dir = '../temp'
k_folds = 5
loss_function = nn.CrossEntropyLoss()
num_epochs = 2

results = {}

In [None]:
train_dataset = MNIST(dataset_dir, download=True, transform=transforms.ToTensor(), train=True)
test_datasdet = MNIST(dataset_dir, download=True, transform=transforms.ToTensor(), train=False)

dataset = ConcatDataset([train_dataset, test_datasdet])

In [None]:
kfold = KFold(n_splits=k_folds, shuffle=True)

In [None]:
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
    print(f'Fold: {fold}')
    
    train_subsampler = SubsetRandomSampler(train_ids)
    test_subsampler = SubsetRandomSampler(test_ids)
    
    train_loader = DataLoader(dataset, batch_size=10, sampler=train_subsampler)
    test_loader = DataLoader(dataset, batch_size=10, sampler=test_subsampler)
    
    network = SimpleConvNet()
    network.apply(reset_weights)
    
    optimizer = torch.optim.Adam(network.parameters(), lr=1e-4)
    
    for epoch in trange(num_epochs, desc="Epoch"):
        print(f'Starting epoch {epoch + 1}')
        current_loss = 0.0

        # Wrap train_loader with tqdm for a progress bar
        for i, data in enumerate(tqdm(train_loader, desc=f"Training Fold {fold} Epoch {epoch+1}"), 0):
            inputs, targets = data
            
            optimizer.zero_grad()
    
            outputs = network(inputs)
            
            loss = loss_function(outputs, targets)
            
            loss.backward()
            
            optimizer.step()
            
            current_loss += loss.item()
            if i % 500 == 499:
                print('Loss after mini-batch %5d: %.3f' %
                      (i + 1, current_loss / 500))
                current_loss = 0.0
                
        print('Training process has finished. Saving trained model.')

        print('Starting testing')
        
        save_path = f'../temp/k-fold-example-model-fold-{fold}.pth'
        torch.save(network.state_dict(), save_path)

        correct, total = 0, 0
        # Wrap test_loader with tqdm for a progress bar
        with torch.no_grad():
            for i, data in enumerate(tqdm(test_loader, desc=f"Testing Fold {fold}"), 0):

                inputs, targets = data

                outputs = network(inputs)

                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

            print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
          
            results[fold] = 100.0 * (correct / total)

In [None]:
sum = 0.0

for key, value in results.items():
  print(f'Fold {key}: {value} %')
  sum += value

print(f'Average: {sum/len(results.items())} %')