### Setup

In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, Grayscale
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import h5py

In [2]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### Data Importing

In [3]:
class Data(Dataset):
    
    def __init__(self, train=0, gray=True):
        
        # Load data
        self.train = (train == 0)
        if train == 0:
            self.x = self.load_h5_as_numpy('camelyonpatch_level_2_split_train_x.h5', 'x')
            self.y = self.load_h5_as_numpy('camelyonpatch_level_2_split_train_y.h5', 'y')[:, 0, 0, 0]
        elif train == 1:
            self.x = self.load_h5_as_numpy('camelyonpatch_level_2_split_test_x.h5', 'x')
            self.y = self.load_h5_as_numpy('camelyonpatch_level_2_split_test_y.h5', 'y')[:, 0, 0, 0]
        elif train == 2:
            self.x = self.load_h5_as_numpy('camelyonpatch_level_2_split_valid_x.h5', 'x')
            self.y = self.load_h5_as_numpy('camelyonpatch_level_2_split_valid_y.h5', 'y')[:, 0, 0, 0]
        self.gray = gray
            
        # Prepare transforms
        self.t1 = RandomHorizontalFlip(p=0.5)
        self.t2 = RandomVerticalFlip(p=0.5)
        self.g = Grayscale()
            
    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        
        # Get images
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if type(idx) is list:
            img = torch.from_numpy(self.x[idx].astype(np.float32)).permute(0, 3, 1, 2)/128 - 1
            cls = torch.from_numpy(self.y[idx])
        else:
            img = torch.from_numpy(self.x[idx].astype(np.float32)).permute(2, 0, 1)/128 - 1
            cls = self.y[idx]
            
        # Transforms
        if self.train:
            img = self.t1(self.t2(img))
        if self.gray:
            img = self.g(img)
        
        # Return
        return (img, cls)
    
    def load_h5_as_numpy(self, file_name, key):
        with h5py.File(file_name, 'r') as h5_file:
            data = h5_file[key][:]
        return data

### Models

In [4]:
class CustomModel(nn.Module):
    
    def __init__(self, dropout_rate=0.2, gray=False):
        
        super().__init__()
        
        first_deg = 1 if gray else 3
        
        self.conv = nn.Sequential(
            
            nn.Conv2d(first_deg, 16, 3, padding='valid'),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding='valid'),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(16, 32, 3, padding='valid'),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding='valid'),
            nn.ReLU(),
            nn.MaxPool2d(2),    

            nn.Conv2d(32, 64, 3, padding='valid'),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding='valid'),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            nn.Flatten()
        )
        
        self.drop = nn.Dropout(dropout_rate)
        
        self.lin1 = nn.Sequential(
            nn.Linear(4096, 256),
            nn.ReLU()
        )
        
        self.lin2 = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        self.lin3 = nn.Linear(128, 2)
        
        self.train_model = nn.Sequential(
            self.conv,
            self.drop,
            self.lin1,
            self.drop,
            self.lin2,
            self.drop,
            self.lin3
        )
        
        self.test_model = nn.Sequential(
            self.conv,
            self.lin1,
            self.lin2,
            self.lin3
        )
    
    def forward(self, x):
        if self.training:
            return self.train_model(x)
        else:
            return self.test_model(x)
        

### Training / Testing Loop

In [5]:
class Log():
    
    def __init__(self):
        self.log_text = ""

    def log(self, text):
        self.log_text += (text + "\n")
        print(text)

    def get_log(self):
        return self.log_text

In [6]:
def training(model, loss_fn, optimiser, epoch, train_name, save=False, patience=1, factor=0.4):
    
    # Setup
    logger = Log()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser, mode="max", patience=patience, factor=factor, verbose=True)
    train_loss = np.zeros(epoch)
    valid_loss = np.zeros(epoch)
    
    # Main epoch loop
    for i in range(epoch):

        # Training
        model.train()
        for inputs, labels in tqdm(train_dataloader, mininterval=1):
            y_pred = model(inputs.to(dev))
            loss = loss_fn(y_pred, labels.to(dev))
            train_loss[i] += loss.item()
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

        # Validating
        model.eval()
        acc = 0
        loss = 0
        for inputs, labels in tqdm(val_dataloader, mininterval=1):
            y_pred = model(inputs.to(dev))
            acc += (torch.argmax(y_pred, 1) == labels.to(dev)).float().sum()
            valid_loss[i] += loss_fn(y_pred, labels.to(dev)).item()
        acc = 100 * float(acc) / len(val_dataloader.dataset)
        logger.log(f"Epoch {i+1}: validation accuracy {round(acc, 2)}")
        scheduler.step(acc)
        
    # Normalise loss data
    train_loss = 100 * train_loss / len(train_dataloader)
    valid_loss = 100 * valid_loss / len(val_dataloader)
        
    # Create the loss plot
    epochs = np.arange(1, epoch+1)
    plt.plot(epochs, train_loss)
    plt.plot(epochs, valid_loss)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(['Training Loss','Validation Loss'])
    
    # Save data
    if save:
        torch.save(model.state_dict(), f'{train_name}.model')
        plt.savefig(f'{train_name}_train.png')
        with open(f'{train_name}_train.txt', "w") as text_file:
            text_file.write(logger.get_log())

In [7]:
def testing(model, test_name, save=False):
    
    # Setup
    logger = Log()
    
    # Testing
    model.eval()
    acc = 0
    for inputs, labels in tqdm(test_dataloader, mininterval=1):
        y_pred = model(inputs.to(dev))
        acc += (torch.argmax(y_pred, 1) == labels.to(dev)).float().sum()
    acc = 100 * float(acc) / len(test_dataloader.dataset)
    logger.log(f"Test accuracy {round(acc, 2)}")
    
    # Saving
    if save:
        with open(f'{test_name}_test.txt', "w") as text_file:
            text_file.write(logger.get_log())

### Running

In [8]:
# Image settings
batch_size = 32
gray = True

# Model settings
dropout_rate = 0.35

# Training settings
lr = 3
momentum = 0.92
epochs = 30

# Scheduler settings
patience = 2
factor = 0.4

# Name
name = f"{epochs}_SGD_{lr}_{momentum}_{dropout_rate}_{batch_size}_d_t{'g' if gray else ''}"

In [9]:
train_dataloader = DataLoader(Data(train=0, gray=gray), batch_size=batch_size)
test_dataloader = DataLoader(Data(train=1, gray=gray), batch_size=batch_size)
val_dataloader = DataLoader(Data(train=2, gray=gray), batch_size=batch_size)

In [None]:
model = CustomModel(dropout_rate=dropout_rate, gray=gray).to(dev)
loss_fn = nn.CrossEntropyLoss()

optimiser = torch.optim.SGD(model.parameters(), lr=10**(-lr), momentum=momentum)
training(model, loss_fn, optimiser, epochs, name, save=True, patience=patience, factor=factor)

In [None]:
testing(model, name, save=True)