## Methods

In this notebook, we explore the use of different `methods`. For that,
we will be using `SwissProt` data.

### Setup

---

In [1]:
# Add path to load local modules
import sys
sys.path.insert(0, '..') # add directory above current directory to path

In [2]:
# ruff: noqa: E402
# Reload modules automatically
%reload_ext autoreload
%autoreload 2

# Module imports
import time
import seaborn as sns

# External imports
import numpy as np
import torch

# Custom Modules imports
from datasets.prot.swissprot import SPSimpleDataset, SPSetDataset  # noqa

  EXISTS: go-basic.obo
go-basic.obo: fmt(1.2) rel(2023-11-15) 46,228 Terms; optional_attrs(relationship)


In [3]:
# Set styles
sns.set_style("whitegrid")

### Data Loading

---

In this section, we load the data from the `SwissProt` database which is the smaller of the two datasets. We will be using both the regular dataloader for standard few shot finetuning as well as episodic dataloader for episodic finetuning.

#### Regular Dataloader

In [7]:
# Setup the loading parameters
root = "../data"
batch_size = 10
min_samples = 6

kwargs = {
    "root": root,
    "batch_size": batch_size,
    "min_samples": min_samples,
}

# Load SPSetDataset for each mode
modes = ["train", "val", "test"]
r_datasets = [SPSimpleDataset(**kwargs, mode=mode) for mode in modes]
r_train, r_val, r_test = [dataset.get_data_loader(num_workers=0, pin_memory=False) for dataset in r_datasets]

# Get some basic statistics about each of the splits
for split, mode in zip(r_datasets, modes):
    print(f"ℹ️ {mode} split has {len(split)} samples")
    print(f"ℹ️ Each sample is an encoded protein sequence of length {split.dim}")
    print(f"ℹ️ {mode} split has {len(np.unique([smp.annot for smp in split.samples]))} classes.")
    print()

ℹ️ train split has 11722 samples
ℹ️ Each sample is an encoded protein sequence of length 1280
ℹ️ train split has 182 classes.

ℹ️ val split has 600 samples
ℹ️ Each sample is an encoded protein sequence of length 1280
ℹ️ val split has 26 classes.

ℹ️ test split has 652 samples
ℹ️ Each sample is an encoded protein sequence of length 1280
ℹ️ test split has 12 classes.



#### Episodic Dataloader

In [8]:
# Setup the loading parameters
root = "../data"
n_way = 5
n_support = 3
n_query = 3
subset = 1.0 # Load full dataset

kwargs = {
    "n_way": n_way,
    "n_support": n_support,
    "n_query": n_query,
    "root": root,
    "subset": subset,
}

# Load SPSetDataset for each mode
modes = ["train", "val", "test"]
e_datasets = [SPSetDataset(**kwargs, mode=mode) for mode in modes]
e_train, e_val, e_test = [dataset.get_data_loader(num_workers=0, pin_memory=False) for dataset in e_datasets]

# Get some basic statistics about each of the splits
for split, mode in zip(e_datasets, modes):
    print(f"ℹ️ {mode} split has {len(split)} samples")
    print(f"ℹ️ Each sample is an encoded protein sequence of length {split.dim}")
    print()

ℹ️ train split has 182 samples
ℹ️ Each sample is an encoded protein sequence of length 1280

ℹ️ val split has 26 samples
ℹ️ Each sample is an encoded protein sequence of length 1280

ℹ️ test split has 12 samples
ℹ️ Each sample is an encoded protein sequence of length 1280



### Baseline

---

In [6]:
# # Choose backbone
# bm1 = fcnet.FCNet(x_dim=X.shape[-1])

# # Define the method
# m1 = baseline.Baseline(
    # backbone=bm1, 
    # n_way=3, 
    # n_support=15, 
    # n_classes=len(dataset.categories),
    # loss="softmax",
    # type="classification"
# )

# # Define the optimizer for obtaining training
# optimizer = optim.AdamW(m1.parameters(), lr=0.001)

# for epoch in range(2):
    # m1.train_loop(epoch, loader, optimizer)