In [None]:
import warnings

from get_processed_data import get_processed_data
from FSLMethods import form_datasets
from FSLTrainer import FSLTrainer

warnings.filterwarnings('ignore')

### Preparing data

Train-Test-Validation split and sampling

In [None]:
df, X_train, y_train, X_val, y_val, X_test, y_test = get_processed_data()

## Datasets need to be a FewShotDataset / torch Dataset with .get_labels
train_set, validation_set, test_set = form_datasets(X_train, y_train, X_val, y_val, X_test, y_test, feature_selection = True, sampling_method = 'undersampling')


### Model training (meta-learning / episodic training)

Episodic training simulates the few-shot learning scenario to train a prototypical network. Training data is organized into episodes that resemble few-shot tasks.

Train and tune model

In [None]:
config = {
    'n_shot': [2, 4, 8, 16, 32, 64], ## Try even numbers in [2, 100]
    'embedding_size': [2 ** x for x in range(2, 6)] ## Try {4, 8, 16, 32}
}

trainer = FSLTrainer(train_set, validation_set, test_set, config)

In [None]:
## Key: Value = (k, embedding_size): (metric, model_params)
results, best_config = trainer.tune(metric = 'recall')

### Model evaluation

Get best model

In [None]:
## TODO: Get model that meets supervised learning performance
threshold = 0.7
min_k, relevant_embedding_size, relevant_model_state = 0, 0, None

temp_list = [(k, embedding_size, metric, model_params) for (k, embedding_size), (metric, model_params) in results.items()]
temp_list = sorted(temp_list, key = lambda x: x[2])
for tup in temp_list:
    if tup[2] >= threshold:
        min_k = tup[0]
        relevant_embedding_size = tup[1]
        relevant_model_state = tup[3]
        break

if relevant_model_state == None:
    print('Few shot learning classifier unable to match threshold, further tuning is required... ...')
else:
    print(f'Minimum k required to match performance threshold = {min_k}')

In [None]:
# evaluate(model, test_loader) ##TODO: Implement method


In [None]:
_, best_model_state = results[best_config]
actuals, predictions = trainer.test(best_model_state, {'n_shot': best_config[0], 'embedding_size': best_config[1]})