In [1]:
import autoencoder
import utils
import mrrmse

import pandas as pd
import torch

from sklearn.model_selection import KFold, train_test_split
import numpy as np
import tqdm
import random

from hyperopt import hp
from hyperopt.pyll import scope
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.hyperopt import HyperOptSearch

## Prepare data:
#### Read joined data (pre + post treatment)

In [2]:
lincs_joined_df = pd.read_parquet("data/lincs_pretreatment.parquet")
kaggle_joined_df = pd.read_parquet("data/kaggle_pretreatment.parquet")
test_joined_df = pd.read_parquet("data/test_pretreatment.parquet")
print(f"lincs_joined_df = {lincs_joined_df.shape}\nkaggle_joined_df = {kaggle_joined_df.shape}\ntest_joined_df = {test_joined_df.shape}")

lincs_joined_df = (107404, 1842)
kaggle_joined_df = (602, 1841)
test_joined_df = (255, 921)


#### Kaggle provided data

In [3]:
de_train = pd.read_parquet('data/de_train.parquet')
id_map = pd.read_csv('data/id_map.csv',index_col='id')
print(f"de_train = {de_train.shape}\nid_map = {id_map.shape}")

de_train = (614, 18216)
id_map = (255, 2)


#### Define features of interest and sort data accordingly.

In [4]:
features = ['cell_type', 'sm_name']
multiindex_features = [("label",'cell_type'),("label",'sm_name')]

transcriptome_cols = de_train.columns[5:]
landmark_cols = kaggle_joined_df["post_treatment"].columns
print(f"transcriptome_cols = {transcriptome_cols.shape}\nlandmark_cols = {landmark_cols.shape}")

transcriptome_cols = (18211,)
landmark_cols = (918,)


In [5]:
unique_sm_name = pd.concat([lincs_joined_df[("label","sm_name")],kaggle_joined_df[("label","sm_name")]]).drop_duplicates().reset_index(drop=True)
unique_cell_type = pd.concat([lincs_joined_df[("label","cell_type")],kaggle_joined_df[("label","cell_type")]]).drop_duplicates().reset_index(drop=True)
print(f"Number of unique molecules = {len(unique_sm_name)}.\nNumber of unique cell types = {len(unique_cell_type)}.")

Number of unique molecules = 1896.
Number of unique cell types = 36.


In [6]:
# We only need to sort these two dataframes because they represent the same underlying dataset.
de_train = de_train.query("~control").sort_values(features)
kaggle_joined_df = kaggle_joined_df.sort_values(multiindex_features)
# Sanity check that these dfs align.
genes_align = (kaggle_joined_df["post_treatment"] == de_train[landmark_cols]).all(axis=None)
labels_align = (kaggle_joined_df["label"][features] == de_train[features]).all(axis=None)
genes_align and labels_align

True

#### CV splits

In [7]:
eval_cells_only_df = kaggle_joined_df[kaggle_joined_df["label"]["cell_type"].isin(["B cells", "Myeloid cells"])][multiindex_features]
len(eval_cells_only_df)

30

In [8]:
fold_to_eval_df = {}
skf = KFold(n_splits=3, random_state=42, shuffle=True)
for i,(t,v) in enumerate(skf.split(eval_cells_only_df)):
    fold_to_eval_df[i] = eval_cells_only_df.iloc[v]

for i, df in fold_to_eval_df.items():
    print(f"fold = {i} of shape {df.shape}")

fold = 0 of shape (10, 2)
fold = 1 of shape (10, 2)
fold = 2 of shape (10, 2)


In [9]:
def make_mask(fold):
    val = fold_to_eval_df[fold]
    return kaggle_joined_df[("label","sm_name")].isin(val[("label","sm_name")]) & kaggle_joined_df[("label","cell_type")].isin(val[("label","cell_type")])

print("Using fold 0 as validation set:")
print(f"Train data = {pd.concat([kaggle_joined_df[~make_mask(0)],lincs_joined_df]).shape}")
print(f"Validation data = {kaggle_joined_df[make_mask(0)].shape}")

Using fold 0 as validation set:
Train data = (107994, 1843)
Validation data = (12, 1841)


In [10]:
class Translator(torch.nn.Module):
    def __init__(self,config):
        super(Translator,self).__init__()
        # This will eventually be changed to a GNN
        self.smiles_embed = torch.nn.Embedding(len(unique_sm_name), config["sm_emb_size"])

        # This needs to be able to handle out of dictionary
        self.cell_embed = torch.nn.Embedding(len(unique_cell_type), config["cell_emb_size"])

        self.config = config
        input_dim = config["sm_emb_size"] + config["cell_emb_size"] + config["latent_dim"]
        self.translation = utils.make_sequential(input_dim,config["hidden_dim"],config["latent_dim"],config["dropout"])

    def forward(self,inp,z):
        sm_emb = self.smiles_embed(inp["sm_name"])
        ct_emb = self.cell_embed(inp["cell_type"])
        x = torch.cat((sm_emb, ct_emb, z), dim=1)
        return self.translation(x)

In [11]:
class RNVAE(torch.nn.Module):
    cell_type_map = {v: k for k,v in unique_cell_type.to_dict().items()}
    sm_name_map = {v: k for k,v in unique_sm_name.to_dict().items()}
    
    def __init__(self,config):
        super(RNVAE,self).__init__()
        self.vae = autoencoder.AutoEncoder(target_dim=len(landmark_cols),config=config)
        self.translator = Translator(config)

    # This is super slow because we are iterating.
    @classmethod
    def make_input(cls, df, disabletqdm=True):
        ct = df[("label","cell_type")].map(cls.cell_type_map)
        sm = df[("label","sm_name")].map(cls.sm_name_map)
        return [{"cell_type":torch.tensor(ct.iloc[i]),
                "sm_name":torch.tensor(sm.iloc[i]),
                "pre_treatment":torch.tensor(df["pre_treatment"].iloc[i].to_numpy(),dtype=torch.float),
                "post_treatment":torch.tensor(df["post_treatment"].iloc[i].to_numpy(),dtype=torch.float)} for i in tqdm.tqdm(range(len(df)),disable=disabletqdm)]

    @classmethod
    def make_input_new(cls, df):
        ct = torch.tensor(df[("label","cell_type")].map(cls.cell_type_map).to_numpy())
        sm = torch.tensor(df[("label","sm_name")].map(cls.sm_name_map).to_numpy())
        pre = torch.tensor(df["pre_treatment"].to_numpy(),dtype=torch.float32)
        post = torch.tensor(df["post_treatment"].to_numpy(),dtype=torch.float32)
        
        return [{"cell_type":ct[i],
                "sm_name":sm[i],
                "pre_treatment":pre[i],
                "post_treatment":post[i]} for i in range(len(df))]

    @classmethod
    def make_test(cls,df):
        ct = torch.tensor(df[("label","cell_type")].map(cls.cell_type_map).to_numpy())
        sm = torch.tensor(df[("label","sm_name")].map(cls.sm_name_map).to_numpy())
        pre = torch.tensor(df["pre_treatment"].to_numpy(),dtype=torch.float32)
        
        return [{"cell_type":ct[i],
                "sm_name":sm[i],
                "pre_treatment":pre[i]} for i in range(len(df))]
    
    def forward(self,inp):
        latent = self.vae.latent(inp["pre_treatment"])
        z_prime = self.translator(inp,latent["z"])
        x_hat = self.vae.decode(z_prime)
        return {"x_hat":x_hat, "mu": latent["mu"], "log_var":latent["log_var"]}

    def loss_function(self,fwd,inp):
        return self.vae.loss_function(fwd,inp["post_treatment"])

In [12]:
class Imputer(torch.nn.Module):
    def __init__(self,config,rnvae):
        super(Imputer,self).__init__()
        self.impute_loss_weight = config["impute_loss_weight"]
        self.imp = utils.make_sequential(len(landmark_cols),config["hidden_dim"],len(transcriptome_cols),config["dropout"])
        self.rnvae = rnvae

    @classmethod
    def make_input(cls, mask):
        kg_df = kaggle_joined_df[mask]
        trn_df = de_train[mask]
        rninp = RNVAE.make_input(kg_df)
        trm = trn_df[transcriptome_cols]
        for i,inp in enumerate(rninp):
            inp["transcriptome"] = torch.tensor(trm.iloc[i].to_numpy(), dtype=torch.float)
        return rninp

    def forward(self,inp):
        fwd = self.rnvae(inp)
        trm = self.imp(fwd["x_hat"])
        fwd["transcriptome"] = trm
        return fwd

    def loss_function(self,fwd,inp):
        trm_loss = torch.nn.functional.mse_loss(fwd["transcriptome"], inp["transcriptome"])
        lossdict = self.rnvae.loss_function(fwd,inp)
        lossdict["loss"] += self.impute_loss_weight*trm_loss
        lossdict["Transcriptome_Loss"] = trm_loss.detach()
        return lossdict

In [None]:
bsz = 512
lincs_sample = lincs_joined_df.sample(10000)
rnvae_inp_new = RNVAE.make_input_new(lincs_sample)
rnvae_inp = RNVAE.make_input(lincs_sample,disabletqdm=False)
print(len(rnvae_inp_new[0]["pre_treatment"]),len(rnvae_inp[0]["pre_treatment"]))
print((rnvae_inp_new[0]["pre_treatment"] == rnvae_inp[0]["pre_treatment"]).all())

for k, v in rnvae_inp_new[0].items():
    print(k,v.dtype,v.shape)

for k, v in rnvae_inp[0].items():
    print(k,v.dtype,v.shape)

rnvae_loader = torch.utils.data.DataLoader(rnvae_inp, batch_size=bsz)

  6%|██▏                                   | 566/10000 [00:02<00:54, 173.26it/s]

In [14]:
train_loaders = []
eval_loaders = []
for fold in fold_to_eval_df:
    traind = Imputer.make_input(~make_mask(fold))
    for k, v in traind[0].items():
        print(k,v.dtype)
    print()
    train_loaders.append(torch.utils.data.DataLoader(traind, batch_size=bsz))
    
    evald = Imputer.make_input(make_mask(fold))
    for k, v in evald[0].items():
        print(k,v.dtype)
    eval_loaders.append(torch.utils.data.DataLoader(evald, batch_size=len(evald)))

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32
cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32
cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32


In [41]:
def epoch(models):
    def _epoch(model,opt,loader):
        for batch in loader:
            opt.zero_grad()
            fwd = model(batch)
            loss = model.loss_function(fwd,batch)["loss"]
            if torch.isnan(loss):
                return loss.detach()
            loss.backward()
            opt.step()
        
        return loss.detach()


    loss = _epoch(models["rnvae"],models["rnvae_opt"],models["rnvae_loader"])
    if torch.isnan(loss):
        return loss

    loss = _epoch(models["imputer"],models["impute_opt"],models["train_loader"])
    if torch.isnan(loss):
        return loss

    imputer = models["imputer"]
    with torch.no_grad():
        eval = next(iter(models["eval_loader"]))
        fwd = imputer(eval)
        # The eval loss we wish to optimize is how well the model
        # predicts the full transcriptome.
        return imputer.loss_function(fwd,eval)["Transcriptome_Loss"]

def make_models(config, input_data, fold):
    rnvae = RNVAE(config)
    imputer = Imputer(config,rnvae)
    return {
        "rnvae": rnvae,
        "imputer": imputer,
        "rnvae_opt": torch.optim.Adam(rnvae.parameters(), lr=config["lr_rnvae"]),
        "impute_opt": torch.optim.Adam(imputer.parameters(), lr=config["lr_imputer"]),
        "rnvae_loader": input_data["rnvae_loader"], # There is just one rnvae_loader shared across all folds
        "train_loader": input_data["train_loaders"][fold],
        "eval_loader": input_data["eval_loaders"][fold]
    }
    
def train_model(config, input_data):    
    def report(epoch,result):
        train.report(result)
        # if epoch % 10 == 0:
        #     print(epoch,result)
    
    all_models = []
    for fold in input_data["fold_to_eval_df"]:
        all_models.append(make_models(config, input_data, fold))

    for i in range(input_data["epochs"]):
        losses = []
        for fold in input_data["fold_to_eval_df"]:
            losses.append(epoch(all_models[fold]))
        
        if np.any(np.isnan(losses)):
            report(i,{input_data["metric"]: np.nan, "done": True})
        else:
            report(i,{input_data["metric"]: np.mean(losses)})

In [43]:
num_samples = 25
epochs = 25
metric = "mse"

input_data = {
    "rnvae_loader": rnvae_loader,
    "train_loaders": train_loaders,
    "eval_loaders": eval_loaders,
    "fold_to_eval_df": fold_to_eval_df,
    "epochs": epochs,
    "metric": metric
}

example_config = {
    "lr_rnvae": 1e-3,
    "lr_imputer": 1e-4,
    "dropout": .1,
    "sm_emb_size": 64,
    "cell_emb_size": 32,
    "latent_dim": 256,
    "hidden_dim": 512,
    "kld_weight": 1,
    "impute_loss_weight": 2,
}

space = {
    "lr_rnvae": hp.loguniform("lr_rnvae", -10, -1),
    "lr_imputer": hp.loguniform("lr_imputer", -10, -1),
    "dropout": hp.uniform("dropout", 0, 1),
    "sm_emb_size": scope.int(hp.qloguniform("sm_emb_size", 0, 3, 1)),
    "cell_emb_size": scope.int(hp.qloguniform("cell_emb_size", 0, 3, 1)),
    "latent_dim": scope.int(hp.qloguniform("latent_dim", 0, 7, 1)),
    "hidden_dim": scope.int(hp.qloguniform("hidden_dim", 0, 7, 1)),
    "kld_weight": hp.loguniform("kld_weight", -2, 2),
    "impute_loss_weight": hp.loguniform("impute_loss_weight", -2, 2),
}

train_model(example_config,input_data)
mode = "min"
hyperopt_search = HyperOptSearch(space, metric="mse", mode=mode)
scheduler = ASHAScheduler(metric="mse", grace_period=5, mode=mode, max_t=epochs)
tuner = tune.Tuner(
    tune.with_parameters(train_model, input_data=input_data),
    tune_config=tune.TuneConfig(
        num_samples=num_samples,
        search_alg=hyperopt_search,
        scheduler=scheduler
    ),
    run_config=train.RunConfig(
        failure_config=train.FailureConfig(fail_fast=False))
)
results = tuner.fit()

best_result = results.get_best_result(metric, mode=mode)
print(best_result.path)
print("CONFIG:", best_result.config)
print("METRICS:", best_result.metrics)

0,1
Current time:,2023-10-21 15:26:35
Running for:,00:02:45.27
Memory:,4.7/8.0 GiB

Trial name,status,loc,cell_emb_size,dropout,hidden_dim,impute_loss_weight,kld_weight,latent_dim,lr_imputer,lr_rnvae,sm_emb_size,iter,total time (s),mse
train_model_96b7866f,TERMINATED,127.0.0.1:27725,14,0.325083,1,0.756603,3.0455,288,4.97523e-05,0.0258026,2,6,2.57376,
train_model_37a0f3c7,TERMINATED,127.0.0.1:27725,7,0.0481424,468,0.142023,0.224891,47,0.150954,6.51287e-05,17,1,3.04511,
train_model_b29663eb,TERMINATED,127.0.0.1:27725,1,0.0846026,466,0.546392,0.176804,4,0.00014273,0.00155269,6,25,145.703,17.7377
train_model_3442c86b,TERMINATED,127.0.0.1:27747,4,0.952439,2,3.62338,2.20191,4,0.00897108,0.000823534,9,1,1.40722,
train_model_b39c8b7d,TERMINATED,127.0.0.1:27747,16,0.317211,226,0.521017,6.50339,40,0.0315213,0.00651967,5,5,12.2709,1395970.0
train_model_7d962c2d,TERMINATED,127.0.0.1:27758,2,0.321315,1,4.68073,6.58762,305,0.000691152,0.115826,3,1,2.23891,
train_model_6b6dd4c5,TERMINATED,127.0.0.1:27758,8,0.934079,11,3.31605,0.543498,12,0.00295545,0.162567,10,1,2.24743,
train_model_6cfd59f2,TERMINATED,127.0.0.1:27758,8,0.276024,31,0.197947,0.65754,12,0.0020115,7.31638e-05,10,25,29.1612,18.1587
train_model_2e128526,TERMINATED,127.0.0.1:27747,10,0.384028,1,0.51023,1.3125,2,9.98266e-05,0.0818821,8,5,3.55792,19.523
train_model_c08ef518,TERMINATED,127.0.0.1:27747,17,0.675803,2,6.2698,1.41903,8,0.000103325,0.00579406,15,5,5.12363,19.3459


2023-10-21 15:26:35,748	INFO tune.py:1143 -- Total run time: 165.48 seconds (165.27 seconds for the tuning loop).


/Users/laurasisson/ray_results/train_model_2023-10-21_15-23-50/train_model_b29663eb_3_cell_emb_size=1,dropout=0.0846,hidden_dim=466,impute_loss_weight=0.5464,kld_weight=0.1768,latent_dim=4,lr_i_2023-10-21_15-23-59
CONFIG: {'cell_emb_size': 1, 'dropout': 0.08460261005467273, 'hidden_dim': 466, 'impute_loss_weight': 0.5463921138401946, 'kld_weight': 0.17680367176040052, 'latent_dim': 4, 'lr_imputer': 0.00014272996901395086, 'lr_rnvae': 0.001552694731400991, 'sm_emb_size': 6}
METRICS: {'mse': 17.737665, 'timestamp': 1697916395, 'done': True, 'training_iteration': 25, 'trial_id': 'b29663eb', 'date': '2023-10-21_15-26-35', 'time_this_iter_s': 1.8166158199310303, 'time_total_s': 145.70338678359985, 'pid': 27725, 'hostname': 'Lauras-Air', 'node_ip': '127.0.0.1', 'config': {'cell_emb_size': 1, 'dropout': 0.08460261005467273, 'hidden_dim': 466, 'impute_loss_weight': 0.5463921138401946, 'kld_weight': 0.17680367176040052, 'latent_dim': 4, 'lr_imputer': 0.00014272996901395086, 'lr_rnvae': 0.001552