# Few-Shot Learning - Practical 2
___
In this notebook we will walk through the basics of training a few-shot model. We will define our own training code, take a look at the definition of one possible model architecture, and explore a few standard hyper-parameters. 

Since we are low on compute resources and want to be able to see _something_ train we will be using a toy problem that is basically solved at this point: OmniGlot. This is essentially few-shot's MNIST. Its a collection of handwritten characters from a number of different languages. The key distinction between it and MNIST is that we only have about 20 characters per class but we have about 1000 classes. This will allow us to sample many subproblems from the dataset in order to form our episodes. 

### Part 0: Imports
We start with what will become our traditional imports; Numpy, PyTorch and our course utilities package. 

In [None]:
import utils # Special package prepared for this course
import torch # Core auto-differentiation library 
import numpy # Standard Python linear algebra library

### Part 1: Models
We will use a standard 4-layer CNN architecture for our encoder and a simplified RCNet architecture for our few-shot model. These are defined in the utilities package under the submodule `practical_2`. Once we have gone through the training code it might be interesting to copy them into this notebook and play around with the definitions and see how changing them changes the model performance.

In [None]:
encoder = utils.practical_2.SimpleCNN(device='cpu')             # Note that for this exercise everything will run on CPU
model = utils.practical_2.ARelationalNet(encoder, fc_dim=16, device='cpu') # device notes are included to show one paradigm for managing

### Part 2: Training loop definition
In order to keep everything organized we will define two functions to help us in our experimental setup. The first is called `train` and executes a single training iteration on the OmniGlot training data. Commentary on individual components is included as comments below.

In [None]:
def train(model, lr=1e-5, num_episodes=10, optimizer=None):

    loss_list = []
    acc_list = []

    # We will bypass some of the complexity of loading and preprocessing
    # data by making use of a pre-built dataloader object for the OmniGlot
    # dataset. Note that building these dataloaders and data parsing tools
    # is often one of the most time consuming parts of a DL project and
    # has enough complexity to justfy its own course.
    dataloader = utils.practical_2.get_omniglot(train=True, num_episodes=num_episodes)

    # As training loops are developed you will often be asked to try out 
    # new ideas. This is of course something we need to plan for and adapt
    # to as the requirements and technologies available shift. Because of
    # this its always a good idea to write your code in a modular way and
    # be able to alter function behavior with changing default behavior. 
    # 
    # The below instantiation of our optimizer is a good example of this,
    # by default our optimizer is Adam. This is a reasonable choice to
    # start off with but we might find that its a poor choice by the end
    # of the project. Instead of changing the line of code in our
    # experiment scripts we want to write our scripts such that if we
    # re-run them exactly as before we get the old behavior but if we run
    # them with new keyword arguments then we get the desired adjusted
    # behavior, in this case maybe a better optimizer.
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
    # Place the model into train mode to ensure layers behave properly.
    model.train()

    # Here we have the primary for loop for the training process. We will
    # be utilizing our pre-built dataloader to rapidly sample episodes from
    # the OmniGlot dataset and run them through the model. We will then
    # compute a loss value to measure our performance and compute the 
    # model gradients via backpropagation. Our final step (hehe, step and
    # gradient step) is to take a single optimization step using our 
    # predefined optimization algorithm. 
    #
    # This set of steps is a very standard algorithm configuration and is
    # at the heart of nearly all deep learning training routines.
    for episode, labels in dataloader:
        
        #Transfer labels to the correct device
        labels = labels.to(model.device)
        
        # -----------------------------------------------------------------
        # This might be an interesting place to insert some print statements
        # and take a look at the current form of the episode.
        # -----------------------------------------------------------------

        # Run the episode through the model
        logits = model(episode)

        # Compute loss
        loss = torch.nn.functional.cross_entropy(logits, labels)

        # Compute accuracy
        _, predictions = logits.max(1)
        total = labels.size(0)
        correct = (predictions == labels.long()).sum().item()
        acc = correct / total

        # Compute the gradients via backpropagation
        loss.backward()
        
        # Update the model weights via the chosen optimizer
        optimizer.step()
        
        loss_list.append(loss.item())
        acc_list.append(acc)
  
    loss = sum(loss_list) / len(loss_list)
    acc = sum(acc_list) / len(acc_list)
    
    return loss, acc

### Part 4: Evaluation loop definition

In [None]:
def evaluate(model, num_episodes=10, large_images=False):    
    # First thing we do is place the model in eval mode. Note that this 
    # does not do anything to gradient computation, rather it adjusts
    # several torch.nn.Modules so we don't carry information from one
    # episode to the next. For a more complete description see this
    # stack overflow post:
    # https://stackoverflow.com/questions/55627780/evaluating-pytorch-models-with-torch-no-grad-vs-model-eval
    model.eval()

    loss_list = []
    acc_list = []

    dataloader = utils.practical_2.get_omniglot(
        train=False,
        num_episodes=num_episodes,
        large_images=large_images
    )

    # Here we have a very similar for-loop to our training loop. The big 
    # difference here is the lack of the backpropagation and optimization
    # steps. We also have the added torch.no_grad() context statement.
    # This line ensures we do not perform extra computation since we have
    # no intention of computing gradients later on. 
    for episode, labels in dataloader:
        
        # Transfer labels to the correct device
        labels = labels.to(model.device)

        # Run the episode through model without saving the computation graph
        with torch.no_grad():
            logits = model(episode)

        # Compute loss
        loss = torch.nn.functional.cross_entropy(logits, labels)

        # Compute accuracy
        _, predictions = logits.max(1)
        total = labels.size(0)
        correct = (predictions == labels.long()).sum().item()
        acc = correct / total

        # Save computed loss and accuracy
        loss_list.append(loss.item())
        acc_list.append(acc)

    loss = sum(loss_list) / len(loss_list)
    acc = sum(acc_list) / len(acc_list)

    return loss, acc

In [None]:
## WARNING: Executing this cell can take a very long time!

num_train_episodes = 100
num_val_episodes = 50
lr=1e-5

for iteration in range(10):
    
    loss, acc = train(model, lr=lr, num_episodes=num_train_episodes)
    
    print(f'Iteration {iteration}, training loss: {loss}, training acc: {acc}')
    
    loss, acc = evaluate(model, num_episodes=num_val_episodes)
    
    print(f'Iteration {iteration}, validation loss: {loss}, validation acc: {acc}')


### Part 5: Evaluation on a previously trained model
Since we have limited compute lets also load some weights that were previously trained another dataset. This is a very similar architecture but it was trained on a dataset very different from Omniglot. It was also trained across many GPUs and using a variety of best practices that we didn't implement in our example training loop.

In [None]:
import nshot

model = nshot.load_from_file_path(
    '/home/shared/weights/metadataset.nsm',
    device='cpu'
)

evaluate(model, num_episodes=25, large_images=True)

So without ever seeing omniglot before this model gets 73% accuracy on brand new character recognition tasks.