# Test `prediction.py` model loading
Michael Nolan   2021.01.15

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb

import os
import sys
sys.path.append(os.getcwd())
import prediction # this is where the pt-l model code is

import h5py
import argparse

import matplotlib.pyplot as plt

In [None]:
# this is a beefy config file! Stanardize this, for sure
wandb.init(
    config = {
        # model-agnostic hyperparameters
        'data_file_path': "D:\\Users\\mickey\\Data\\datasets\\ecog\\goose_wireless\\gw_250",
        'batch_size': 1000,
        'sequence_length': 50,
        'data_suffix': 'ecog',
        'objective_function': 'mse',
        'learning_rate': 0.001,
        'learning_rate_factor': 0.9,
        'device': 'cuda',

        # model-specific hyperparameters
        'g_encoder_size': 10,
        'c_encoder_size': 0,
        'g_latent_size': 10,
        'u_latent_size': 0,
        'controller_size': 0,
        'generator_size': 10,
        'factor_size': 10,
        'prior': {
            'g0' : {
                'mean' : {
                    'value': 0.0, 
                    'learnable' : True
                    },
                'var'  : {
                    'value': 0.1, 
                    'learnable' : True
                    },
                },
            'u'  : {
                'mean' : {
                    'value': 0.0, 
                    'learnable' : False
                    },
                'var'  : {
                    'value': 0.1, 
                    'learnable' : True
                    },
                'tau'  : {
                    'value': 10, 
                    'learnable' : True
                    },
                },
            },
        'clip_val': 2.0,
        'max_norm': 5.0,
        'do_normalize_factors': True,
        'factor_bias': False,
        'loss_weight_dict': {
            'kl': {
                'weight': 0.0,
                'min': 0.0,
                'max': 1.0,
                'schedule_dur': 1600,
                'schedule_start': 0,
            },
            'l2': {
                'weight': 0.0,
                'min': 0.0,
                'max': 1.0,
                'schedule_dur': 1600,
                'schedule_start': 0.0,
            },
            'l2_con_scale': 0,
            'l2_gen_scale': 2000,
        },
        'l2_gen_scale': 0.9,
        'l2_con_scale': 0.9,
        'dropout': 0.0,
        },
    mode="disabled"
    )

In [None]:
class LfadsModel_ECoG(prediction.Lfads):

    def __init__(self, config):
        self.data_file_path         = config.data_file_path
        self.batch_size             = config.batch_size # check this - set by wandb in hparam sweeps
        self.seq_len                = config.sequence_length
        self.data_suffix            = config.data_suffix
        self.objective_function     = config.objective_function
        self.prepare_data()
        with h5py.File(self.data_file_path,'r') as hdf:
            _, _, self.input_size = hdf[f'test_{self.data_suffix}'].shape
        self.generator_size         = config.generator_size
        self.g_encoder_size         = config.g_encoder_size
        self.g_latent_size          = config.g_latent_size
        self.controller_size        = config.controller_size
        self.c_encoder_size         = config.c_encoder_size
        self.u_latent_size          = config.u_latent_size
        self.factor_size            = config.factor_size

        self.prior                  = config.prior

        self.clip_val               = config.clip_val
        self.factor_bias            = config.factor_bias

        # is it poor form to put this at the end of the init call?
        super(LfadsModel_ECoG, self).__init__(config)

        # # create modules
        # self.encoder = LFADS_Encoder(
        #     self.input_size, 
        #     self.g_encoder_size, 
        #     self.g_latent_size, 
        #     c_encoder_size = self.c_encoder_size, 
        #     dropout = self.dropout, 
        #     clip_val = self.clip_val
        #     )
        # self.controller = LFADS_ControllerCell(
        #     self.input_size, 
        #     self.controller_size, 
        #     self.u_latent_size, 
        #     dropout = self.dropout, 
        #     clip_val = self.clip_val, 
        #     factor_bias=self.factor_bias
        #     )
        # self.generator = LFADS_GeneratorCell(
        #     input_size, 
        #     generator_size, 
        #     factor_size,
        #     attention = False,
        #     dropout=self.dropout, 
        #     clip_val=self.clip_val, 
        #     factor_bias=self.factor_bias
        #     )

    def prepare_data(self):
        # load datasets from hdf5 volume
        data_dict = self.read_h5(self.data_file_path)
        self.train_dataset  = EcogSrcTrgDataset(
            data_dict[f"train_{self.data_suffix}"],
            self.seq_len
            )
        self.valid_dataset  = EcogSrcTrgDataset(
            data_dict[f"valid_{self.data_suffix}"],
            self.seq_len
            )
        self.test_dataset   = EcogSrcTrgDataset(
            data_dict[f"test_{self.data_suffix}"],
            self.seq_len
            )

    # these are defined adequately in the prediction.Lfads class!
    # def training_step(self,train_batch,batch_idx):
    #     src, trg = train_batch
    #     recon, (factors, gen_inputs) = self.forward(src, trg)
    #     loss = self.loss()

    def train_dataloader(self):
        return DataLoader(self.train_dataset,batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset,batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset,batch_size=self.batch_size)
        
    @staticmethod
    def read_h5(data_file_path):
        try:
            with h5py.File(data_file_path, 'r') as hf:
                data_dict = {k: torch.tensor(hf[k].value) for k in hf.keys()}
            return data_dict
        except IOError:
            print(f'Cannot open data file {data_file_path}.')
            raise
    
    @staticmethod
    def add_model_specific_arguments(parent_parser):
        parser = super(LFADS)
        parser = argparse.ArgumentParser(parents=[parent_parser],add_help=False)
        parser.add_argument() # add this for all LFADS hyperparameters! (like the obj. function)

class EcogSrcTrgDataset(Dataset):
    
    def __init__(self, tensor, seq_len):
        assert tensor.shape[1] >= 2*seq_len, f"sequence length cannot be longer than 1/2 data sample length ({tensor.shape[1]})"
        self.tensor = torch.tensor(tensor).float()
        self.seq_len = seq_len

    def __getitem__(self, index):
        src = self.tensor[index,:self.seq_len,:]
        trg = self.tensor[index,self.seq_len:2*self.seq_len,:]
        return (src, trg)

    def __len__(self):
        return self.tensor.shape[0]

In [None]:
# create model
prediction_model_shell = LfadsModel_ECoG(wandb.config)
prediction_model_shell.train_dataloader()

In [None]:
wandb_logger = WandbLogger(
    name='LFADS-wandbtest',
    project='GW_ECoG-Prediction'
)

In [None]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor = 'avg_valid_loss',
    dirpath = 'D:\\Users\\mickey\\Data\\models\\pytorch-lightning\\',
    filename = 'lfads-{epoch:03d}-{val_loss:.3f}',
)
early_stopping_callback = pl.callbacks.early_stopping.EarlyStopping(
    monitor ='avg_valid_loss'
)
trainer = pl.Trainer(max_epochs=100, 
                    logger = wandb_logger, 
                    gpus=1, 
                    callbacks=[checkpoint_callback, early_stopping_callback])

In [None]:
trainer.fit(prediction_model_shell)