In [None]:
#from architecture import *
import torch
from data import get_dataloader
from utils import create_sampler
import numpy as np
import random

from functions import *
from PT_plus_MAP import *

In [None]:
# Download the dataset.
# Execute the notebooks in dataset to download the data and prepare them.

In [None]:
### Architecture ###
# model = MLP(num_layers=2, num_feat=512, num_classes=61)
# model = Conv1d(num_layers=2, num_feat=512, num_classes=61)
# model = GNN(num_classes=61)

# Only the needed variables will be used to create the model.
model_name = 'GNN'  #'MLP', 'GNN' or 'Conv1d'
num_classes = 64  # Number of classes in the base dataset. If you reproduce the results of the article, use 64 for split1 and 65 for split2.
num_layers = 1  # Number of hidden layers.
num_feat = 128  # For 'MLP'/'GNN', number of features per hidden layer. For 'Conv1d', number of feature maps per hidden layer.
n_step = 1  # Only for 'GNN'. Number of times the input signal is diffused. See the 'GNN' architecture for more details.

### Optimization ###
lr = 0.1
weight_decay = 1e-4
lr_gamma = 0.1
n_epoch = 90

In [None]:
# Architecture
model_parameters = get_model_parameters(model_name, num_layers, num_feat, num_classes, n_step)
model = get_model(model_name, model_parameters)

# Data parameters
### TO UPDATE
IBC_path = '/bigdisk2/nilearn_data/neurovault/collection_6618/'
split_dir = '../dataset/split/'
parcel, batch_size = get_variables(model.name)

# Episodes parameters
n_episode = 500
n_way = 5
n_shot = 5
n_query = 15

# Save path
save_path = f'./results/{model.name}/'

# Sampler
train_sampler = create_sampler(IBC_path, split_dir, 'train')
val_sampler_infos = [n_episode, n_way, n_shot, n_query]

# Loader
train_loader = get_dataloader('train', IBC_path, parcel, split_dir, meta=False, batch_size=batch_size, sampler=train_sampler)
val_loader = get_dataloader('val', IBC_path, parcel, split_dir, meta=True, sampler_infos=val_sampler_infos)

# Train the model

In [None]:
# Optimization
criterion = nn.CrossEntropyLoss()
optimizer, scheduler = get_optimizer(model, lr, n_epoch)

# Move on CUDA.
use_cuda = True
model = model.cuda()
    
# Number of epochs between each evaluation on the validation dataset.
iter_val_acc = 5

In [None]:
# Train the model and evaluate it on the validation set at each epoch.
epochs_acc = []
epochs_loss = []
epochs_train_acc = []

best_acc = -1
for epoch in range(n_epoch+1):
    # Train for one epoch.
    epoch_loss, epoch_train_acc = train(model, criterion, optimizer, train_loader, use_cuda)
    print("\rLoss at epoch {}: {:.2f}.".format(epoch+1, epoch_loss), end='')
    print("(Acc \t: {:.2f}).".format(epoch_train_acc*100),end='')
    # Evaluate on the few-shot tasks from the validation set.
    if epoch % iter_val_acc == 0:
        optimizer.zero_grad()
        epoch_acc = episodic_evaluation(model, val_loader, val_sampler_infos, use_cuda)
        print("Acc on tasks at epoch {}: {:.2f}.".format(epoch+1, epoch_acc*100))
    scheduler.step()
    # Check if it is the best epoch_acc so far.
    is_best = epoch_acc > best_acc
    best_acc = max(epoch_acc, best_acc)
    
    # Save the model
    epochs_loss.append(epoch_loss)
    epochs_train_acc.append(epoch_train_acc*100)
    epochs_acc.append(epoch_acc*100)
    save_checkpoint({
                'epoch': epoch + 1,
                'arch': "MLP - num_layers {} - num_feat {} - num_classes {}".format(num_layers, num_feat, num_classes),
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict(),
                'epochs_acc': epochs_acc,
                'epochs_train_acc': epochs_train_acc,
                'epochs_loss': epochs_loss 
            }, is_best, folder=save_path)


In [None]:
# Check the training has been correct.
def load_statistics(save_path, type='last'):
    if type == 'best':
        checkpoint = torch.load('{}/model_best.pth.tar'.format(save_path))
    elif type == 'last':
        checkpoint = torch.load('{}/checkpoint.pth.tar'.format(save_path))
    else:
        assert False, 'type should be in [best, or last], but got {}'.format(type)
    epochs_acc = checkpoint['epochs_acc']
    epochs_train_acc = checkpoint['epochs_train_acc']
    epochs_loss = checkpoint['epochs_loss']
    best_acc = checkpoint['best_acc']
    return epochs_acc, epochs_loss, epochs_train_acc, best_acc

epochs_acc, epochs_loss, epochs_train_acc, best_acc = load_statistics(save_path)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(epochs_train_acc, 'g', label='train')
plt.plot(epochs_acc, 'r', label='val')
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()
plt.plot(epochs_loss, 'g')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

print("The best accuracy on tasks on the validation set is: {:.2f}.".format(best_acc*100))

# Test the model

In [None]:
# Loader
train_loader = get_dataloader('train', IBC_path, parcel, split_dir, meta=False, batch_size=1, sampler=None)
test_loader = get_dataloader('test', IBC_path, parcel, split_dir, meta=False, batch_size=1, sampler=None)

In [None]:
acc_1_shot,conf1, acc_5_shot, conf5 = do_extract_and_evaluate_simplified(model, train_loader, test_loader, save_path)
acc_1_shot, conf1, acc_5_shot, conf5

In [None]:
test_loader = get_dataloader('test', IBC_path, parcel, split_dir, meta=False, batch_size=1, sampler=None)
acc_1_shot, conf1, acc_5_shot, conf5 = do_extract_and_evaluate_simplified_PT_plus_MAP(model, test_loader, save_path)

In [None]:
acc_1_shot, conf1, acc_5_shot, conf5