In [None]:
# %history -g -f "history.py"

In [None]:
import numpy as np
import torch.optim as optim
import torchvision.transforms as transforms
import torch.nn as nn
from model_training import *
from utils import *
from models.CNN import Model1, Model2, Model3, Model4
from sklearn.model_selection import StratifiedKFold

In [None]:
# read configurations file
config = read_params('settings.yaml')

In [None]:
batch_size = config['batch_size']
num_epochs = config['number_epochs']
lr = config['lr']
transition_steps = config['transition_steps']
gamma = config['gamma_value']

In [None]:
# Extracting the training, validation and testing data
compressed_data_path = config['compressed_data_path']
data = decompress_data(compressed_data_path)

# Get data loaders
data_loaders_and_classes = get_loaders_and_classes(data, batch_size)
processed_data = np.load(compressed_data_path)


In [None]:
# Initialise model
model = Model4()

In [None]:
# Get number of parameters in model
model_total_params = sum(p.numel() for p in model.parameters())

In [None]:
device = get_device()

# Initialising training parameters
class_weights = get_class_weights(data['y_train'], device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Set optimiser
optimizer = optim.Adam(model.parameters(), 1e-4)
# optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

# Scheduling parameters
scheduler = Scheduler(optimizer, transition_steps, gamma)
lr_scheduler = scheduler.get_MultiStepLR()

# Create a model trainer object
model_trainer = ModelTrainer(model, criterion, data_loaders_and_classes)
model_trainer.scheduler = lr_scheduler # Include Scheduler

# Inititate training
train_loss_res, train_accuracy_res, validation_loss_res, validation_accuracy_res = model_trainer.train_model(100, optimizer)

In [None]:
# Load minimum validation loss model for evaluation
model.load_state_dict(torch.load("model_ckpt/checkpoint.pt"))
model_on_device = model.to(device) # Move model to Cuda devide
_, accuracy, _, _, _ = model_trainer.evaluate_model(model_on_device, data_loaders_and_classes['val_loader'], True) # Evaluate performance of the model
print(accuracy)

## Training using the cross validation approach

### Adapted from https://www.machinecurve.com/index.php/2021/02/03/how-to-use-k-fold-cross-validation-with-pytorch/

In [None]:
# Extracting the training, validation and testing data
compressed_data_path = config['compressed_cv_data_path']
processed_data = np.load(compressed_data_path) # Unzipping
x_train_cv = processed_data["x_train"]
y_train_cv = processed_data["y_train"]

x_test_cv = processed_data["x_test"]
y_test_cv = processed_data["y_test"]

In [None]:
num_folds = config['num_folds']


network = Model4()
device = get_device()

# Initialising training parameters
class_weights = get_class_weights(y_train_cv, device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

train_data = TrainTestData(x_train_cv, y_train_cv)
test_data = TrainTestData(x_test_cv, y_test_cv)


In [None]:
train_dataset = train_data.get_dataset()
train_dataset.enc.classes_.tolist()

In [None]:
# Configuration options
device = get_device() 
k_folds = 10
num_epochs = 110


class_weights = get_class_weights(y_train_cv, device)
loss_function = nn.CrossEntropyLoss(weight=class_weights)


# For fold results
results = {}

# Set fixed random number seed
torch.manual_seed(42)

transform = transforms.Compose(
[
transforms.ToTensor()
])
# Create the datasets
dataset_train_part = HAVSDataset(x_train_cv, y_train_cv, transform=transform)
dataset_test_part = HAVSDataset(x_test_cv, y_test_cv, transform=transform)
dataset = dataset_train_part

# Define the K-fold Cross Validator
kfold = StratifiedKFold(n_splits=k_folds, shuffle=False)


In [None]:
results = {}

# Start print
print('--------------------------------')

# K-fold Cross Validation model evaluation
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset, y_train_cv)):

    # Print
    print(f'FOLD {fold}')
    print('--------------------------------')

    # Sample elements randomly from a given list of ids, no replacement.
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)

    # Define data loaders for training and testing data in this fold
    trainloader = torch.utils.data.DataLoader(
                        dataset, 
                        batch_size=32, sampler=train_subsampler) # sampler=train_subsampler
    testloader = torch.utils.data.DataLoader(
                        dataset,
                        batch_size=32, sampler=test_subsampler) # sampler=test_subsampler

    # Init the neural network
    network = Model4()
    network_on_device = network.to(device) # Move model to the current device
    network_on_device.apply(reset_weights)

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=20, verbose=True)
    
    # Initialize optimizer
    optimizer = torch.optim.Adam(network_on_device.parameters())

    # Initialize scheduler
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10,20,30,40,50,60] , gamma=0.8, last_epoch=-1, verbose=True)

    # Run the training loop for defined number of epochs
    for epoch in range(0, num_epochs):

        # Print epoch
        print(f'Starting epoch {epoch+1}')

        # Set current loss value
        current_loss = 0.0

        # Iterate over the DataLoader for training data
        for i, data in enumerate(trainloader, 0):
        
            # Get inputs
            inputs, targets = data
            inputs = inputs.to(device, dtype=torch.float)
            targets = targets.to(device)
  
            # Zero the gradients
            optimizer.zero_grad()
            
            # Perform forward pass
            outputs = network_on_device(inputs)
            
            # Compute loss
            loss = loss_function(outputs, targets)
            
            # Perform backward pass
            loss.backward()
            
            # Perform optimization
            optimizer.step()
            
            # Print statistics
            current_loss += loss.item()
            if i % 500 == 499:
                print('Loss after mini-batch %5d: %.3f' %
                        (i + 1, current_loss / 500))
                current_loss = 0.0

        # Check early stopping
        test_loss, _, _, _, _ = evaluate_model(testloader, device, network_on_device, loss_function)

        #scheduler.step()

        early_stopping(test_loss, network_on_device)
        if early_stopping.early_stop:
            print("Early stopping")
            break           
    # Process is complete.
    print('Training process has finished. Saving trained model.')

    # Print about testing
    print('Starting testing')

    # Load last checkpoint (best results)
    network.load_state_dict(torch.load("checkpoint.pt"))
    network_on_device = network.to(device)

    # Saving the model
    save_path = f'./model-fold-{fold}_3.pth'
    torch.save(network_on_device.state_dict(), save_path)

    # Evaluation for this fold
    correct, total = 0, 0
    with torch.no_grad():

        # Iterate over the test data and generate predictions
        for i, data in enumerate(testloader, 0):

            # Get inputs
            inputs, targets = data
            inputs = inputs.to(device, dtype=torch.float)
            targets = targets.to(device)
            # Generate outputs
            outputs = network_on_device(inputs)

            # Set total and correct
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        # Print accuracy
        print('Accuracy for fold %d: %d %%' % (fold, 100.0 * correct / total))
        print('--------------------------------')
        results[fold] = 100.0 * (correct / total)

    # Print fold results
    print(f'K-FOLD CROSS VALIDATION RESULTS FOR {k_folds} FOLDS')
    print('--------------------------------')
    sum = 0.0
    for key, value in results.items():
        print(f'Fold {key}: {value} %')
        sum += value
    print(f'Average: {sum/len(results.items())} %')