# Example CV with PyTorch

Here's a short example of doing CV with PyTorch. It basically involves pre-determining the splits for K-fold CV using sklearn's `StratifiedKFold` and then looping over the train/valid indices for each split and training and evaluating a model for each one. I've commented the code below so hopefully it should be relatively intuitive - let me know if not!

## Imports

In [7]:
import torch
import numpy as np
from torch.utils.data import SubsetRandomSampler, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from sklearn.model_selection import StratifiedKFold

## Data Preparation

In [8]:
TRAIN_DIR = "cropped_data/"  # note that I removed the annotated subdirectory before running this code
IMAGE_SIZE = 56
BATCH_SIZE = 8
torch.manual_seed(123)

# Prepare dataset (I'm just using some random transforms here)
train_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor()
])
valid_transforms = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
])

# Create datasets
# Note we are creating two "datasets", one with transfoorms for training, and one without for validation
train_dataset = ImageFolder(root=TRAIN_DIR, transform=train_transforms)
valid_dataset = ImageFolder(root=TRAIN_DIR, transform=valid_transforms)

# Prepare folds
n = len(train_dataset)  # total number of samples
kfold_splitter = StratifiedKFold(n_splits=5, shuffle=True, random_state=123)
splits = kfold_splitter.split(X=np.zeros(n), y=train_dataset.targets)  # we only need y to generate splits, so I'm setting X to be an array of 0's, it doesn't matter

# Train/evaluate model via CV
for fold, (train_idx, valid_idx) in enumerate(splits, 1):
    print(f"Running CV fold: {fold}...")
    # Loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(train_idx))
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, sampler=SubsetRandomSampler(valid_idx))
    # Train model
#     model = model()  # make an instance of the model here
#     trainer()        # train the model as usual here, outputting whatever metrics you like
print(f"Finished CV!")

Running CV fold: 1...
Running CV fold: 2...
Running CV fold: 3...
Running CV fold: 4...
Running CV fold: 5...
Finished CV!
