In [None]:
import numpy as np
import pandas as pd
import random
import warnings

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from easyfsl.samplers import TaskSampler
from easyfsl.utils import evaluate
from easyfsl.methods import FewShotClassifier, PrototypicalNetworks

from ray import tune

from statistics import mean

from get_processed_data import get_processed_data
from FSLMethods import form_datasets, training_epoch, evaluate_model
from FSLDataset import FSLDataset
from FSLNetworks import DummyNetwork
from FSLTrainer import fsl_trainer, fsl_tuner

warnings.filterwarnings('ignore')

### Model training (meta-learning / episodic training)

Episodic training simulates the few-shot learning scenario to train a prototypical network. Training data is organized into episodes that resemble few-shot tasks.

Set up

In [None]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)

Train and tune model

In [None]:
config = {
    'n_shot': tune.sample_from(lambda _: 2 * np.random.randint(1, 6)), ## Try even numbers in [2, 10]
    'embedding_size': tune.sample_from(lambda _: 2 ** np.random.randint(2, 6)) ## Try {4, 8, 16, 32}
}

model_tuner = fsl_tuner(fsl_trainer, config, metric = 'recall')

### Model evaluation

In [None]:
# evaluate(model, test_loader) ##TODO: Implement method