In [1]:
# Copyright (C) 2022 Insitro, Inc. This software and any derivative works are licensed under the 
# terms of the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC 4.0), 
# accessible at https://creativecommons.org/licenses/by-nc/4.0/legalcode

In [1]:
import pytorch_lightning as pl

In [2]:
from datamodules import DataModule
from pyro_models import PyroModel

In [3]:
datamodule = DataModule(batch_size=64,
                             test_batch_size=64, 
                             poses=20, 
                             splits_fname='splits_jacs_full.npz',
                             dataset_csv_fname='JACS_full.csv',
                             df_eval_fname='df_eval_data.csv',
                             source_eval='caix',
                             cnn_feats_train_fname='cnn_feats_JACS_full.pt',
                             cnn_feats_eval_fname='cnn_feats_hca_ChEMBL.pt',
                             source_data='jacs_counts',)

Number of datapoints: 108528; using JACS dataset with counts
Getting CNN feats from : ../notebooks/cnn_feats_JACS_full.pt)


Calculating Getting graph data fingerprints:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph poses:   0%|          | 0/20 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

Getting graph data:   0%|          | 0/108528 [00:00<?, ?it/s]

train/val/test split: 0.7/0.1/0.2
Using eval data: 3324 samples
Getting CNN feats from : ../notebooks/cnn_feats_hca_ChEMBL.pt)


Calculating Getting eval graph data fingerprints:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting graph poses:   0%|          | 0/20 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Getting eval graph data:   0%|          | 0/3324 [00:00<?, ?it/s]

Using splits from: splits_jacs_full.npz


In [4]:
import os
import wandb

In [5]:
v_name = 'model_run'
exp_name = 'model_experiments'

In [6]:
pl.seed_everything(42 ** 3 - 42 ** 2)
model = PyroModel(hidden_dim=128,
               poses=datamodule.poses,
               final_act='exp',
               learning_rate=1e-4,
               weight_decay=0.0, 
               use_smiles=2048,
               pose_reduce='attn_gated',
               use_cnn_feats=224,
               clip_norm=1e-1,
               lrd_num_steps=1250,
               lrd_gamma=0.1,
               n_layers=2, dropout=0.5)


logger = pl.loggers.WandbLogger(log_model="all", name=v_name, project=exp_name)
logger.watch(model, log="all")


checkpoint_callback = pl.callbacks.ModelCheckpoint(
                     monitor="val_EMA_loss/dataloader_idx_0",
                     dirpath=exp_name + '_wandb' + '/' + v_name,
                     filename='best-{epoch:02d}-{val_loss:.2f}',
                     save_last=True)


trainer = pl.Trainer(gpus=1,
                     max_epochs=8,
                     logger=logger,
                     callbacks=[checkpoint_callback])

trainer.fit(model, datamodule)
trainer.test(model, datamodule)
trainer.validate(model, datamodule)
wandb.finish()

Global seed set to 72324
[34m[1mwandb[0m: Currently logged in as: [33mkirillshmilovich[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                  | Params
----------------------------------------------------------
0 | cnn_embed       | PyroSequential        | 189 K 
1 | smiles_embed    | PyroSequential        | 656 K 
2 | pose_attn_tanh  | PyroSequential        | 32.8 K
3 | pose_attn_sig   | PyroSequential        | 32.8 K
4 | pose_attn       | PyroLinear            | 128   
5 | post_add_layer  | PyroResidualNLayerMLP | 131 K 
6 | enrichment_head | PyroSequential        | 257   
7 | matrix_head     | PyroSequential        | 257   
----------------------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.175     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0             DataLoader 1
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
test_EMA_eval_ki_pearson   -0.038394860143770915    -0.038394860143770915
test_EMA_eval_ki_spearman  -0.32768327688295357     -0.32768327688295357
      test_EMA_loss         1.4221405982971191
        test_loss           1.4220893383026123
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Validate metric                 DataLoader 0                   DataLoader 1
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    VAL_EMA_eval_ki_pearson         -0.038394860143770915          -0.038394860143770915
   VAL_EMA_eval_ki_spearman         -0.32768327688295357           -0.32768327688295357
val_EMA_eval_ki_pearson_subset      -0.11437773148524351           -0.11437773148524351
val_EMA_eval_ki_spearman_subset     -0.18470103181591158           -0.18470103181591158
         val_EMA_loss                1.4231353998184204
           val_loss                  1.4230819940567017
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


VBox(children=(Label(value='107.843 MB of 107.843 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
VAL_EMA_eval_ki_pearson,▁████████
VAL_EMA_eval_ki_spearman,▁▇███████
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇█
global_step,▁▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇████████
test_EMA_eval_ki_pearson,▁
test_EMA_eval_ki_spearman,▁
test_EMA_loss/dataloader_idx_0,▁
test_loss/dataloader_idx_0,▁
train_loss,█▃▃▇▂▃▃▂▂▂▂▂▂▂▃▁▃▄▃▂▂▃▂▃▂▂▂▂▂▁▃▅▃▅▂▄▃▄▃▆
trainer/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████

0,1
VAL_EMA_eval_ki_pearson,-0.03839
VAL_EMA_eval_ki_spearman,-0.32768
epoch,8.0
global_step,10856.0
test_EMA_eval_ki_pearson,-0.03839
test_EMA_eval_ki_spearman,-0.32768
test_EMA_loss/dataloader_idx_0,1.42214
test_loss/dataloader_idx_0,1.42209
train_loss,1.42737
trainer/global_step,10856.0
