In [1]:
import sys 
sys.path.append('../')

In [2]:
from dataset import LincsDataset
from torch_geometric.loader import DataLoader
from model import BaseModel
from aae import AAE
from model_utils import get_params
from pytorch_lightning import Trainer
from torch.utils.data import ConcatDataset
from datetime import datetime
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from model_utils import transfer_trained_weights
import argparse
import torch
import pandas as pd
import numpy as np
import itertools


2023-05-03 21:32:39.987137: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Load a pretrained VAE

## Set up arguments

In [19]:
# args
batch_size = 1
NUM_WORKERS = 4
train_split1 = "train_0"
valid_split = "valid_0"

layer_type = "FiLMConv"
model_architecture = 'vae'
gradient_clip_val = 1.0
max_lr = 1e-5
gen_step_drop_probability = 0
use_oclr_scheduler = True
using_cyclical_anneal = False
use_clamp_log_var = False

raw_moler_trace_dataset_parent_folder = "/data/ongh0068/guacamol/trace_dir"
# raw_moler_trace_dataset_parent_folder = "/data/ongh0068/l1000/TRACE_DIR"
output_pyg_trace_dataset_parent_folder = (
    "/data/ongh0068/l1000/already_batched"
)
    

In [20]:
train_dataset = LincsDataset(
    root="/data/ongh0068",
    raw_moler_trace_dataset_parent_folder=raw_moler_trace_dataset_parent_folder,  # "/data/ongh0068/l1000/trace_playground",
    output_pyg_trace_dataset_parent_folder=output_pyg_trace_dataset_parent_folder,
    gene_exp_controls_file_path="/data/ongh0068/l1000/lincs/robust_normalized_controls.npz",
    gene_exp_tumour_file_path="/data/ongh0068/l1000/lincs/robust_normalized_tumors.npz",
    lincs_csv_file_path="/data/ongh0068/l1000/lincs/experiments_filtered.csv",
    split=train_split1,
    gen_step_drop_probability=gen_step_drop_probability,
)
train_dataset

Loading controls gene expression...
Loading tumour gene expression...
Loading csv...


LincsDataset(794)

In [5]:
valid_dataset = LincsDataset(
        root="/data/ongh0068",
        raw_moler_trace_dataset_parent_folder=raw_moler_trace_dataset_parent_folder,  # "/data/ongh0068/l1000/trace_playground",
        output_pyg_trace_dataset_parent_folder=output_pyg_trace_dataset_parent_folder,
        gene_exp_controls_file_path="/data/ongh0068/l1000/lincs/robust_normalized_controls.npz",
        gene_exp_tumour_file_path="/data/ongh0068/l1000/lincs/robust_normalized_tumors.npz",
        lincs_csv_file_path="/data/ongh0068/l1000/lincs/experiments_filtered.csv",
        split=valid_split,
    )

Loading controls gene expression...
Loading tumour gene expression...
Loading csv...


In [27]:
train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        # sampler=train_sampler,
        follow_batch=[
            "correct_edge_choices",
            "correct_edge_types",
            "valid_edge_choices",
            "valid_attachment_point_choices",
            "correct_attachment_point_choice",
            "correct_node_type_choices",
            "original_graph_x",
            "correct_first_node_type_choices",
        ],
        num_workers=NUM_WORKERS,
        # prefetch_factor=0,
    )

In [7]:
valid_dataloader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        # sampler=valid_sampler,
        follow_batch=[
            "correct_edge_choices",
            "correct_edge_types",
            "valid_edge_choices",
            "valid_attachment_point_choices",
            "correct_attachment_point_choice",
            "correct_node_type_choices",
            "original_graph_x",
            "correct_first_node_type_choices",
        ],
        num_workers=NUM_WORKERS,
        # prefetch_factor=0,
    )

In [28]:
one = next(iter(train_dataloader))
one

MolerDataBatch(x=[16413, 59], edge_index=[2, 33312], original_graph_edge_features=[65434], original_graph_node_categorical_features=[29966], focus_node=[1000], partial_graph_edge_features=[33312], edge_features=[7327, 3], correct_edge_choices=[7327], correct_edge_choices_batch=[7327], correct_edge_choices_ptr=[1001], num_correct_edge_choices=[1000], stop_node_label=[1000], valid_edge_choices=[7327, 2], valid_edge_choices_batch=[7327], valid_edge_choices_ptr=[1001], valid_edge_types=[512, 3], correct_edge_types=[512, 3], correct_edge_types_batch=[512], correct_edge_types_ptr=[1001], partial_node_categorical_features=[16413], correct_attachment_point_choice=[58], correct_attachment_point_choice_batch=[58], correct_attachment_point_choice_ptr=[1001], correct_node_type_choices=[468, 166], correct_node_type_choices_batch=[468], correct_node_type_choices_ptr=[1001], correct_first_node_type_choices=[1000, 166], correct_first_node_type_choices_batch=[1000], correct_first_node_type_choices_ptr=

In [29]:
one['dose'].size(0) == 1000

True

In [30]:
for batch_id, batch in enumerate(train_dataloader):
    if batch['dose'].size(0) == 1000:
        continue
    else:
        print(batch_id)
        print(batch)

791
MolerDataBatch(x=[7921, 59], edge_index=[2, 16016], original_graph_edge_features=[30916], original_graph_node_categorical_features=[14157], focus_node=[520], partial_graph_edge_features=[16016], edge_features=[3265, 3], correct_edge_choices=[3265], correct_edge_choices_batch=[3265], correct_edge_choices_ptr=[521], num_correct_edge_choices=[520], stop_node_label=[520], valid_edge_choices=[3265, 2], valid_edge_choices_batch=[3265], valid_edge_choices_ptr=[521], valid_edge_types=[259, 3], correct_edge_types=[259, 3], correct_edge_types_batch=[259], correct_edge_types_ptr=[521], partial_node_categorical_features=[7921], correct_attachment_point_choice=[30], correct_attachment_point_choice_batch=[30], correct_attachment_point_choice_ptr=[521], correct_node_type_choices=[247, 166], correct_node_type_choices_batch=[247], correct_node_type_choices_ptr=[521], correct_first_node_type_choices=[520, 166], correct_first_node_type_choices_batch=[520], correct_first_node_type_choices_ptr=[521], s

In [31]:
len(train_dataloader)

794

In [9]:
torch.cat((one['gene_expressions'], one['dose'].unsqueeze(-1)), dim=-1)

tensor([[ 0.1534,  0.1941, -0.1085,  ...,  1.1979,  0.8452,  1.0735],
        [ 0.2131,  0.0942,  0.1915,  ...,  1.0251,  0.8387,  1.0735],
        [ 0.0908,  0.1827, -0.1887,  ...,  1.2491,  0.8054,  0.7423],
        ...,
        [-0.0321,  0.0563, -0.1736,  ...,  0.9675,  0.6692,  1.0000],
        [ 0.2147,  0.1086, -0.0809,  ...,  1.1219,  0.7681,  1.0000],
        [ 0.7786,  0.4293,  0.8607,  ...,  0.9577,  0.3489,  1.0000]])

In [10]:
params = get_params(dataset=train_dataset)  # train_dataset)
params

{'full_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166,
  'aggr_layer_type': 'MoLeRAggregation',
  'total_num_moler_aggr_heads': 32,
  'layer_type': 'FiLMConv'},
 'partial_graph_encoder': {'input_feature_dim': 59,
  'atom_or_motif_vocab_size': 166,
  'aggr_layer_type': 'MoLeRAggregation',
  'total_num_moler_aggr_heads': 16,
  'layer_type': 'FiLMConv'},
 'mean_log_var_mlp': {'input_feature_dim': 832,
  'output_size': 1024,
  'hidden_layer_dims': [],
  'use_bias': False},
 'decoder': {'node_type_selector': {'input_feature_dim': 1344,
   'output_size': 167},
  'use_node_type_loss_weights': True,
  'node_type_loss_weights': tensor([10.0000,  0.1000,  3.6015,  0.1000,  0.1000,  0.4439,  0.7549,  0.4416,
          10.0000,  2.7939,  3.3916, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000, 10.0000,
          10.0000,

In [11]:
params["full_graph_encoder"]["layer_type"] = layer_type
params["partial_graph_encoder"]["layer_type"] = layer_type
params["use_oclr_scheduler"] = use_oclr_scheduler
params["using_cyclical_anneal"] = using_cyclical_anneal
model_architecture = model_architecture
params["max_lr"] = max_lr

In [13]:
# ckpt_path = '/data/ongh0068/2023-03-05_14_24_55.916122/epoch=24-val_loss=0.29.ckpt'
ckpt_path = '../first_stage_models/2023-03-03_09_30_01.589479/epoch=12-val_loss=0.46.ckpt'
checkpoint = torch.load(ckpt_path)
checkpoint.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])

In [14]:
checkpoint['state_dict']

OrderedDict([('_full_graph_encoder._dummy_param', tensor([], device='cuda:2')),
             ('_full_graph_encoder._embed.weight',
              tensor([[ 1.3196,  0.2931,  0.8071,  ...,  0.2062, -0.6496, -1.1868],
                      [-1.8402,  0.1664,  0.0733,  ...,  0.1846, -1.6061, -0.7596],
                      [ 0.2631,  0.0088, -0.3884,  ...,  0.7362,  0.6730, -0.1531],
                      ...,
                      [ 0.6153, -0.4552, -2.0542,  ..., -1.0122, -1.8221, -1.5664],
                      [ 1.1537, -0.1217,  0.2818,  ..., -1.2707,  0.8064, -0.1662],
                      [-2.0246,  1.2176,  0.4897,  ..., -0.9578, -0.6310,  0.6666]],
                     device='cuda:2')),
             ('_full_graph_encoder._model._first_layer.lins.0.weight',
              tensor([[-0.0877,  0.0317,  0.2789,  ...,  0.0152,  0.1249,  0.0469],
                      [-0.0665,  0.0976,  0.0030,  ...,  0.0678, -0.0192,  0.0892],
                      [-0.0352,  0.0101,  0.1400,  ...,  0

## Load from ckpt directly

In [15]:
# model = BaseModel(
#     params,
#     valid_dataset,
#     using_lincs=True,
#     num_train_batches=len(train_dataloader),
#     batch_size=batch_size,
#     use_clamp_log_var = True if use_clamp_log_var is not None else False
# )
# model

model = BaseModel.load_from_checkpoint(ckpt_path, params = params, dataset = train_dataset, use_lincs = False)
model

BaseModel(
  (_full_graph_encoder): GraphEncoder(
    (_embed): Embedding(166, 64)
    (_model): GenericGraphEncoder(
      (_first_layer): FiLMConv(123, 64, num_relations=3)
      (_encoder_layers): ModuleList(
        (0): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (1): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (2): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (3): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (4): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (5): Sequential(
          (0): LayerNorm(64, mode=graph)
          (1): FiLMConv(64, 64, num_relations=3)
        )
        (6): Sequential(
          (0)

In [42]:
# model.load_state_dict(checkpoint['state_dict'])

In [16]:
def sanitise(row):
    """Specifically for the L1000 csv"""
    control_indices = (
        row["ControlIndices"]
        .replace("[", "")
        .replace("]", "")
        .replace("\n", "")
        .split(" ")
    )
    control_indices = [idx for idx in control_indices if idx != ""]
    row["ControlIndices"] = np.asarray(control_indices, dtype=np.int32)
    tumour_indices = (
        row["TumourIndices"]
        .replace("[", "")
        .replace("]", "")
        .replace("\n", "")
        .split(" ")
    )
    tumour_indices = [idx for idx in tumour_indices if idx != ""]
    row["TumourIndices"] = np.asarray(tumour_indices, dtype=np.int32)
    return row

In [17]:
test_set = pd.read_csv("/data/ongh0068/l1000/INPUT_DIR/test.csv")
test_set = test_set.apply(lambda x: sanitise(x), axis=1)


reference_smiles = test_set.SMILES.to_list()
control_idxes = test_set.ControlIndices.values
tumour_idxes = test_set.TumourIndices.values
original_idxes = test_set.original_idx.to_list()

control_idxes

array([array([17160, 17161, 17162, 17163, 17164, 17165, 17166, 17167, 17168,
              17169, 17170, 17171, 17172, 17173, 17174, 17175, 17176, 17177,
              17178, 17179, 17180, 17181, 17182, 17183, 17184, 17185, 17186,
              17187, 17188, 17189, 17190, 17191, 17192, 17193, 17194, 17195,
              17196, 17197], dtype=int32)                                   ,
       array([25749, 25750, 25751, 25752, 25753, 25754, 25755, 25756, 25757,
              25758, 25759, 25760, 25761, 25762, 25763, 25764, 25765, 25766,
              25767, 25768, 25769, 25770, 25771, 25772, 25773, 25774, 25775,
              25776, 25777, 25778, 25779, 25780, 25781, 25782, 25783, 25784,
              25785, 25786, 25787, 25788, 25789, 25790, 25791], dtype=int32),
       array([12594, 12595, 12596, 12597, 12598, 12599, 12600, 12601, 12602,
              12603, 12604, 12605, 12606, 12607, 12608, 12609, 12610, 12611,
              12612, 12613, 12614, 12615, 12616, 12617, 12618, 12619, 1262

In [18]:
control_idx = control_idxes[0]
tumour_idx = tumour_idxes[0]
original_idx = original_idxes[0]
control_idx

array([17160, 17161, 17162, 17163, 17164, 17165, 17166, 17167, 17168,
       17169, 17170, 17171, 17172, 17173, 17174, 17175, 17176, 17177,
       17178, 17179, 17180, 17181, 17182, 17183, 17184, 17185, 17186,
       17187, 17188, 17189, 17190, 17191, 17192, 17193, 17194, 17195,
       17196, 17197], dtype=int32)

In [19]:
possible_pairs = np.array(list(itertools.product(control_idx, tumour_idx)))

control_idx_batched = possible_pairs[:, 0]
tumour_idx_batched = possible_pairs[:, 1]

In [20]:
num_rand_vectors_required = 20
rand_vect_dim = 512
control_gene_exp_batched = train_dataset._gene_exp_controls[control_idx_batched]
tumour_gene_exp_batched = train_dataset._gene_exp_tumour[tumour_idx_batched]
difference_gene_exp_batched = tumour_gene_exp_batched - control_gene_exp_batched
difference_gene_exp_batched

array([[-3.3737407 ,  0.23622662, -0.3157369 , ...,  0.01545614,
        -0.3174927 , -0.03243435],
       [ 0.03139508,  0.06642026, -0.2858467 , ..., -0.04326904,
         0.42397094,  0.30014277],
       [ 0.34698838, -0.11357671, -0.345627  , ...,  0.00950205,
         0.58586156, -0.14038225],
       ...,
       [-3.9856606 ,  0.29206398, -0.00577831, ...,  0.07723993,
        -0.77160954,  0.11332911],
       [-0.5805249 ,  0.12225762,  0.02411187, ...,  0.01851475,
        -0.03014591,  0.44590622],
       [-0.26493162, -0.05773935, -0.03566843, ...,  0.07128584,
         0.13174471,  0.00538122]], dtype=float32)

In [22]:
num_samples = num_rand_vectors_required
difference_gene_exp_batched = torch.tensor(
    difference_gene_exp_batched[:num_samples, :]
)
random_vectors = torch.randn(
    num_rand_vectors_required, rand_vect_dim
)

random_vectors.size()

torch.Size([20, 512])

In [23]:
dose_batched = (
    torch.from_numpy(
        np.repeat(
            train_dataset._experiment_idx_to_dose[original_idx], (random_vectors.shape[0])
        )
    )
    .float()
)

dose_batched.size()

torch.Size([20])

In [24]:
conditioned_random_vectors = model.condition_on_gene_expression(
    latent_representation=random_vectors,
    gene_expressions=difference_gene_exp_batched,
    dose=dose_batched,
)

AttributeError: 'BaseModel' object has no attribute '_gene_exp_condition_mlp'

# Try loading VAE in LDM model

In [17]:
from omegaconf import OmegaConf
import importlib


In [32]:
config_file = 'config/ddim_vae_uncon.yml'
config = OmegaConf.load(config_file)
config

{'model': {'base_learning_rate': 5e-05, 'target': 'ldm.models.diffusion.ddpm.LatentDiffusion', 'params': {'linear_start': 0.00085, 'linear_end': 0.012, 'num_timesteps_cond': 1, 'log_every_t': 200, 'timesteps': 1000, 'first_stage_key': 'image', 'cond_stage_key': 'caption', 'image_size': 32, 'channels': 4, 'cond_stage_trainable': True, 'conditioning_key': 'crossattn', 'monitor': 'val/loss_simple_ema', 'scale_factor': 0.18215, 'use_ema': False}, 'first_stage_config': {'target': 'model.BaseModel', 'model_type': 'vae'}}}

In [33]:
# params contain tensors - unsupportive
# config['model']['first_stage_config']['params'] = params
config['model']['first_stage_config']

{'target': 'model.BaseModel', 'model_type': 'vae'}

In [21]:
def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    print(cls)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)

In [22]:
m = get_obj_from_str(config['model']['first_stage_config']['target'])
m

BaseModel


model.BaseModel

In [26]:
def instantiate_first_stage_model(config, ckpt, **kwargs):
    # vae kwargs: params, dataset, using_lincs, include_predict_gene_exp_mlp = False, num_train_batches=1, batch_size=1, use_clamp_log_var = False
    if not "target" in config:
        if config == '__is_first_stage__':
            return None
        elif config == "__is_unconditional__":
            return None
        raise KeyError("Expected key `target` to instantiate.")
    model = get_obj_from_str(config["target"])
    model.load_from_checkpoint(ckpt, **kwargs)
    
    return model

In [31]:
ckpt_path = "/data/ongh0068/l1000/2023-03-11_23_33_36.921147/epoch=07-val_loss=0.60.ckpt"

first_stage_m = instantiate_first_stage_model(config['model']['first_stage_config'], ckpt_path, params=params, dataset=valid_dataset, using_lincs=True)
first_stage_m

BaseModel


FileNotFoundError: [Errno 2] No such file or directory: '/data/ongh0068/l1000/2023-03-11_23_33_36.921147/epoch=07-val_loss=0.60.ckpt'

In [25]:
first_stage_m.state_dict()

OrderedDict([('_full_graph_encoder._dummy_param', tensor([])),
             ('_full_graph_encoder._embed.weight',
              tensor([[ 0.3401, -0.2626,  2.0737,  ...,  1.1110, -0.5459, -0.7753],
                      [ 0.5659, -0.2552,  2.9655,  ..., -0.7322,  0.9360, -3.2468],
                      [-0.0572, -0.4171,  1.3939,  ...,  1.3022, -0.4804, -0.9878],
                      ...,
                      [ 0.8405,  2.1865,  1.8139,  ..., -1.0245,  0.3250, -0.9826],
                      [ 2.9078,  0.1768,  2.0735,  ...,  1.3644,  0.7263, -1.0662],
                      [ 1.8306, -2.4170, -0.9222,  ...,  0.5893, -0.9238,  0.4638]])),
             ('_full_graph_encoder._model._first_layer.lins.0.weight',
              tensor([[-0.0235,  0.0517, -0.0179,  ..., -0.0861,  0.0156,  0.0049],
                      [ 0.0612, -0.0610, -0.0026,  ...,  0.0882, -0.0012, -0.0140],
                      [-0.0618,  0.0394, -0.0762,  ..., -0.0060, -0.0734,  0.0807],
                      ...,
  