In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

import sys
import os
module_path = os.path.abspath(os.path.join(os.pardir))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from datetime import datetime
import pandas as pd
import numpy as np
import joblib
from pathlib import Path
from sklearn import model_selection
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping

In [18]:
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

from project.datasets import Dataset, CTRPDataModule
from project.models import FiLMNetwork, ConcatNetwork, ConditionalNetwork

In [7]:
import pyarrow.dataset as ds
import pyarrow.feather as feather

In [8]:
def prepare(exp, subset=True):
    data_path = Path("../../film-gex-data/processed/")
    input_cols = joblib.load(data_path.joinpath("gene_cols.pkl"))
    
    if exp=='id':
        cpd_id = "master_cpd_id"
        cond_cols = np.array([cpd_id, 'cpd_conc_umol'])
    else:
        fp_cols = joblib.load(data_path.joinpath("fp_cols.pkl"))
        cond_cols = np.append(fp_cols, ['cpd_conc_umol'])
        
    if subset:
        dataset = ds.dataset(data_path.joinpath("train_sub.feather"), format='feather')
    else:
        dataset = ds.dataset(data_path.joinpath("train.feather"), format='feather')

    return dataset, input_cols, cond_cols


def cv(name, exp, gpus, nfolds, dataset, input_cols, cond_cols, batch_size):
    seed_everything(2299)
    cols = list(np.concatenate((input_cols, cond_cols, ['cpd_avg_pv'])))

    for fold in np.arange(0,nfolds):
        start = datetime.now()
        train = dataset.to_table(columns=cols, filter=ds.field('fold') != fold).to_pandas()
        val = dataset.to_table(columns=cols, filter=ds.field('fold') == fold).to_pandas()
        # DataModule
        dm = CTRPDataModule(train,
                            val,
                            input_cols,
                            cond_cols,
                            target='cpd_avg_pv',
                            batch_size=batch_size)
        print("Completed dataloading in {}".format(str(datetime.now() - start)))
        # Model
        start = datetime.now()
        if exp=='film':
            model = FiLMNetwork(len(input_cols), len(cond_cols))
        else:
            model = ConcatNetwork(len(input_cols), len(cond_cols))
        # Callbacks
        logger = TensorBoardLogger(save_dir=os.getcwd(),
                                   version="{}_{}_fold_{}".format(name, exp, fold),
                                   name='lightning_logs')
        early_stop = EarlyStopping(monitor='val_loss',
                                   min_delta=0.01)
        # Trainer
        start = datetime.now()
        trainer = Trainer(auto_lr_find=True,
                          auto_scale_batch_size=False,
                          max_epochs=25, 
                          gpus=[1,3],
                          logger=logger,
                          early_stop_callback=False,
                          distributed_backend='dp')
        print("Completed loading in {}".format(str(datetime.now() - start)))
        trainer.fit(model, dm)
        print("Completed fold {} in {}".format(fold, str(datetime.now() - start)))
    
    return print("/done")

In [9]:
dataset, input_cols, cond_cols = prepare('id', subset=True)

In [12]:
name = 'test'
exp = 'id'
gpus = 3
nfolds = 1

In [26]:
model = ConditionalNetwork(exp, len(input_cols), len(cond_cols), batch_size=256)

In [27]:
model.hparams

"batch_size":    256
"conds_sz":      2
"exp":           id
"inputs_sz":     978
"learning_rate": 0.001
"metric":        <function r2_score at 0x7fa5c34c8e50>
"ps":            [0.2]

In [1]:
#cv(name, exp, gpus, nfolds, dataset, input_cols, cond_cols, batch_size=256)

In [None]:
name

In [6]:
logger = TensorBoardLogger(save_dir=os.getcwd(),
                           version="{}_{}_fold_{}",
                           name='lightning_logs')

In [8]:
logger.log_dir

'/srv/home/wconnell/github/film-gex/notebooks/lightning_logs/{}_{}_fold_{}'

In [15]:
trainer = Trainer(logger=logger)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores


In [16]:
trainer.default_root_dir

'/srv/home/wconnell/github/film-gex/notebooks'

In [17]:
trainer.logger.log_dir

'/srv/home/wconnell/github/film-gex/notebooks/lightning_logs/{}_{}_fold_{}'

In [18]:
FiLMNetwork.log()

Object `trainer.log` not found.
