# Learn2Learn Prototypical Network Implementation for BRACOL Dataset


***Install learn2learn and EfficientNet***

In [None]:
!pip install learn2learn
!pip install efficientnet_pytorch

from IPython.display import clear_output 
clear_output()
print('Done!')

***Some imports***

In [5]:
import sys
# change this to your own protonet files dir
sys.path.insert(0, '/content/drive/MyDrive/pg/protonet/')

from models import *
from engine import create_task_pool, run_train_dataloader, run_test_dataloader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

ModuleNotFoundError: No module named 'learn2learn'

### Training on Dataset

***Define some hiperparams***

In [None]:
# tasks params
ways = 5
shot = 1

# model & optimizer & lr_scheduler
model = MobileNetv2()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

# path to data
path_data = '/content/drive/MyDrive/pg/dataset/'

# dataset transforms
train_transforms=transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.5),
            transforms.RandomApply([transforms.RandomRotation(10)], 0.25),
            transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

val_transforms=transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

***Define train dataset, validation dataset and task pool***

In [None]:
train_dataset = torchvision.datasets.ImageFolder(root=path_data + 'train/', transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

val_dataset = torchvision.datasets.ImageFolder(root=path_data + 'val/', transform=val_transforms)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

# it might take a while to define the task pool
task_pool = create_task_pool(dataset=train_dataset, num_tasks=-1, ways=5, shot=1)

***Run training***

In [None]:
run_train_dataloader(n_epochs=100, 
                     train_loader=train_loader, 
                     val_loader=val_loader, 
                     task_pool=task_pool, 
                     model=model,
                     optimizer=optimizer, 
                     lr_scheduler=lr_scheduler, 
                     ways=ways, 
                     shot=shot, 
                     save_path='model_final.pth')

### Inference on Dataset

***Define the model and weights***

In [None]:
# tasks params
ways = 5
shot = 1
path_data = '/content/drive/MyDrive/pg/dataset/'


# define model
model = MobileNetv2()
model.load_state_dict(torch.load('model_final.pth'))
model.eval()

***Define test dataset***

In [None]:
trans = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
test_dataset = torchvision.datasets.ImageFolder(root=path_data + 'test/', transform=trans)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)


# it might take a while to define the task pool
train_dataset = torchvision.datasets.ImageFolder(root=path_data + 'train/', transform=trans)
task_pool = create_task_pool(dataset=train_dataset, num_tasks=-1, ways=ways, shot=shot)

***Run inference on test dataset***

In [None]:
results_dict = run_test_dataloader(model=model, 
                    test_loader=test_loader, 
                    task_pool=task_pool, 
                    ways=ways, 
                    shot=shot)

***Plot confusion matrix***

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sn

y_true = results_dict['real']
y_pred = results_dict['predicted']

cm = confusion_matrix(results_dict['real'], results_dict['predicted'], normalize='true')
df_cm = pd.DataFrame(cm, index = [i for i in ['Healthy', 'Miner', 'Rust', 'Phoma', 'Cercospora']],
                  columns = [i for i in ['Healthy', 'Miner', 'Rust', 'Phoma', 'Cercospora']])
plt.figure(figsize = (12,9))
sn.heatmap(df_cm, annot=True, cmap="Blues", vmin=0, vmax=1)
plt.ylabel('True Label', fontweight='bold')
plt.xlabel('Predicted Label', fontweight='bold')
plt.title('Confusion Matrix', fontsize=18, fontweight='bold')
print(precision_score(y_true, y_pred, average='macro'))
print(recall_score(y_true, y_pred, average='macro'))
print(f1_score(y_true, y_pred, average='macro'))

### Run training script for multiple backbones

In [None]:
import sys
sys.path.insert(0, '/content/drive/MyDrive/pg/protonet')
from run_protonet_models import run_protonet_models

run_protonet_models(ways=5, shot=1, path_data='/content/drive/MyDrive/pg/dataset/')

***Plot t-SNE***

In [None]:
import sys
sys.path.insert(0, '/content/drive/MyDrive/pg/protonet')
from run_protonet_models import generate_tSNE

generate_tSNE(ways=5, shot=5, path_data='/content/drive/MyDrive/pg/dataset/')
