In [1]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

import sys
import os
module_path = os.path.abspath(os.path.join(os.pardir))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from datetime import datetime
import pandas as pd
import numpy as np
import joblib
from pathlib import Path
from sklearn import model_selection
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping

In [3]:
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

from project.datasets import Dataset, CTRPDataModule
from project.film_model import FiLMNetwork, ConcatNetwork

In [4]:
import pyarrow.dataset as ds
import pyarrow.feather as feather

In [5]:
def prepare(exp, subset=True):
    data_path = Path("../../film-gex-data/processed/")
    input_cols = joblib.load(data_path.joinpath("gene_cols.pkl"))
    
    if exp=='id':
        cpd_id = "master_cpd_id"
        cond_cols = np.array([cpd_id, 'cpd_conc_umol'])
    else:
        fp_cols = joblib.load(data_path.joinpath("fp_cols.pkl"))
        cond_cols = np.append(fp_cols, ['cpd_conc_umol'])
        
    if subset:
        dataset = ds.dataset(data_path.joinpath("train_sub.feather"), format='feather')
    else:
        dataset = ds.dataset(data_path.joinpath("train.feather"), format='feather')

    return dataset, input_cols, cond_cols


def cv(name, exp, gpus, nfolds, dataset, input_cols, cond_cols, batch_size):
    seed_everything(2299)
    cols = list(np.concatenate((input_cols, cond_cols, ['cpd_avg_pv'])))

    for fold in np.arange(0,nfolds):
        start = datetime.now()
        train = dataset.to_table(columns=cols, filter=ds.field('fold') != fold).to_pandas()
        val = dataset.to_table(columns=cols, filter=ds.field('fold') == fold).to_pandas()
        # DataModule
        dm = CTRPDataModule(train,
                            val,
                            input_cols,
                            cond_cols,
                            target='cpd_avg_pv',
                            batch_size=batch_size)
        print("Completed dataloading in {}".format(str(datetime.now() - start)))
        # Model
        start = datetime.now()
        if exp=='film':
            model = FiLMNetwork(len(input_cols), len(cond_cols))
        else:
            model = ConcatNetwork(len(input_cols), len(cond_cols))
        # Callbacks
        logger = TensorBoardLogger(save_dir=os.getcwd(),
                                   version="{}_{}_fold_{}".format(name, exp, fold),
                                   name='lightning_logs')
        early_stop = EarlyStopping(monitor='val_loss',
                                   min_delta=0.01)
        # Trainer
        start = datetime.now()
        trainer = Trainer(auto_lr_find=True,
                          auto_scale_batch_size=False,
                          max_epochs=25, 
                          gpus=[1,3],
                          logger=logger,
                          early_stop_callback=False,
                          distributed_backend='dp')
        print("Completed loading in {}".format(str(datetime.now() - start)))
        trainer.fit(model, dm)
        print("Completed fold {} in {}".format(fold, str(datetime.now() - start)))
    
    return print("/done")

In [6]:
dataset, input_cols, cond_cols = prepare('id', subset=True)

In [7]:
name = 'test'
exp = 'id'
gpus = 3
nfolds = 1

In [8]:
cv(name, exp, gpus, nfolds, dataset, input_cols, cond_cols, batch_size=256)

Completed dataloading in 0:00:15.086992


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [1,3]


Completed loading in 0:00:00.314845
Completed scaling in 0:01:36.945705
Completed dataset creation in 0:00:04.603609



  | Name       | Type        | Params
-------------------------------------------
0 | metric     | R2Score     | 0     
1 | inputs_emb | LinearBlock | 677 K 
2 | conds_emb  | LinearBlock | 96    
3 | block_1    | LinearBlock | 1 K   
4 | block_2    | LinearBlock | 161   
Finding best initial lr:  98%|█████████▊| 98/100 [00:02<00:00, 44.70it/s]Saving latest checkpoint..
Learning rate set to 0.0630957344480193

  | Name       | Type        | Params
-------------------------------------------
0 | metric     | R2Score     | 0     
1 | inputs_emb | LinearBlock | 677 K 
2 | conds_emb  | LinearBlock | 96    
3 | block_1    | LinearBlock | 1 K   
4 | block_2    | LinearBlock | 161   


Epoch 1:  24%|██▎       | 533/2254 [00:14<00:56, 30.43it/s, loss=0.090, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [00:19<00:00, 44.70it/s]

Epoch 1:  80%|███████▉  | 1799/2254 [00:43<00:11, 39.38it/s, loss=0.093, v_num=ld_0]
Validating: 0it [00:00, ?it/s][A
Epoch 1:  80%|███████▉  | 1800/2254 [00:45<00:12, 37.09it/s, loss=0.093, v_num=ld_0]
Epoch 1:  80%|████████  | 1806/2254 [00:46<00:12, 37.11it/s, loss=0.093, v_num=ld_0]
Epoch 1:  80%|████████  | 1812/2254 [00:46<00:11, 37.15it/s, loss=0.093, v_num=ld_0]
Epoch 1:  81%|████████  | 1819/2254 [00:46<00:11, 37.22it/s, loss=0.093, v_num=ld_0]
Epoch 1:  81%|████████  | 1826/2254 [00:46<00:11, 37.27it/s, loss=0.093, v_num=ld_0]
Epoch 1:  81%|████████▏ | 1833/2254 [00:46<00:11, 37.31it/s, loss=0.093, v_num=ld_0]
Validating:   8%|▊         | 35/455 [00:03<02:19,  3.01it/s][A
Epoch 1:  82%|████████▏ | 1840/2254 [00:46<00:11, 37.35it/s, loss=0.093, v_num=ld_0]
Epoch 1:  82%|████████▏ | 1847/2254 [00:46<00:10, 37.41it/s, loss=0.093, v_num=ld_0]
Epoch 1:  82%|████████▏ | 1854/2254 [00:46<00:10, 37.46it/s, loss=0.093, v_num=ld_0]
Epoch 1:  83%|████████▎ | 1861/2254 [00:46<00:10, 37

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  36%|███▌      | 815/2254 [00:20<00:40, 35.43it/s, loss=0.094, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  36%|███▌      | 817/2254 [00:20<00:40, 35.44it/s, loss=0.093, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  36%|███▋      | 818/2254 [00:20<00:40, 35.44it/s, loss=0.095, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]
Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  36%|███▋      | 819/2254 [00:20<00:40, 35.42it/s, loss=0.094, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  36%|███▋      | 821/2254 [00:20<00:40, 35.42it/s, loss=0.094, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  36%|███▋      | 822/2254 [00:20<00:40, 35.42it/s, loss=0.095, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:19<00:00,  1.25it/s]


Epoch 2:  37%|███▋      | 824/2254 [00:20<00:40, 35.44it/s, loss=0.096, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]
Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]


Epoch 2:  37%|███▋      | 825/2254 [00:20<00:40, 35.44it/s, loss=0.095, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]


Epoch 2:  37%|███▋      | 827/2254 [00:20<00:40, 35.43it/s, loss=0.093, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]
Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]


Epoch 2:  37%|███▋      | 828/2254 [00:20<00:40, 35.42it/s, loss=0.093, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]
Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.25it/s]


Epoch 2:  37%|███▋      | 833/2254 [00:20<00:40, 35.44it/s, loss=0.092, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:20<00:00,  1.24it/s]

Epoch 2:  37%|███▋      | 839/2254 [00:21<00:40, 35.21it/s, loss=0.091, v_num=ld_0]




Epoch 2:  45%|████▍     | 1011/2254 [00:25<00:34, 36.41it/s, loss=0.092, v_num=ld_0]

Finding best initial lr: 100%|██████████| 100/100 [01:24<00:00,  1.18it/s]


Epoch 2:  80%|███████▉  | 1799/2254 [00:43<00:11, 39.28it/s, loss=0.093, v_num=ld_0]
Validating: 0it [00:00, ?it/s][A
Epoch 2:  80%|███████▉  | 1800/2254 [00:46<00:12, 36.86it/s, loss=0.093, v_num=ld_0]
Epoch 2:  80%|████████  | 1809/2254 [00:46<00:12, 36.91it/s, loss=0.093, v_num=ld_0]
Validating:   3%|▎         | 12/455 [00:03<10:24,  1.41s/it][A
Epoch 2:  81%|████████  | 1818/2254 [00:46<00:11, 36.99it/s, loss=0.093, v_num=ld_0]
Epoch 2:  81%|████████  | 1827/2254 [00:46<00:11, 37.07it/s, loss=0.093, v_num=ld_0]
Epoch 2:  81%|████████▏ | 1836/2254 [00:46<00:11, 37.14it/s, loss=0.093, v_num=ld_0]
Validating:   8%|▊         | 38/455 [00:03<02:26,  2.85it/s][A
Epoch 2:  82%|████████▏ | 1845/2254 [00:46<00:10, 37.20it/s, loss=0.093, v_num=ld_0]
Epoch 2:  82%|████████▏ | 1854/2254 [00:47<00:10, 37.27it/s, loss=0.093, v_num=ld_0]
Validating:  12%|█▏        | 56/455 [00:03<00:52,  7.57it/s][A
Epoch 2:  83%|████████▎ | 1863/2254 [00:47<00:10, 37.34it/s, loss=0.093, v_num=ld_0]
Epoch 2: 



In [9]:
foo = torch.FloatTensor(val[cond_cols].to_numpy())
model.forward(foo)

tensor([[ -5.2721,  -5.0805,   1.7124,  ...,  -5.4074,   1.7200, -13.4092],
        [ -0.8518,   0.4557,  -1.1743,  ...,   0.7598,   1.3746,  -2.2434],
        [ -1.5418,   1.0390,  -1.1317,  ...,   1.3517,   2.7206,  -1.4814],
        ...,
        [ -1.4564,   1.6760,   0.2275,  ...,   1.7220,   1.8293,  -1.2928],
        [ -2.3295,  -1.4681,  -0.4503,  ...,  -2.5142,   2.1346,  -8.3439],
        [ -2.1074,  -1.6664,   0.6190,  ...,  -1.4646,   1.0616,  -6.0193]],
       grad_fn=<AddmmBackward>)

In [26]:
data.shape

(57676, 1502)

In [28]:
data[data['fold']==0].shape

(11474, 1502)

In [29]:
11474 / 512

22.41015625