In [1]:
## need older version of lightning for this notebook to run

import sys
sys.path.append('learn2learn/') # the version on pypi doesn't have LightningMAML
%config Completer.use_jedi = False

In [2]:
import pandas as pd
import numpy as np
import torch

import pyximport
pyximport.install()

import learn2learn as l2l
from learn2learn.data import TaskDataset
from learn2learn.algorithms.lightning import LightningMAML
from learn2learn.utils.lightning import EpisodicBatcher

from baseline.pytorch_models import LitFFNN
from baseline.training_functions import *
from baseline.training_functions import make_split

import pytorch_lightning as pl

from datasets import load_abundance_data, get_shared_taxa_dfs
from datasets import MicroDataset, Dataset
import torch.nn.functional as F
dfs = load_abundance_data()
all_datasets = get_shared_taxa_dfs(dfs)

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return np.eye(num_classes, dtype='uint8')[y]

In [3]:
pd.read_csv('results/baseline/ffnn_summary.csv')

Unnamed: 0,Model Type,Data Type,Dataset,hyperparams,AUC
0,FFNN,Abundance,Quin_gut_liver_cirrhosis,"[1024, 128, 128, 0.1, 0.5]",0.746377
1,FFNN,Abundance,WT2D,"[128, 64, 32, 0.001, 0.5]",0.59596
2,FFNN,Abundance,Zeller_fecal_colorectal_cancer,"[256, 256, 32, 0.1, 0.1]",0.7
3,FFNN,Abundance,Chatelier_gut_obesity,"[1024, 128, 32, 0.01, 0.1]",0.690311
4,FFNN,Abundance,metahit,"[1024, 64, 128, 0.1, 0.5]",0.623529
5,FFNN,Abundance,t2dmeta_long,"[256, 256, 128, 0.1, 0.1]",0.605882


In [4]:
datasets = [b for a,b in all_datasets.items()]

In [5]:
splits = [make_split(df) for df in datasets]

#redo the split for the dataset of interest (the one at idx 0)
# so we include a test set
train, test = make_split(datasets[0])
train, valid = make_split(train)

splits[0] = (train, valid)

trains = [s[0] for s in splits]
vals = [s[1] for s in splits]

In [6]:
class MetaMicroDataset(Dataset):
    """Dataset class for column dataset.
    Args:
       cats (list of str): List of the name of columns contain
                           categorical variables.
       conts (list of str): List of the name of columns which 
                           contain continuous variables.
       y (Tensor, optional): Target variables.
       is_reg (bool): If the task is regression, set ``True``, 
                      otherwise (classification) ``False``.
       is_multi (bool): If the task is multi-label classification, 
                        set ``True``.
    """
    def __init__(self, df, is_marker = False):
        df = df.sample(frac=1)
        if not is_marker:
            self.taxa_cols = df.columns[df.columns.str.contains('k__')] 
        else:
            self.taxa_cols = df.columns[df.columns.str.contains('gi[|]')] 
        self.matrix = torch.Tensor( df[self.taxa_cols].astype(float).values ).float()
        
        #scale the dataset to be in relative abundance space
        self.matrix=F.softmax(self.matrix, requires_grad=True)
        
        df.loc[df.disease=='ibd_crohn_disease'] = 'ibd_ulcerative_colitis'
        self.y=torch.Tensor( pd.Categorical(df.disease).codes ).long()
        self.n_samples, self.n_taxa= self.matrix.shape
        
    def __len__(self): return len(self.y)
    def __getitem__(self, idx):
        return [self.matrix[idx], self.y[idx]]

In [7]:
def collate(a):
    """
    collate function to simplify the tasks -- only want to ever distinguish between class 0 and 1
    Each class can represent any positive/negative group from any dataset
    6 datasets ==> 12 classes ==> 132 distinct metalearning tasks
    """
    idx = max([b[1] for b in a])
    #print(a[1])
    return(torch.cat( [b[0].unsqueeze(0) for b in a] ),
           torch.Tensor( [int(b[1]==idx)  for b in a]) )#.long() )

def collate(a):
    """
    collate function to simplify the tasks -- only want to ever distinguish between class 0 and 1
    Each class can represent any positive/negative group from any dataset
    6 datasets ==> 12 classes ==> 132 distinct metalearning tasks
    """
    #idx = max([b[1] for b in a])
    q = torch.Tensor([b[1] for b in a])
    return(torch.cat( [b[0].unsqueeze(0) for b in a] ),
           ( q == q.max() ).long() )



In [33]:
def build_taskset(datasets):
    MetaDS = l2l.data.UnionMetaDataset( [l2l.data.MetaDataset( MicroDataset(t) ) for t in datasets] )
    dataset = l2l.data.MetaDataset(MetaDS)
    transforms = [
        l2l.data.transforms.NWays(dataset, n=2),
        l2l.data.transforms.KShots(dataset, k=8),
        l2l.data.transforms.LoadData(dataset)
    ]
    return( TaskDataset(dataset, transforms,
                        num_tasks=2*( len(datasets) * 2 * (len(datasets) * 2 - 1) ), 
                        task_collate=collate)
          )
    
train_set = build_taskset(trains)
val_set = build_taskset(vals)



In [34]:
train_set = build_taskset(trains)
val_set = build_taskset(vals)


In [35]:
# %%time

# # MetaDS = MetaMicroDataset(trains)
# MetaDS = l2l.data.UnionMetaDataset( [l2l.data.MetaDataset( MicroDataset(t) ) for t in datasets] )
# dataset = l2l.data.MetaDataset(MetaDS)
# transforms = [
#     l2l.data.transforms.NWays(dataset, n=2),
#     l2l.data.transforms.KShots(dataset, k=10),
#     l2l.data.transforms.LoadData(dataset)
# ]
# train_set = TaskDataset(dataset, transforms,
#                         num_tasks=( len(datasets) * 2 * (len(datasets) * 2 - 1) ), 
#                         task_collate=collate)
# for task in taskset:
#     X, y = task

In [36]:
class FFNN(torch.nn.Module):
    """
    a deep FFNN network -- assuming it's going to show improvements
    in the transfer learning exploration
    probably won't be so good in standard learning approach
    """
    def __init__(self, 
                 dataset, 
                 layer_sizes = [128, 64, 32], 
                 dropout = .2, 
                 n_labels = 2):
        super(FFNN, self).__init__()
        
        linear_layers = [ nn.Linear( dataset.n_taxa, layer_sizes[0]), 
                          nn.BatchNorm1d(layer_sizes[0]), 
                          nn.Dropout(dropout), 
                          nn.GELU()]
        self.n_labels=n_labels
        for i in range(len(layer_sizes)-1):
            linear_layers += [ nn.Linear(layer_sizes[i], layer_sizes[i+1]), 
                               nn.BatchNorm1d(layer_sizes[i+1]), 
                               nn.Dropout(dropout), 
                               nn.GELU()]
            
        linear_layers += [ nn.Linear(layer_sizes[-1], n_labels), 
                           nn.Softmax() ]
        
        self.linear_net = nn.Sequential(*linear_layers)
        
    def forward(self, x):
        out = self.linear_net(x.float())
        return(out)

In [37]:
import torch.nn as nn

In [38]:
# np.arange(2) * (maml.train_queries + maml.train_shots)

In [39]:
i=0
for task in train_set:
    X, y = task
    i+=1
    

In [40]:
X.shape



torch.Size([16, 771])

In [41]:
# for _ in range(maml.train_shots):
#     qqq=np.arange(2) * (maml.train_queries + maml.train_shots) + _

In [64]:
k=2

In [65]:
def build_taskset(datasets):
    MetaDS = l2l.data.UnionMetaDataset( [l2l.data.MetaDataset( MicroDataset(t) ) for t in datasets] )
    dataset = l2l.data.MetaDataset(MetaDS)
    transforms = [
        l2l.data.transforms.NWays(dataset, n=2),
        l2l.data.transforms.KShots(dataset, k=k),
        l2l.data.transforms.LoadData(dataset)
    ]
    return( TaskDataset(dataset, transforms,
                        num_tasks=2*( len(datasets) * 2 * (len(datasets) * 2 - 1) ), 
                        task_collate=collate)
          )
    
train_set = build_taskset(trains)
val_set = build_taskset(vals)




In [66]:
model = FFNN(dataset=MicroDataset(pd.concat(trains).reset_index() ))
maml = LightningMAML(model, adaptation_lr=0.1, lr = .002,
                    train_ways=2, 
                    test_ways=2,
                    train_shots = k//2,
                    test_shots = k//2,
                    train_queries = k//2,
                    test_queries = k//2,
                    adaptation_steps=k//2)#, loss=torch.nn.MSELoss())
# , allow_nograd=True, 
#                     allow_unused=True)
episodic_data = EpisodicBatcher(train_set, train_set)#, taskset)
#EpisodicBatcher(tasksets.train, tasksets.validation, tasksets.test)

#setr up the trainer/logger/callbacks
checkpoint_callback=ModelCheckpoint(
                filepath = 'checkpoint_dir',
                save_top_k=1,
                verbose=False,
                monitor='valid_loss',
                mode='min'
                )

tube_logger = TestTubeLogger('checkpoint_dir', 
                            name='test_tube_logger')

trainer = pl.Trainer(max_epochs = 30,
#              num_sanity_val_steps=0,
             progress_bar_refresh_rate=1,
             weights_summary='full',
             check_val_every_n_epoch=1,
             checkpoint_callback=checkpoint_callback,
            callbacks=[EarlyStopping(monitor='valid_loss', 
                                    patience=50)]) 

trainer.fit(maml, episodic_data)
print('success!')

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

   | Name                       | Type             | Params
-----------------------------------------------------------------
0  | loss                       | CrossEntropyLoss | 0     
1  | model                      | MAML             | 109 K 
2  | model.module               | FFNN             | 109 K 
3  | model.module.linear_net    | Sequential       | 109 K 
4  | model.module.linear_net.0  | Linear           | 98 K  
5  | model.module.linear_net.1  | BatchNorm1d      | 256   
6  | model.module.linear_net.2  | Dropout          | 0     
7  | model.module.linear_net.3  | GELU             | 0     
8  | model.module.linear_net.4  | Linear           | 8 K   
9  | model.module.linear_net.5  | BatchNorm1d      | 128   
10 | model.module.linear_net.6  | Dropout          | 0     
11 | model.module.linear_net.7  | GELU             | 0     
12 | model.module.linear_net.8  | Linear           | 2 K   
13 | model.module.

Validation sanity check: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Training: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])
[False False False False]
tensor([ True, False,  True, False])
tensor([1, 0])
tensor([1, 0])


Validating: 0it [00:00, ?it/s]

[False False False False]
tensor([ True, False,  True, False])
tensor([0, 1])
tensor([0, 1])
success!


In [62]:
qqq=torch.load(checkpoint_callback.best_model_path)

FileNotFoundError: [Errno 2] No such file or directory: ''

In [19]:
#load best_model_path's params
maml.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])

#set up standard model ==> for lightning training
lightning = LitFFNN(train, 
                    valid)
lightning.model.load_state_dict(maml.model.module.state_dict())

<All keys matched successfully>

In [20]:
lightning.model

FFNN(
  (linear_net): Sequential(
    (0): Linear(in_features=771, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): GELU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.2, inplace=False)
    (7): GELU()
    (8): Linear(in_features=64, out_features=32, bias=True)
    (9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): Dropout(p=0.2, inplace=False)
    (11): GELU()
    (12): Linear(in_features=32, out_features=2, bias=True)
    (13): Softmax(dim=None)
  )
)

In [21]:
checkpoint_callback.best_model_score
lightning = LitFFNN(train, 
                    valid)
lightning.model.load_state_dict()

TypeError: load_state_dict() missing 1 required positional argument: 'state_dict'

In [None]:
pl.Trainer(max_epochs = 100,
             logger=tube_logger,
             progress_bar_refresh_rate=0,
             weights_summary=None,
             check_val_every_n_epoch=1,
             checkpoint_callback=checkpoint_callback,
            callbacks=[EarlyStopping(monitor='val_loss', 
                                    patience=20)]) 

In [None]:
MetaDS = MetaMicroDataset(vals)
dataset = l2l.data.MetaDataset(MetaDS)
# transforms = [
#     l2l.data.transforms.NWays(dataset, n=5),
#     l2l.data.transforms.KShots(dataset, k=1),
#     l2l.data.transforms.LoadData(dataset),
# ]
# val_taskset = TaskDataset(dataset, transforms, num_tasks=20)
# for task in val_taskset:
#     X, y = task

In [None]:
taskset

In [None]:
from baseline.training_functions import *

In [None]:

"""
Example for running few-shot algorithms with the PyTorch Lightning wrappers.
"""

import learn2learn as l2l
import pytorch_lightning as pl
from argparse import ArgumentParser
from learn2learn.algorithms import (
    LightningPrototypicalNetworks,
    LightningMetaOptNet,
    LightningMAML,
    LightningANIL,
)
from learn2learn.utils.lightning import EpisodicBatcher


def main():
    parser = ArgumentParser(conflict_handler="resolve", add_help=True)
    # add model and trainer specific args
    parser = LightningPrototypicalNetworks.add_model_specific_args(parser)
    parser = LightningMetaOptNet.add_model_specific_args(parser)
    parser = LightningMAML.add_model_specific_args(parser)
    parser = LightningANIL.add_model_specific_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)

    # add script-specific args
    parser.add_argument("--algorithm", type=str, default="protonet")
    parser.add_argument("--dataset", type=str, default="mini-imagenet")
    parser.add_argument("--root", type=str, default="~/data")
    parser.add_argument("--meta_batch_size", type=int, default=16)
    parser.add_argument("--seed", type=int, default=42)
    return(parser)
    args = parser.parse_args()
    dict_args = vars(args)

    pl.seed_everything(args.seed)
    print('hi')
    print(args)
    # Create tasksets using the benchmark interface
    if False and args.dataset in ["mini-imagenet", "tiered-imagenet"]:
        data_augmentation = "lee2019"
    else:
        data_augmentation = "normalize"
    tasksets = l2l.vision.benchmarks.get_tasksets(
        name=args.dataset,
        train_samples=args.train_queries + args.train_shots,
        train_ways=args.train_ways,
        test_samples=args.test_queries + args.test_shots,
        test_ways=args.test_ways,
        root=args.root,
        data_augmentation=data_augmentation,
    )
    episodic_data = EpisodicBatcher(
        tasksets.train,
        tasksets.validation,
        tasksets.test,
        epoch_length=args.meta_batch_size * 10,
    )

    # init model
    if args.dataset in ["mini-imagenet", "tiered-imagenet"]:
        model = l2l.vision.models.ResNet12(output_size=args.train_ways)
    else:  # CIFAR-FS, FC100
        model = l2l.vision.models.CNN4(
            output_size=args.train_ways,
            hidden_size=64,
            embedding_size=64*4,
        )
    features = model.features
    classifier = model.classifier

    # init algorithm
    if args.algorithm == "protonet":
        algorithm = LightningPrototypicalNetworks(features=features, **dict_args)
    elif args.algorithm == "maml":
        algorithm = LightningMAML(model, **dict_args)
    elif args.algorithm == "anil":
        algorithm = LightningANIL(features, classifier, **dict_args)
    elif args.algorithm == "metaoptnet":
        algorithm = LightningMetaOptNet(features, **dict_args)

    trainer = pl.Trainer.from_argparse_args(
        args,
        gpus=1,
        accumulate_grad_batches=args.meta_batch_size,
        callbacks=[
            l2l.utils.lightning.TrackTestAccuracyCallback(),
            l2l.utils.lightning.NoLeaveProgressBar(),
        ],
    )
    trainer.fit(model=algorithm, datamodule=episodic_data)
    trainer.test(ckpt_path="best")


# if __name__ == "__main__":
#     main()

In [None]:
parser = ArgumentParser(conflict_handler="resolve", add_help=True)
# add model and trainer specific args
# parser = LightningPrototypicalNetworks.add_model_specific_args(parser)
# parser = LightningMetaOptNet.add_model_specific_args(parser)
# parser = LightningMAML.add_model_specific_args(parser)
# # parser = LightningANIL.add_model_specific_args(parser)
# # parser = pl.Trainer.add_argparse_args(parser)

In [None]:
parser.parse_args()

In [None]:
parser = main()

In [None]:
parser.parse_args()

In [None]:
lightning = LitFFNN(trains, 
                    vals, 
                    layer_1_dim = 256,#hyperparams['layer_1_size'],
                    layer_2_dim = 256,#, hyperparams['layer_2_size'],
                    layer_3_dim = 32, #hyperparams['layer_3_size'],
                    learning_rate = .1,#hyperparams['learning_rate'],
                    batch_size=50, 
                    dropout=.2#hyperparams['dropout']
                    )

#setr up the trainer/logger/callbacks
checkpoint_callback=ModelCheckpoint(
                dirpath = 'checkpoint_dir',
                save_top_k=1,
                verbose=False,
                monitor='val_loss',
                mode='min'
                )

tube_logger = TestTubeLogger('checkpoint_dir', 
                            name='test_tube_logger')

trainer = pl.Trainer(max_epochs = 500,
                     logger=tube_logger,
                     progress_bar_refresh_rate=0,
                     weights_summary=None,
                     check_val_every_n_epoch=1,
                     checkpoint_callback=checkpoint_callback,
                    callbacks=[EarlyStopping(monitor='val_loss', 
                                            patience=20)]) #the patience of 20 is mentioned in the DeepMicro paper



In [None]:
import torch

In [None]:
qqq = LightningMAML(lightning.model, loss = torch.nn.MSELoss())

In [None]:
#from learn2learn.data import MetaDataset

In [None]:
class MetaLitFFNN(LitFFNN):
    
    def init_MAML_loss(self, ...):
        lr=0.001
        maml_lr=0.0005
        #iterations=1000
        ways=2
        shots=1
        # tps=16
        tps = 5

        fas=3
        mu = 1

        model = CompleteMHAPacking(proc_field, visit_field, emb_dim = 64,
                                   hidden_size = 32, head_count = 2, dropout = .3)
        model.to(device)

        meta_model = l2l.algorithms.MAML(model, lr=maml_lr)
        opt = optim.AdamW(meta_model.parameters(), lr=lr)

        #loss_func = nn.CrossEntropyLoss(reduction='mean')
        # bias = np.log( ( np.array(ny) / np.max(np.array(ny)) ))
        # linear = model.linear.state_dict()
        # linear['bias'] = torch.tensor( bias )
        # model.linear.load_state_dict(linear)


        lr = 0.001
        #  ClassBalancedFocalLoss(.95, 2), (.95, 5) also gave similarly poor results
        #  (please tell me if im setting up this loss function incorrectly)


        # ==> focus on Classbalanced
        # loss_func = ClassBalancedFocalLoss(.95, ny, 2)
        # loss_func.to(device)

        loss_func = nn.BCELoss()


        #we have 4 tasks to meta-learn from
        # most published approaches would not use the RA data in the metalearning step
        #                         --> this doesn't make sense to me for this context
        #              --> it makes more sense if we are stricly isolating how much carryover information we can get
        #   --> to get best results i don't see why eliminating RA is necessary, but this is an easy thing to change in the future
        num_tasks = 4

        #determine number of datapoints to sample at each iteration
        #samp_size = 100
        iterations = 500
    with torch.backends.cudnn.flags(enabled=False):
        for iteration in range(iterations):

            for _ in tqdm.tqdm( range(tps) ):
                learner = self.model.clone( allow_unused=True )

                #get data from a randomly selected task
                        #for the 'experiment' currently only looking at source tasks for the metalearning
                task = all_tasks[np.random.choice(4, 1)[0]]

                for step in range(fas):
                    # Compute validation loss
                    learner.zero_grad()
                    batch = next(task)
                    preds = learner(batch)
                    loss = loss_func(preds, to_categorical(batch.ird) )
                    learner.adapt(loss)
                    del batch, preds, loss

                batch = get_batch(task)

                source_loss = loss_func( learner(batch, device), to_categorical( batch.ird ) )

                batch = next( all_tasks[0].__iter__() )
                #get the loss with the fas-updated learner
                preds = learner(batch, device)       

                # Take the meta-learning step
                opt.zero_grad()
                val_error = loss_func(preds, to_categorical(batch.ird) ) #torch.as_tensor(batch.ird).long() )
                #val_error.backward()

                meta_pred_loss = val_error + mu * source_loss
                meta_pred_loss.backward()
                opt.step()

                del batch, preds, val_error, learner

            if iteration%5 == 0:
                print(iteration)