In [2]:
import matplotlib.pyplot as plt
import time

import torch
from torch import nn, optim
from torch.optim import SGD

In [3]:
def multiple_train(model, lr, batch_size, n_trials, n_epochs, version = 'normal', seed = 17):
    
    
    seeds = range(0, n_trials)
    train_acc = torch.zeros(n_trials, n_epochs)
    test_acc = torch.zeros(n_trials, n_epochs)
    tr_losses = torch.zeros(n_trials, n_epochs)
    
    for i in range (n_trials):
        print('Trial number : ' +str(i))
        if(version == 'normal'):
            model = CNN()
        train_acc[i], test_acc[i] , tr_losses[i] = run_train(model, lr, batch_size, n_epochs, seeds[i])
        print('---')
        
    train_mean_accuracy = train_acc.mean()
    test_mean_accuracy = test_acc.mean()
    train_std = train_acc.std()
    test_std = test_acc.std()
    
    ##print(f'Model: {model._get_name()} has %.4f train mean accuracy with %.4f SD // %.4f test mean accuracy with %.4f SD' %
    ##   (train_mean_accuracy, train_std, test_mean_accuracy, test_std))
    
    return train_acc.mean(dim = 0), test_acc.mean(dim = 0), tr_losses.mean(dim = 0)

In [4]:
def run_train(model, lr, batch_size, n_epochs, seed=17):

    # Generate data
    torch.manual_seed(seed) # For reproducbility
    train_loader, test_loader = load_data(batch_size)

    # Apply training mode and weight initialization
    model.train()
    #model.apply(weight_initialization)

    # Train model
    start = time.time()
    tr_losses, test_accuracy, train_accuracy = train(model, train_loader, test_loader, n_epochs, lr=lr)
    
    print('Batch_size %.2f ' %batch_size)
    print('Training time: %.2f s' % (time.time()-start))

    model.eval() # Disable dropout layers for testing
    final_train_accuracy = compute_accuracy(model, train_loader)
    final_test_accuracy = compute_accuracy(model, test_loader)

    print('Train accuracy is %.4f and Test accuracy is %.4f' %
         (final_train_accuracy, final_test_accuracy))
    return train_accuracy, test_accuracy, tr_losses

In [5]:
def train(model, train_loader, test_loader, n_epochs, lr):
    
    binary_crit = nn.BCELoss()
    optimizer = optim.SGD(model.parameters(), lr=lr , momentum=0.9)

    tr_losses = torch.zeros(n_epochs)
    test_accuracy = torch.zeros(n_epochs)
    train_accuracy = torch.zeros(n_epochs)

    for e in range(n_epochs):
        # Reset training/validation loss
        tr_loss = 0

        # Training mode
        model.train()

        for (trainX, trainY, trainC) in train_loader:
            # Forward pass
            out = model(trainX)

            # Binary classification loss
            binary_loss = binary_crit(out, trainY.float())
            
            
            total_loss = binary_loss 
            tr_loss += total_loss.item()

            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()    
       
        # Collect loss data
        tr_losses[e] = tr_loss
        test_accuracy[e]=compute_accuracy(model, test_loader)
        train_accuracy[e]=compute_accuracy(model, train_loader)
        
    return tr_losses, test_accuracy, train_accuracy

In [6]:
def compute_accuracy(model, data_loader):
    acc = 0.
    total = 0
    model.eval()
    with torch.no_grad():
        for (X, y, _) in data_loader:
            out = model(X)
            acc += ((out > 0.5) == y).float().sum().item()
            total += len(y)
    return acc/total

In [7]:
def plot(net, tr_losses, n_epochs):
    
    nom=''
    #if aux_weight : nom = '+AuxLoss'
        
    data = torch.linspace(1, n_epochs,n_epochs)
    plot = torch.linspace(1, n_epochs,n_epochs)
    
    for i in range(n_epochs):
        plot[i]=tr_losses[:,i].mean() 
    
    name_fig=net._get_name()+nom
    name_ax='ax_'+name_fig
    name_fig, name_ax = plt.subplots(figsize = (6,2))
    
    plt.plot(data, plot, 'k--')
    name_ax.set_xlabel('Epoch')
    name_ax.set_ylabel('Training loss')
    name_ax.grid()    

    plt.title(f'Model: {net._get_name()}'+nom)
    
    fname = f'{net._get_name()}'+nom
    plt.savefig('image/' + fname)
    print(f'Training loss plot saved under {fname}.png \n \n ')
    print('################################################################')