In [1]:
import torch
from data import get_dataloader
import numpy as np

from functions import *
from PT_plus_MAP import *
from architecture import LR

In [2]:
### Architecture ###
model_name = 'LR'

### Optimization ###
lr = 0.1
weight_decay = 1e-4
lr_gamma = 0.1
n_epoch = 20

# Episodes parameters
n_episode = 10000  # number of few-shot tasks generated
n_way = 5
n_shot = 1
n_query = 15

In [3]:
# Architecture
model = LR(n_way)

# Data parameters
IBC_path = '/bigdisk2/nilearn_data/neurovault/collection_6618/'
split_dir = '../dataset/split/'
parcel = True
batch_size = 128

# Save path
save_path = f'./results/{model.name}/'

# Train a LR on the training examples of the few-shot task and evaluate it on the query samples.

In [4]:
# Loader
train_loader = get_dataloader('train', IBC_path, parcel, split_dir, meta=False, batch_size=1, sampler=None)
test_loader = get_dataloader('test', IBC_path, parcel, split_dir, meta=False, batch_size=1, sampler=None)

# Reshape the data samples.
base_mean, data = extract_feature(train_loader, test_loader, model, 'best')

# Train a LR on 10000 few-shot problems and evaluate it.  ## Essayer avec normalisation et sans normalisation. Et aussi avec centrage et normalisation.
acc_list = []
# You can use the same normalization as in SimpleShot ('CL2N') or only divide each example by its norm ('L2N').  
norm_type = None

# Iterate over several episodes.
for episode in range(n_episode):
    # Retrieve train and test from test_loader
    train_data, test_data, train_labels, test_labels = sample_case(data, n_shot, n_way, n_query)
    #print("train_data", train_data.shape)  #  Examples associated with the same class follow each other.
    #print("train_labels", train_labels.shape)
    # (Optional) Normalize the data samples.
    if norm_type == 'CL2N':
        train_data = train_data - base_mean
        train_data = train_data / LA.norm(train_data, 2, 1)[:, None]
        test_data = test_data - base_mean
        test_data = test_data / LA.norm(test_data, 2, 1)[:, None]
    elif norm_type == 'L2N':
        train_data = train_data / LA.norm(train_data, 2, 1)[:, None]
        test_data = test_data / LA.norm(test_data, 2, 1)[:, None]
    train_data = torch.from_numpy(train_data)
    train_labels = torch.from_numpy(train_labels)
    test_data = torch.from_numpy(test_data)
    test_labels = torch.from_numpy(test_labels)
    
    # Rename the labels for criterion
    #print(train_labels)
    unique_labels = torch.sort(torch.unique(train_labels))
    #print(unique_labels.values)
    for new_l, l in enumerate(unique_labels.values):
        train_labels[train_labels == l] = torch.ones_like(train_labels[train_labels == l]) * new_l
        test_labels[test_labels == l] = torch.ones_like(test_labels[test_labels == l]) * new_l
    # Initialize the LR on the train for a few epochs
    model = LR(n_way)
    
    # Optimization
    criterion = nn.CrossEntropyLoss()
    optimizer, scheduler = get_optimizer(model, lr, n_epoch)
    
    # Train the LR
    model.train()
    losses = []
    train_accs = []
    for epoch in range(n_epoch+1):
        # Train for one epoch.
        # Zero the parameter gradients.
        optimizer.zero_grad()
        # Forward + backward + optimize.
        outputs = model(train_data)   
        loss = criterion(outputs, train_labels)
        loss.backward()  
        optimizer.step()
        scheduler.step()
        acc = compute_accuracy(outputs.clone().detach(), train_labels)
        # Statistics.
        losses.append(loss.item())
        train_accs.append(acc*100 / train_labels.shape[0])
    
    # Compute the accuracy on test_data
    optimizer.zero_grad()
    model.eval()
    
    with torch.no_grad():    
        outputs = model(test_data)

        # Compute the accuracy.
        acc = compute_accuracy(outputs.clone().detach(), test_labels)
    print('Acc on episode {} : {:.2f}.'.format(episode, acc*100 / test_labels.shape[0]), end='\r')
    acc_list.append(acc*100 / test_labels.shape[0])

acc_mean, acc_conf = compute_confidence_interval(acc_list)
print('The baseline has an average accuracy of {:.2f}% over {} tasks with 95% confidence interval {:.2f}.'.format(np.round(acc_mean, 2), n_episode, np.round(acc_conf, 2)))

The baseline has an average accuracy of 52.94% over 10000 tasks with 95% confidence interval 0.19.
