In [113]:
import autoencoder
import utils
import mrrmse

import pandas as pd
import torch

from sklearn.model_selection import KFold, train_test_split
import numpy as np
import tqdm

## Prepare data:
#### Read joined data (pre + post treatment)

In [5]:
lincs_joined_df = pd.read_parquet("data/lincs_pretreatment.parquet")
kaggle_joined_df = pd.read_parquet("data/kaggle_pretreatment.parquet")
test_joined_df = pd.read_parquet("data/test_pretreatment.parquet")
print(f"lincs_joined_df = {lincs_joined_df.shape}\nkaggle_joined_df = {kaggle_joined_df.shape}\ntest_joined_df = {test_joined_df.shape}")

lincs_joined_df = (107404, 1842)
kaggle_joined_df = (602, 1841)
test_joined_df = (255, 921)


#### Kaggle provided data

In [6]:
de_train = pd.read_parquet('data/de_train.parquet')
id_map = pd.read_csv('data/id_map.csv',index_col='id')
print(f"de_train = {de_train.shape}\nid_map = {id_map.shape}")

de_train = (614, 18216)
id_map = (255, 2)


#### Define features of interest and sort data accordingly.

In [7]:
features = ['cell_type', 'sm_name']
multiindex_features = [("label",'cell_type'),("label",'sm_name')]

transcriptome_cols = de_train.columns[5:]
landmark_cols = kaggle_joined_df["post_treatment"].columns
print(f"transcriptome_cols = {transcriptome_cols.shape}\nlandmark_cols = {landmark_cols.shape}")

transcriptome_cols = (18211,)
landmark_cols = (918,)


In [8]:
unique_sm_name = pd.concat([lincs_joined_df[("label","sm_name")],kaggle_joined_df[("label","sm_name")]]).drop_duplicates().reset_index(drop=True)
unique_cell_type = pd.concat([lincs_joined_df[("label","cell_type")],kaggle_joined_df[("label","cell_type")]]).drop_duplicates().reset_index(drop=True)
print(f"Number of unique molecules = {len(unique_sm_name)}.\nNumber of unique cell types = {len(unique_cell_type)}.")

Number of unique molecules = 1896.
Number of unique cell types = 36.


In [9]:
# We only need to sort these two dataframes because they represent the same underlying dataset.
de_train = de_train.query("~control").sort_values(features)
kaggle_joined_df = kaggle_joined_df.sort_values(multiindex_features)
# Sanity check that these dfs align.
genes_align = (kaggle_joined_df["post_treatment"] == de_train[landmark_cols]).all(axis=None)
labels_align = (kaggle_joined_df["label"][features] == de_train[features]).all(axis=None)
genes_align and labels_align

True

#### CV splits

In [10]:
eval_cells_only_df = kaggle_joined_df[kaggle_joined_df["label"]["cell_type"].isin(["B cells", "Myeloid cells"])][multiindex_features]
len(eval_cells_only_df)

30

In [11]:
fold_to_eval_df = {}
skf = KFold(n_splits=3, random_state=42, shuffle=True)
for i,(t,v) in enumerate(skf.split(eval_cells_only_df)):
    fold_to_eval_df[i] = eval_cells_only_df.iloc[v]

for i, df in fold_to_eval_df.items():
    print(f"fold = {i} of shape {df.shape}")

fold = 0 of shape (10, 2)
fold = 1 of shape (10, 2)
fold = 2 of shape (10, 2)


In [12]:
def make_mask(fold):
    val = fold_to_eval_df[fold]
    return kaggle_joined_df[("label","sm_name")].isin(val[("label","sm_name")]) & kaggle_joined_df[("label","cell_type")].isin(val[("label","cell_type")])

print("Using fold 0 as validation set:")
print(f"Train data = {pd.concat([kaggle_joined_df[~make_mask(0)],lincs_joined_df]).shape}")
print(f"Validation data = {kaggle_joined_df[make_mask(0)].shape}")

Using fold 0 as validation set:
Train data = (107994, 1843)
Validation data = (12, 1841)


In [13]:
class Translator(torch.nn.Module):
    def __init__(self,config):
        super(Translator,self).__init__()
        # This will eventually be changed to a GNN
        self.smiles_embed = torch.nn.Embedding(len(unique_sm_name), config["sm_emb_size"])

        # This needs to be able to handle out of dictionary
        self.cell_embed = torch.nn.Embedding(len(unique_cell_type), config["cell_emb_size"])

        self.config = config
        input_dim = config["sm_emb_size"] + config["cell_emb_size"] + config["latent_dim"]
        self.translation = utils.make_sequential(input_dim,config["hidden_dim"],config["latent_dim"],config["dropout"])

    def forward(self,inp,z):
        sm_emb = self.smiles_embed(inp["sm_name"])
        ct_emb = self.cell_embed(inp["cell_type"])
        x = torch.cat((sm_emb, ct_emb, z), dim=1)
        return self.translation(x)

In [110]:
class RNVAE(torch.nn.Module):
    cell_type_map = {v: k for k,v in unique_cell_type.to_dict().items()}
    sm_name_map = {v: k for k,v in unique_sm_name.to_dict().items()}
    
    def __init__(self,config):
        super(RNVAE,self).__init__()
        self.vae = autoencoder.AutoEncoder(target_dim=len(landmark_cols),config=config)
        self.translator = Translator(config)

    @classmethod
    def make_input(cls, df, disabletqdm=True):
        ct = df[("label","cell_type")].map(cls.cell_type_map)
        sm = df[("label","sm_name")].map(cls.sm_name_map)
        return [{"cell_type":torch.tensor(ct.iloc[i]),
                "sm_name":torch.tensor(sm.iloc[i]),
                "pre_treatment":torch.tensor(df["pre_treatment"].iloc[i].to_numpy(),dtype=torch.float),
                "post_treatment":torch.tensor(df["post_treatment"].iloc[i].to_numpy(),dtype=torch.float)} for i in tqdm.tqdm(range(len(df)),disable=disabletqdm)]
        
    def forward(self,inp):
        assert inp["pre_treatment"].shape == inp["post_treatment"].shape
        latent = self.vae.latent(inp["pre_treatment"])
        z_prime = self.translator(inp,latent["z"])
        x_hat = self.vae.decode(z_prime)
        return {"x_hat":x_hat, "mu": latent["mu"], "log_var":latent["log_var"]}

    def loss_function(self,fwd,inp):
        return self.vae.loss_function(fwd,inp["post_treatment"])

In [107]:
class Imputer(torch.nn.Module):
    def __init__(self,config,rnvae):
        super(Imputer,self).__init__()
        self.impute_loss_weight = config["impute_loss_weight"]
        self.imp = utils.make_sequential(len(landmark_cols),config["hidden_dim"],len(transcriptome_cols),config["dropout"])
        self.rnvae = rnvae

    @classmethod
    def make_input(cls, mask):
        kg_df = kaggle_joined_df[mask]
        trn_df = de_train[mask]
        rninp = RNVAE.make_input(kg_df)
        trm = trn_df[transcriptome_cols]
        for i,inp in enumerate(rninp):
            inp["transcriptome"] = torch.tensor(trm.iloc[i].to_numpy(), dtype=torch.float)
        return rninp

    def forward(self,inp):
        fwd = self.rnvae(inp)
        trm = self.imp(fwd["x_hat"])
        fwd["transcriptome"] = trm
        return fwd

    def loss_function(self,fwd,inp):
        trm_loss = torch.nn.functional.mse_loss(fwd["transcriptome"], inp["transcriptome"])
        lossdict = self.rnvae.loss_function(fwd,inp)
        lossdict["loss"] += self.impute_loss_weight*trm_loss
        lossdict["Transcriptome_Loss"] = trm_loss.detach()
        return lossdict


In [111]:
example_config = {
    "sm_emb_size": 64,
    "cell_emb_size": 32,
    "latent_dim": 256,
    "hidden_dim": 512,
    "dropout": .1,
    "kld_weight": 1,
    "impute_loss_weight": 2,
}

# inp = RNVAE.make_input(kaggle_joined_df[make_mask(0)])
# rnvae = RNVAE(example_config)
# fwd = rnvae(inp)
# display(fwd["x_hat"].shape)
# rnvae.loss_function(fwd,inp)

# inp = Imputer.make_input(make_mask(0))
# imputer = Imputer(example_config,rnvae)
# imputer.loss_function(fwd,inp)

In [126]:
lincs_subset = lincs_joined_df.sample(10000)
rnvae_inp = RNVAE.make_input(lincs_subset, disabletqdm=False)
rnvae_loader = torch.utils.data.DataLoader(rnvae_inp, batch_size=32)

100%|████████████████████████████████████| 10000/10000 [00:46<00:00, 215.04it/s]


In [129]:
imputer_inp = Imputer.make_input(~make_mask(0))
imputer_loader = torch.utils.data.DataLoader(imputer_inp, batch_size=32)

In [130]:
optimizer = torch.optim.Adam(imputer.parameters(), lr=1e-2)
for _ in range(1):
    for batch in rnvae_loader:
        fwd = rnvae(batch)
        print("RNVAE",rnvae.loss_function(fwd,batch))

    for batch in imputer_loader:
        fwd = imputer(batch)
        print("IMPUTER",imputer.loss_function(fwd,batch))

RNVAE {'loss': tensor(3.1455, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(2.7204), 'KLD': tensor(-0.4251)}
RNVAE {'loss': tensor(2.3337, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(1.9090), 'KLD': tensor(-0.4246)}
RNVAE {'loss': tensor(2.7411, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(2.3110), 'KLD': tensor(-0.4301)}
RNVAE {'loss': tensor(1.9576, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(1.5331), 'KLD': tensor(-0.4245)}
RNVAE {'loss': tensor(2.8466, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(2.4312), 'KLD': tensor(-0.4153)}
RNVAE {'loss': tensor(1.5231, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(1.1002), 'KLD': tensor(-0.4230)}
RNVAE {'loss': tensor(2.1744, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(1.7501), 'KLD': tensor(-0.4243)}
RNVAE {'loss': tensor(1.6524, grad_fn=<AddBackward0>), 'Reconstruction_Loss': tensor(1.2315), 'KLD': tensor(-0.4209)}
RNVAE {'loss': tensor(2.9606, grad_fn=<AddBackward0>), '