# Few-Shot Learning - Practical 3
___________________________________________________

In this notebook, we'll be going through the basics of applying a few-shot learning model to a text dataset. We'll make use of the NShot package to load a default text encoder, BERT, and apply it to a dataset of Amazon product reviews. We'll see how training would work. Because of limited compute resources, we won't be able to fully train our model on a large dataset, so we'll load an existing model and see how it performs.


In [None]:
import nshot as ns

import pandas as pd
import json
from collections import Counter

import matplotlib
%matplotlib inline

In [None]:
encoder_device = 'cpu'
model_device   = 'cpu'

# Specify paths
data_dir = '/home/shared/amazon_reviews/tutorial/'
train_path = data_dir + 'train/'
val_path   = data_dir + 'val/'
test_path  = data_dir + 'test/'

## Part 1. The Dataset
___


For this training example, we will use the Amazon product review dataset, you can find more information about this dataset here: https://jmcauley.ucsd.edu/data/amazon/

In this dataset product reviews are given for a number of different product categories, from which we extract the raw text. The task is to determine the product category that each review is written about. 


In [None]:
# Since we are planning on training we load both the training data, validation, and test data

train_data      = ns.TextData(train_path, max_classes=50, max_per_class=100, min_per_class=10)
validation_data = ns.TextData(val_path, max_classes=20, max_per_class=100, min_per_class=10)
test_data       = ns.TextData(test_path, max_classes=20, max_per_class=100, min_per_class=10)

Before we start the training loop lets take a look at some of the numbers related to the datasets we just loaded. First we print the number of datapoints and the number of unique classes in each dataset. 

In [None]:
dataset_info = pd.DataFrame(data={
    'Training':[
        len(train_data), 
        len(set(train_data.labels))
    ],
    'Validation':[
        len(validation_data),
        len(set(validation_data.labels))
    ],
    'Test':[
        len(test_data),
        len(set(test_data.labels))
    ]
}).rename(index={0:'Datapoints',1:'Classes'})

dataset_info

### Class structure

Product categories on Amazon are hierarchical. A higher-level category, e.g. "Arts, Crafts & Sewing" may have subcategories, e.g. "Craft Supplies". All of the few-shot categories are based on the lowest tier of the hierarchy. We keep track of the structure of the hierarchy by adding '+' between tiers.

In [None]:
train_labels = sorted(set(train_data.labels))
val_labels = sorted(set(validation_data.labels))
test_labels = sorted(set(test_data.labels))

class_info = pd.DataFrame(data={'Training': train_labels,
                               'Validation': val_labels + ['' for _ in range(len(train_labels)-len(val_labels))],
                               'Test': test_labels + ['' for _ in range(len(train_labels)-len(test_labels))]
                               })

class_info

Remember that since we are working in the fewshot paradigm the classes in the training set and the classes in the validation set are entirely disjoint. 

Last thing to look at in the dataset is the distribution of the number of datapoints in each class. This can be a critical set of numbers to investigate and failure to properly handly imbalanced data can result in poor model performance when given real world data. For this dataset, every class has exactly 101 datapoints. We can easily check this as shown below. 

In [None]:
# This gives an object counting each unique value in the given list 
counter = Counter(train_data.labels)

# counter.most_common() sorts our labels by which are the most frequent. 
for label, cnt in counter.most_common():
    print(str(cnt) + ' \t ' + label)

### Individual data points

NShot's dataset objects arbitrarily order datapoints into a list, so we can look at individual data points by choosing an index. This returns two objects, the first is the raw text and the second is the label.

In [None]:
idx = 0

datapoint = train_data[idx]
text, label = datapoint

print(label.upper() + ": " + text)


## Part 2. The Model

For this notebook, we'll be leveraging BERT as an encoder. NShot provides us with a wrapper around a basic BERT model. By default, this uses the 'bert-base-uncased' model.

In [None]:
encoder = ns.BertEncoder(sequence_length=64, device=encoder_device)
model   = ns.RCNet(encoder, fc_dim=64, n_blocks=2, device=model_device)

#### The Training Run

Note: Training text models requires a heavy preprocessing step. This can take 1-30 minutes depending on the cpu power available to the notebook.

In practice, CPU is not going be enough to train this model very well. We may expect to hit ~48% accuracy on 5-shot, 5-way episodes but only if we train long enough.

In [None]:
output = ns.train_model(
    model, 
    train_data, 
    validation_data,
    freeze_encoder_weights=True, 
    max_iterations=1,
    train_with_negatives=False,
    distractors=False,
    query_size=8,
    logging_period=25, 
    episodes_per_iteration=25,
    log_dir='logs/text_logs'
)

#### Investigate the Results
As training progresses you will see several directories and files created. It should have a structure something like this: 

```
logs/text_logs/
  |  logs.json
  |  parameters.json
  |  plots.png
  \--weights
       |  best_weights.pt
       |  iteration_#_weights.pt
       |         :
       |         :
       \  iteration_#_weights.pt
```

The file 'parameters.json' contains some of the important training parameters. Near the end of the file, it also lists 'Validation accuracy' which is the highest validation performance reached in training. This updates at the end of every logging period and is one way to get an idea of what is going on during training.

Once the training run is complete we can use standard json tools to do deeper analysis of the machine readable logs.json file, which contains information from throughout the training run.

Note: This is the most basic format for handling logging and in practice you might rely on more sophisticated experiment management tools such as MLFlow or Weights & Biases. 

In [None]:
with open('logs/text_logs/parameters.json') as f:
    log_data = json.load(f)
    
log_data

The plots.png file saves loss and accuracy curves during training. The blue line represents performance on the training data, while the orange line represents validation data.

A challenge for few-shot models is that they often require low learning rates, so we may not see a lot of movement for our short training run.

In [None]:
# We can also look at the training curve
training_curve_img = matplotlib.image.imread('logs/text_logs/plots.png')

matplotlib.pyplot.figure(figsize=(16,12))
matplotlib.pyplot.imshow(training_curve_img)


## Part 3. Loading an existing model
___

Our current model is constrained in two ways. First, it hasn't been given enough time to train. But if you did let this model train for quite some time, the performance on validation data would still be quite poor. This is because the second issue is the small number of training classes. In this tutorial dataset, we only have 50 training classes. Let's load a model that was both trained sufficiently long and was trained on 800 classes. We'll see the performance is quite improved.

In [None]:
encoder = ns.BertEncoder(sequence_length=128, device=encoder_device)
model   = ns.RCNet(encoder, fc_dim=128, n_blocks=4, device=model_device)

state_dict_file = '/home/shared/practical_3_logs/weights/best_weights.pt'


Next we use the single command `load_weights` to get our previously learned weights loaded into our model.

In [None]:
model.load_weights(state_dict_file)

In [None]:
model

Note the warning that is thrown when loading weights. PyTorch has two basic modes a model can be in `train` and `eval`. If a model is in evaluation mode then it will not update gradient information as data is pushed through the model. Since we are about to evaluate the model on the test set this is just fine. 



#### Evaluating the Model
The evaluation function takes many of the same arguments as the training function. We still need to know how many classes per episode, how many examples per class and how many episodes in the iteration. The big difference is that we will only do one iteration and no gradients will be updated in the process. 

In [None]:
results = ns.run.evaluate(
    model, 
    test_data, 
    positives_per_class=5,
    number_of_episodes=5,
    distractors=False,
    negatives_per_class=0
    )

In [None]:
test_loss, test_accuracy = results

In [None]:
print("Final test accuracy: {}".format(test_accuracy))