In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import numpy as np
import sys
from torchsummary import summary

from channel_wide_res_net import Channel_Wide_ResNet


'''
Function that loads the dataset and returns the data-loaders
'''
def getData(batch_size,test_batch_size,val_percentage):
    # Normalize the training set with data augmentation
    transform_train = transforms.Compose([ 
        torchvision.transforms.Resize(32),
        transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(20),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
    
    # Normalize the test set same as training set without augmentation
    transform_test = transforms.Compose([ 
        torchvision.transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    # Download/Load data
    full_training_data = torchvision.datasets.EMNIST('/home/test/data',split = 'balanced',transform=transform_train,download=True)  
    test_data = torchvision.datasets.EMNIST('/home/test/data',split = 'balanced',transform=transform_test,download=True)  

    # Create train and validation splits
    num_samples = len(full_training_data)
    training_samples = int((1-val_percentage)*num_samples+1)
    validation_samples = num_samples - training_samples
    training_data, validation_data = torch.utils.data.random_split(full_training_data, [training_samples, validation_samples])

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(training_data,batch_size=batch_size,shuffle=True, drop_last=True, num_workers = 4)
    val_loader = torch.utils.data.DataLoader(validation_data,batch_size=batch_size,shuffle=False, drop_last=False, num_workers = 2)
    test_loader = torch.utils.data.DataLoader(test_data,batch_size=test_batch_size,shuffle=False, drop_last=False, num_workers = 2)

    return train_loader, val_loader, test_loader

'''
Function to test that returns the loss per sample and the total accuracy
'''
def test(data_loader,net,cost_fun,device):
  
    net.eval()
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    for batch_idx, (inputs,targets) in enumerate(data_loader):

        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = net(inputs)[0]
        loss = cost_fun(outputs,targets)

        # Metrics computation
        samples+=inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss/samples, cumulative_accuracy/samples*100

'''
Function to train the nework on the data for one epoch that returns the loss per sample and the total accuracy
'''
def train(data_loader,net,cost_fun,device,optimizer):
    
    net.train()
    samples = 0.
    cumulative_loss = 0.
    cumulative_accuracy = 0.

    for batch_idx, (inputs,targets) in enumerate(data_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        outputs = net(inputs)[0]
        loss = cost_fun(outputs,targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Metrics computation
        samples+=inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(1)
        cumulative_accuracy += predicted.eq(targets).sum().item()

    return cumulative_loss/samples, cumulative_accuracy/samples*100

def main(epochs, batch_size, test_batch_size,val_percentage,lr,test_freq, net_depth, net_width):
    
    # Define cost function
    cost_function = torch.nn.CrossEntropyLoss()

    # Create the network: Wide_ResNet(depth, width, dropout, num_classes)
    net = Channel_Wide_ResNet(1,net_depth,net_width,0,10)
    net = net.to(device)
    #summary(net, input_size=(1, 28, 28))

    # Create the optimizer anche the learning rate scheduler
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                    milestones=[int(epochs*0.3),int(epochs*0.6),int(epochs*0.8)], gamma=0.20)

    # Get the data
    train_loader, val_loader, test_loader = getData(batch_size,test_batch_size,val_percentage)
    
    save_filename = './EMNIST-teacher-' + str(net_depth) + '-' +  str(net_width) + '.pth'

    for e in range(epochs):
        net.train() 

        train_loss, train_accuracy = train(train_loader,net,cost_function,device,optimizer)

        val_loss, val_accuracy = test(val_loader,net,cost_function,device)
        
        scheduler.step()

        print('Epoch: {:d}:'.format(e+1))
        print('\t Training loss: \t {:.6f}, \t Training accuracy \t {:.2f}'.format(train_loss, train_accuracy))
        print('\t Validation loss: \t {:.6f},\t Validation accuracy \t {:.2f}'.format(val_loss, val_accuracy))
        
        if((e+1) % test_freq) == 0:
            test_loss, test_accuracy = test(test_loader,net,cost_function,device)
            torch.save(net.state_dict(), save_filename)
            print('Test loss: \t {:.6f}, \t \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))

    print('After training:')
    train_loss, train_accuracy = test(train_loader,net,cost_function,device)
    val_loss, val_accuracy = test(val_loader,net,cost_function,device)
    test_loss, test_accuracy = test(test_loader,net,cost_function,device)

    print('\t Training loss: \t {:.6f}, \t Training accuracy \t {:.2f}'.format(train_loss, train_accuracy))
    print('\t Validation loss: \t {:.6f},\t Validation accuracy \t {:.2f}'.format(val_loss, val_accuracy))
    print('Test loss: \t {:.6f}, \t \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    
    torch.save(net.state_dict(), save_filename)

    net2 = Channel_Wide_ResNet(1,net_depth,net_width,0,10)
    net2 = net.to(device)
    net2.load_state_dict(torch.load(save_filename))
    
    print('loaded net test:')
    test_loss, test_accuracy = test(test_loader,net2,cost_function,device)
    print('\t Test loss: \t {:.6f}, \t Test accuracy \t {:.2f}'.format(test_loss, test_accuracy))
    

# Parameters
epochs = 2
batch_size = 128
test_batch_size = 128
val_percentage = 0.05
lr = 0.1
test_freq = 1
device = 'cuda:0'
net_depth = 40
net_width = 1
    
main(epochs, batch_size, test_batch_size,val_percentage,lr,test_freq, net_depth, net_width)

| Wide-Resnet 40x1
Downloading and extracting zip archive
Downloading https://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download to /home/test/data/EMNIST/raw/emnist.zip
Failed download. Trying https -> http instead. Downloading http://cloudstor.aarnet.edu.au/plus/index.php/s/54h3OuGJhFLwAlQ/download to /home/test/data/EMNIST/raw/emnist.zip


HTTPError: HTTP Error 503: Service Unavailable