In [3]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import sliding_average

from get_processed_data import get_processed_data

### Splitting data

In [4]:
df, X_train, y_train, X_val, y_val, X_test, y_test = get_processed_data()

Training set shape: (12335, 55) (12335,)
Validation set shape: (1542, 55) (1542,)
Test set shape: (1542, 55) (1542,)


### Prototypical network

In [None]:
class PrototypicalNetwork (nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetwork, self).__init__()
        self.backbone = backbone

    ## Predict query labels using labelled support data
    def forward (self, support_data: torch.Tensor, support_labels: torch.Tensor, 
                 query_data: torch.Tensor) -> torch.Tensor:
        ## Extract features / embedding of support and query data (using backbone)
        z_support = self.backbone.forward(support_data)
        z_query = self.backbone.forward(query_data)

        ## Infer no. of unique classes from support set labels
        n_way = len(torch.unique(support_labels))

        ## Construct prototypes
            ## Prototype i = Mean of embeddings of all support data with label i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0) \
                    for label in range(n_way)
            ]
        )

        ## Compute euclidean distance from query data to prototypes, and classification scores
        dists = torch.cdist(z_query, z_proto)
        classification_scores = -dists ## Smaller distance -> Higher score

        return classification_scores

### Model training (meta-learning / episodic training)

Episodic training simulates the few-shot learning scenario to train a prototypical network. Training data is organized into episodes that resemble few-shot tasks.

Set up

In [None]:
N_WAY = 2
N_SHOT = 5
N_QUERY = 10

N_TRAINING_EPISODES = 0 ## No. of tasks to sample (??)
N_VALIDATION_TASKS = 0 ## (??)

train_set = None ## (??)

train_set.get_labels = None ## (??)
## Requires dataset to be a FewShotDataset (??)
train_sampler = TaskSampler(dataset = train_set, n_way = N_WAY, n_shot = N_SHOT, 
                            n_query = N_QUERY, n_tasks = N_TRAINING_EPISODES)
## Loader generates an iterable given a dataset and a sampler
train_loader = DataLoader(dataset = train_set, batch_sampler = train_sampler, pin_memory = True,
                          collate_fn = train_sampler.episodic_collate_fn)

Initializing optimizer, loss function and meta-training loop

In [None]:
## Loss fn
criterion = nn.CrossEntropyLoss()


## Optimizer
model = None ##TODO: Implement model
LEARNING_RATE = 0.001
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)


## Training loop
    ## Takes a classification task as input (support & query set), makes prediction, 
    ## calculates loss, updates model params, returns loss
def fit (optimizer, criterion, 
         support_data: torch.Tensor, support_labels: torch.Tensor, 
         query_data: torch.Tensor, query_labels: torch.Tensor) -> float:
    
    optimizer.zero_grad()
    classification_scores = model.forward(support_data, support_labels, query_data) ##TODO: Define method

    loss = criterion(classification_scores, query_labels)
    loss.backward()
    optimizer.step()

    return loss.item()

Train the model

In [None]:
log_update_frequency = 10

all_loss = []
model.train()
with tqdm(enumerate(train_loader), total = len(train_loader)) as tqdm_train:
    for episode_index, (support_data, support_labels, query_data, query_labels, _) in tqdm_train:
        episode_loss = fit(support_data, support_labels, query_data, query_labels)
        all_loss.append(episode_loss)

        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(loss = sliding_average(all_loss, log_update_frequency))

### Model evaluation

In [None]:
evaluate(model, test_loader) ##TODO: Implement method