<a href="https://colab.research.google.com/github/gmshroff/metaLearning2022/blob/main/code/nb3_CNP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MODEL-BASED META-LEARNING USING 

# Conditional Neural Processes

In [None]:
# !pip install import_ipynb --quiet
# !pip install learn2learn --quiet!pip install import_ipynb --quiet
# !git clone https://github.com/gmshroff/metaLearning.git
# %cd metaLearning/code

In [None]:
import import_ipynb
import utils,models
import l2lutils

In [None]:
from IPython import display
import torch
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from l2lutils import KShotLoader
from IPython import display
import torch.nn as nn

This is the CNP class

In [None]:
from CNP import CNP

# Data Generation/Loading

In [None]:
#Generate data - euclidean
meta_train_ds, meta_test_ds, full_loader = utils.euclideanDataset(n_samples=10000,n_features=20,n_classes=10,batch_size=32)

In [None]:
# Define an MLP network. Note that input dimension has to be data dimension. For classification
# final dimension has to be number of classes; for regression one.
#torch.manual_seed(10)
net = models.MLP(dims=[20,64,10])

In [None]:
# Train the network; note that network is trained in place so repeated calls further train it.
net,losses,accs=models.Train(net,full_loader,lr=1e-2,epochs=20,verbose=True)

In [None]:
#Training accuracy.
models.accuracy(net,meta_train_ds.samples,meta_train_ds.labels,verbose=True)

In [None]:
# Test accuracy.
models.accuracy(net,meta_test_ds.samples,meta_test_ds.labels)

# Meta-Learning: Tasks

Generate a k-shot n-way loader using the meta-training dataset

In [None]:
meta_train_kloader=KShotLoader(meta_train_ds,shots=5,ways=2)

Sample a task - each task has a k-shot n-way training set and a similar test set

In [None]:
d_train,d_test=meta_train_kloader.get_task()

Let's try directly learning using the task training set albeit its small size: create a dataset and loader and train it with the earlier network and Train function.

In [None]:
taskds = utils.MyDS(d_train[0],d_train[1])

In [None]:
d_train_loader = torch.utils.data.DataLoader(dataset=taskds,batch_size=1,shuffle=True)

In [None]:
net,losses,accs=models.Train(net,d_train_loader,lr=1e-1,epochs=10,verbose=True)

How does it do on the test set of the sampled task?

In [None]:
models.accuracy(net,d_test[0],d_test[1])

How does it do on the test set of from the meta-test set?

In [None]:
meta_test_kloader=KShotLoader(meta_test_ds,shots=5,ways=5)

In [None]:
d_train,d_test=meta_test_kloader.get_task()

In [None]:
models.accuracy(net,d_test[0],d_test[1])

# CNP-based  Meta-learning

In [None]:
# optimisers from torch
import torch.optim as optim
import torch.nn.functional as F

In [None]:
lossfn = torch.nn.NLLLoss()

Get a task dataset.

In [None]:
meta_train_kloader=KShotLoader(meta_train_ds,shots=5,ways=2,num_tasks=1000)

In [None]:
d_train,d_test = meta_train_kloader.get_task()

In [None]:
net = CNP(n_features=20,dims=[32,64,32],n_ways=5,n_classes=2)

In [None]:
print(net.mlp1,net.mlp2)

In [None]:
torch.eye(2)[d_train[1]] 

In [None]:
r,m,R = net.adapt(d_train[0],d_train[1])
r.shape,m.shape,R.shape

In [None]:
m

In [None]:
net(d_test[0],r)

# Putting it all together: CNP-based Meta-learning
Now let's put it together in a loop - CNP model-based meta-learning algorithm:

In [None]:
# Redifning accuracy function so that it takes h - dataset context - as input since net requires it.
def accuracy(Net,X_test,y_test,h,verbose=True):
    #Net.eval()
    m = X_test.shape[0]
    y_pred = Net(X_test,h)
    _, predicted = torch.max(y_pred, 1)
    correct = (predicted == y_test).float().sum().item()
    if verbose: print(correct,m)
    accuracy = correct/m
    #Net.train()
    return accuracy

In [None]:
classes_train = [i for i in range(5)]
classes_test = [i+5 for i in range(5)]
classes_train, classes_test

In [None]:
import learn2learn as l2l
import torch.optim as optim
shots,ways = 5,2
net = CNP(n_features=20,n_classes=ways,dims=[32,64,32],lr=1e-4,n_ways=5)
lossfn = torch.nn.NLLLoss()
meta_train_kloader=KShotLoader(meta_train_ds,shots=shots,ways=ways,num_tasks=1000)#,classes=classes_train)

In [None]:
#Meta-testing task loader for later.
meta_test_kloader=KShotLoader(meta_test_ds,shots=shots,ways=ways)#,classes=classes_test)

In [None]:
epoch=0
n_epochs=100
task_count=50
while epoch<n_epochs:
    test_loss = 0.0
    test_acc = 0.0
    # Sample and train on a task
    for task in range(task_count):
        d_train,d_test=meta_train_kloader.get_task()
        rp = torch.randperm(d_train[1].shape[0])
        d_train0=d_train[0][rp]
        d_train1=d_train[1][rp]
        x_tr = d_train0
        d_tr = x_tr 
        h,_,_ = net.adapt(d_tr,d_train1)
        rp1 = torch.randperm(d_test[1].shape[0])
        d_test0=d_test[0][rp1]
        d_test1=d_test[1][rp1]
        x_ts = d_test0
        y_ts_sh = torch.zeros(x_ts.shape[0],ways)
        d_ts = x_ts 
        test_preds = net(d_ts,h)
        #train_preds = net(d_tr,h)
        # Accumulate losses over tasks - note train and test loss both included
        test_loss += lossfn(test_preds,d_test1)#+lossfn(train_preds,d_train1)
        net.eval()
        test_acc += accuracy(net,d_ts,d_test1,h,verbose=False)
        net.train()
    #Update the network weights
    print('Epoch  % 2d Loss: %2.5e Avg Acc: %2.5f'%(epoch,test_loss/task_count,test_acc/task_count))
    display.clear_output(wait=True)
    net.optimizer.zero_grad()
    test_loss.backward()
    net.optimizer.step()
    epoch+=1
    

Now test the trained CNP network and to tasks sampled from the meta_test_ds dataset:

In [None]:
meta_test_kloader=KShotLoader(meta_test_ds,shots=shots,ways=ways)
test_acc = 0.0
task_count = 50
adapt_steps = 1
# Sample and train on a task
for task in range(task_count):
    d_train,d_test=meta_test_kloader.get_task()
    x_tr = d_train[0]
    y_tr_sh = torch.cat((torch.zeros(1,ways),torch.eye(ways)[d_train[1][1:]]))
    d_tr = x_tr #torch.cat((x_tr,y_tr_sh),1)
    h,_,_=net.adapt(d_tr,d_train[1])
    x_ts = d_test[0]
    y_ts_sh = torch.zeros(x_ts.shape[0],ways)
    d_ts = x_ts #torch.cat((x_ts,y_ts_sh),1)
    test_preds = net(d_ts,h)
    test_acc += accuracy(net,d_ts,d_test[1],h,verbose=False)
    # Done with a task
net.train()
print('Avg Acc: %2.5f'%(test_acc/task_count))