In [1]:
import autoencoder
import utils
import mrrmse

import pandas as pd
import torch

from sklearn.model_selection import KFold, train_test_split
import numpy as np
import tqdm
import random

from hyperopt import hp
from hyperopt.pyll import scope
from ray import train, tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.hyperopt import HyperOptSearch

## Prepare data:
#### Read joined data (pre + post treatment)

In [2]:
lincs_joined_df = pd.read_parquet("data/lincs_pretreatment.parquet")
kaggle_joined_df = pd.read_parquet("data/kaggle_pretreatment.parquet")
test_joined_df = pd.read_parquet("data/test_pretreatment.parquet")
print(f"lincs_joined_df = {lincs_joined_df.shape}\nkaggle_joined_df = {kaggle_joined_df.shape}\ntest_joined_df = {test_joined_df.shape}")

lincs_joined_df = (107404, 1842)
kaggle_joined_df = (602, 1841)
test_joined_df = (255, 921)


#### Kaggle provided data

In [3]:
de_train = pd.read_parquet('data/de_train.parquet')
id_map = pd.read_csv('data/id_map.csv',index_col='id')
print(f"de_train = {de_train.shape}\nid_map = {id_map.shape}")

de_train = (614, 18216)
id_map = (255, 2)


#### Define features of interest and sort data accordingly.

In [4]:
features = ['cell_type', 'sm_name']
multiindex_features = [("label",'cell_type'),("label",'sm_name')]

transcriptome_cols = de_train.columns[5:]
landmark_cols = kaggle_joined_df["post_treatment"].columns
print(f"transcriptome_cols = {transcriptome_cols.shape}\nlandmark_cols = {landmark_cols.shape}")

transcriptome_cols = (18211,)
landmark_cols = (918,)


In [5]:
unique_sm_name = pd.concat([lincs_joined_df[("label","sm_name")],kaggle_joined_df[("label","sm_name")]]).drop_duplicates().reset_index(drop=True)
unique_cell_type = pd.concat([lincs_joined_df[("label","cell_type")],kaggle_joined_df[("label","cell_type")]]).drop_duplicates().reset_index(drop=True)
print(f"Number of unique molecules = {len(unique_sm_name)}.\nNumber of unique cell types = {len(unique_cell_type)}.")

Number of unique molecules = 1896.
Number of unique cell types = 36.


In [6]:
# We only need to sort these two dataframes because they represent the same underlying dataset.
de_train = de_train.query("~control").sort_values(features)
kaggle_joined_df = kaggle_joined_df.sort_values(multiindex_features)
# Sanity check that these dfs align.
genes_align = (kaggle_joined_df["post_treatment"] == de_train[landmark_cols]).all(axis=None)
labels_align = (kaggle_joined_df["label"][features] == de_train[features]).all(axis=None)
genes_align and labels_align

True

#### CV splits

In [7]:
eval_cells_only_df = kaggle_joined_df[kaggle_joined_df["label"]["cell_type"].isin(["B cells", "Myeloid cells"])][multiindex_features]
len(eval_cells_only_df)

30

In [8]:
fold_to_eval_df = {}
skf = KFold(n_splits=3, random_state=42, shuffle=True)
for i,(t,v) in enumerate(skf.split(eval_cells_only_df)):
    fold_to_eval_df[i] = eval_cells_only_df.iloc[v]

for i, df in fold_to_eval_df.items():
    print(f"fold = {i} of shape {df.shape}")

fold = 0 of shape (10, 2)
fold = 1 of shape (10, 2)
fold = 2 of shape (10, 2)


In [9]:
def make_mask(fold):
    val = fold_to_eval_df[fold]
    return kaggle_joined_df[("label","sm_name")].isin(val[("label","sm_name")]) & kaggle_joined_df[("label","cell_type")].isin(val[("label","cell_type")])

print("Using fold 0 as validation set:")
print(f"Train data = {pd.concat([kaggle_joined_df[~make_mask(0)],lincs_joined_df]).shape}")
print(f"Validation data = {kaggle_joined_df[make_mask(0)].shape}")

Using fold 0 as validation set:
Train data = (107994, 1843)
Validation data = (12, 1841)


In [10]:
class Translator(torch.nn.Module):
    def __init__(self,config):
        super(Translator,self).__init__()
        # This will eventually be changed to a GNN
        self.smiles_embed = torch.nn.Embedding(len(unique_sm_name), config["sm_emb_size"])

        # This needs to be able to handle out of dictionary
        self.cell_embed = torch.nn.Embedding(len(unique_cell_type), config["cell_emb_size"])

        self.config = config
        input_dim = config["sm_emb_size"] + config["cell_emb_size"] + config["latent_dim"]
        self.translation = utils.make_sequential(input_dim,config["hidden_dim"],config["latent_dim"],config["dropout"])

    def forward(self,inp,z):
        sm_emb = self.smiles_embed(inp["sm_name"])
        ct_emb = self.cell_embed(inp["cell_type"])
        x = torch.cat((sm_emb, ct_emb, z), dim=1)
        return self.translation(x)

In [11]:
class RNVAE(torch.nn.Module):
    cell_type_map = {v: k for k,v in unique_cell_type.to_dict().items()}
    sm_name_map = {v: k for k,v in unique_sm_name.to_dict().items()}
    
    def __init__(self,config):
        super(RNVAE,self).__init__()
        self.vae = autoencoder.AutoEncoder(target_dim=len(landmark_cols),config=config)
        self.translator = Translator(config)

    # This is super slow because we are iterating.
    @classmethod
    def make_input(cls, df, disabletqdm=True):
        ct = df[("label","cell_type")].map(cls.cell_type_map)
        sm = df[("label","sm_name")].map(cls.sm_name_map)
        return [{"cell_type":torch.tensor(ct.iloc[i]),
                "sm_name":torch.tensor(sm.iloc[i]),
                "pre_treatment":torch.tensor(df["pre_treatment"].iloc[i].to_numpy(),dtype=torch.float),
                "post_treatment":torch.tensor(df["post_treatment"].iloc[i].to_numpy(),dtype=torch.float)} for i in tqdm.tqdm(range(len(df)),disable=disabletqdm)]

    @classmethod
    def make_input_new(cls, df):
        ct = torch.tensor(df[("label","cell_type")].map(cls.cell_type_map).to_numpy())
        sm = torch.tensor(df[("label","sm_name")].map(cls.sm_name_map).to_numpy())
        pre = torch.tensor(df["pre_treatment"].to_numpy(),dtype=torch.float32)
        post = torch.tensor(df["post_treatment"].to_numpy(),dtype=torch.float32)
        
        return [{"cell_type":ct[i],
                "sm_name":sm[i],
                "pre_treatment":pre[i],
                "post_treatment":post[i]} for i in range(len(df))]

    @classmethod
    def make_test(cls,df):
        ct = torch.tensor(df[("label","cell_type")].map(cls.cell_type_map).to_numpy())
        sm = torch.tensor(df[("label","sm_name")].map(cls.sm_name_map).to_numpy())
        pre = torch.tensor(df["pre_treatment"].to_numpy(),dtype=torch.float32)
        
        return [{"cell_type":ct[i],
                "sm_name":sm[i],
                "pre_treatment":pre[i]} for i in range(len(df))]
    
    def forward(self,inp):
        latent = self.vae.latent(inp["pre_treatment"])
        z_prime = self.translator(inp,latent["z"])
        x_hat = self.vae.decode(z_prime)
        return {"x_hat":x_hat, "mu": latent["mu"], "log_var":latent["log_var"]}

    def loss_function(self,fwd,inp):
        return self.vae.loss_function(fwd,inp["post_treatment"])

In [12]:
class Imputer(torch.nn.Module):
    def __init__(self,config,rnvae):
        super(Imputer,self).__init__()
        self.impute_loss_weight = config["impute_loss_weight"]
        self.imp = utils.make_sequential(len(landmark_cols),config["hidden_dim"],len(transcriptome_cols),config["dropout"])
        self.rnvae = rnvae

    @classmethod
    def make_input(cls, mask):
        kg_df = kaggle_joined_df[mask]
        trn_df = de_train[mask]
        rninp = RNVAE.make_input(kg_df)
        trm = trn_df[transcriptome_cols]
        for i,inp in enumerate(rninp):
            inp["transcriptome"] = torch.tensor(trm.iloc[i].to_numpy(), dtype=torch.float)
        return rninp

    def forward(self,inp):
        fwd = self.rnvae(inp)
        trm = self.imp(fwd["x_hat"])
        fwd["transcriptome"] = trm
        return fwd

    def loss_function(self,fwd,inp):
        trm_loss = torch.nn.functional.mse_loss(fwd["transcriptome"], inp["transcriptome"])
        lossdict = self.rnvae.loss_function(fwd,inp)
        lossdict["loss"] += self.impute_loss_weight*trm_loss
        lossdict["Transcriptome_Loss"] = trm_loss.detach()
        return lossdict

In [44]:
bsz = 512
lincs_sample = lincs_joined_df.sample(10000)
rnvae_inp_new = RNVAE.make_input_new(lincs_sample)
rnvae_inp = RNVAE.make_input(lincs_sample,disabletqdm=False)
print(len(rnvae_inp_new[0]["pre_treatment"]),len(rnvae_inp[0]["pre_treatment"]))
print((rnvae_inp_new[0]["pre_treatment"] == rnvae_inp[0]["pre_treatment"]).all())

for k, v in rnvae_inp_new[0].items():
    print(k,v.dtype,v.shape)

for k, v in rnvae_inp[0].items():
    print(k,v.dtype,v.shape)

rnvae_loader = torch.utils.data.DataLoader(rnvae_inp, batch_size=bsz)

100%|████████████████████████████████████| 10000/10000 [00:54<00:00, 183.48it/s]

918 918
tensor(True)
cell_type torch.int64 torch.Size([])
sm_name torch.int64 torch.Size([])
pre_treatment torch.float32 torch.Size([918])
post_treatment torch.float32 torch.Size([918])
cell_type torch.int64 torch.Size([])
sm_name torch.int64 torch.Size([])
pre_treatment torch.float32 torch.Size([918])
post_treatment torch.float32 torch.Size([918])





In [14]:
train_loaders = []
eval_loaders = []
for fold in fold_to_eval_df:
    traind = Imputer.make_input(~make_mask(fold))
    for k, v in traind[0].items():
        print(k,v.dtype)
    print()
    train_loaders.append(torch.utils.data.DataLoader(traind, batch_size=bsz))
    
    evald = Imputer.make_input(make_mask(fold))
    for k, v in evald[0].items():
        print(k,v.dtype)
    eval_loaders.append(torch.utils.data.DataLoader(evald, batch_size=len(evald)))

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32
cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32
cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32

cell_type torch.int64
sm_name torch.int64
pre_treatment torch.float32
post_treatment torch.float32
transcriptome torch.float32


In [46]:
def epoch(models):
    def _epoch(model,opt,loader):
        for batch in loader:
            opt.zero_grad()
            fwd = model(batch)
            loss = model.loss_function(fwd,batch)["loss"]
            if torch.isnan(loss):
                return loss.detach()
            loss.backward()
            opt.step()
        
        return loss.detach()


    loss = _epoch(models["rnvae"],models["rnvae_opt"],models["rnvae_loader"])
    if torch.isnan(loss):
        return loss

    loss = _epoch(models["imputer"],models["impute_opt"],models["train_loader"])
    if torch.isnan(loss):
        return loss

    imputer = models["imputer"]
    with torch.no_grad():
        eval = next(iter(models["eval_loader"]))
        fwd = imputer(eval)
        # The eval loss we wish to optimize is how well the model
        # predicts the full transcriptome.
        return imputer.loss_function(fwd,eval)["Transcriptome_Loss"]

def make_models(config, input_data, fold):
    rnvae = RNVAE(config)
    imputer = Imputer(config,rnvae)
    return {
        "rnvae": rnvae,
        "imputer": imputer,
        "rnvae_opt": torch.optim.Adam(rnvae.parameters(), lr=config["lr_rnvae"]),
        "impute_opt": torch.optim.Adam(imputer.parameters(), lr=config["lr_imputer"]),
        "rnvae_loader": input_data["rnvae_loader"], # There is just one rnvae_loader shared across all folds
        "train_loader": input_data["train_loaders"][fold],
        "eval_loader": input_data["eval_loaders"][fold]
    }
    
def train_model(config, input_data):    
    def report(epoch,result):
        train.report(result)
        if epoch % 10 == 0:
            print(epoch,result)
    
    all_models = []
    for fold in input_data["fold_to_eval_df"]:
        all_models.append(make_models(config, input_data, fold))

    for i in range(input_data["epochs"]):
        losses = []
        for fold in input_data["fold_to_eval_df"]:
            losses.append(epoch(all_models[fold]))
        
        if np.any(np.isnan(losses)):
            report(i,{input_data["metric"]: np.nan, "done": True})
        else:
            report(i,{input_data["metric"]: np.mean(losses)})

In [53]:
num_samples = 25
epochs = 100
metric = "mse"

input_data = {
    "rnvae_loader": rnvae_loader,
    "train_loaders": train_loaders,
    "eval_loaders": eval_loaders,
    "fold_to_eval_df": fold_to_eval_df,
    "epochs": epochs,
    "metric": metric
}

example_config = {
    "lr_rnvae": 1e-3,
    "lr_imputer": 1e-4,
    "dropout": .1,
    "sm_emb_size": 64,
    "cell_emb_size": 32,
    "latent_dim": 256,
    "hidden_dim": 512,
    "kld_weight": 1,
    "impute_loss_weight": 2,
}

space = {
    "lr_rnvae": hp.loguniform("lr_rnvae", -10, -1),
    "lr_imputer": hp.loguniform("lr_imputer", -10, -1),
    "dropout": hp.uniform("dropout", 0, 1),
    "sm_emb_size": scope.int(hp.qloguniform("sm_emb_size", 0, 3, 1)),
    "cell_emb_size": scope.int(hp.qloguniform("cell_emb_size", 0, 3, 1)),
    "latent_dim": scope.int(hp.qloguniform("latent_dim", 0, 7, 1)),
    "hidden_dim": scope.int(hp.qloguniform("hidden_dim", 0, 7, 1)),
    "kld_weight": hp.loguniform("kld_weight", -2, 2),
    "impute_loss_weight": hp.loguniform("impute_loss_weight", -2, 2),
}

train_model(example_config,input_data)
mode = "min"
hyperopt_search = HyperOptSearch(space, metric="mse", mode=mode)
scheduler = ASHAScheduler(metric="mse", grace_period=5, mode=mode, max_t=epochs)
tuner = tune.Tuner(
    tune.with_parameters(train_model, input_data=input_data),
    tune_config=tune.TuneConfig(
        num_samples=num_samples,
        search_alg=hyperopt_search,
        scheduler=scheduler
    ),
    run_config=train.RunConfig(
        failure_config=train.FailureConfig(fail_fast=False))
)
results = tuner.fit()

best_result = results.get_best_result(metric, mode=mode)
print(best_result.path)
print("CONFIG:", best_result.config)
print("METRICS:", best_result.metrics)

0,1
Current time:,2023-10-21 16:21:05
Running for:,00:13:08.96
Memory:,6.5/8.0 GiB

Trial name,# failures,error file
train_model_22845687,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_22845687_11_cell_emb_size=1,dropout=0.6582,hidden_dim=4,impute_loss_weight=0.2249,kld_weight=0.8452,latent_dim=963,lr__2023-10-21_16-10-26/error.txt"
train_model_f9b5a998,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_f9b5a998_12_cell_emb_size=10,dropout=0.3987,hidden_dim=59,impute_loss_weight=6.7682,kld_weight=3.9745,latent_dim=96,lr_2023-10-21_16-10-55/error.txt"
train_model_eeb3aaf5,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_eeb3aaf5_14_cell_emb_size=1,dropout=0.4214,hidden_dim=274,impute_loss_weight=0.2076,kld_weight=0.9336,latent_dim=1075,_2023-10-21_16-11-47/error.txt"
train_model_d56f80a9,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_d56f80a9_17_cell_emb_size=6,dropout=0.5781,hidden_dim=18,impute_loss_weight=7.1018,kld_weight=2.4764,latent_dim=1,lr_i_2023-10-21_16-12-35/error.txt"
train_model_4ce2b7eb,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_4ce2b7eb_20_cell_emb_size=2,dropout=0.9585,hidden_dim=5,impute_loss_weight=2.0832,kld_weight=1.3076,latent_dim=4,lr_im_2023-10-21_16-13-24/error.txt"
train_model_577b24bb,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_577b24bb_23_cell_emb_size=2,dropout=0.1050,hidden_dim=7,impute_loss_weight=0.8929,kld_weight=1.8859,latent_dim=22,lr_i_2023-10-21_16-14-18/error.txt"
train_model_11d45035,1,"/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_11d45035_25_cell_emb_size=11,dropout=0.1056,hidden_dim=2,impute_loss_weight=0.7696,kld_weight=3.0119,latent_dim=423,lr_2023-10-21_16-14-45/error.txt"

Trial name,status,loc,cell_emb_size,dropout,hidden_dim,impute_loss_weight,kld_weight,latent_dim,lr_imputer,lr_rnvae,sm_emb_size,iter,total time (s),mse
train_model_56f3dfce,TERMINATED,127.0.0.1:28615,2,0.958979,25,4.27346,0.46721,5,0.000151456,9.6189e-05,5,100.0,384.782,19.1383
train_model_4982e562,TERMINATED,127.0.0.1:28619,1,0.589091,879,3.72412,0.2346,5,0.0198546,0.000392642,14,5.0,235.936,3010.87
train_model_02321c4c,TERMINATED,127.0.0.1:28624,1,0.174596,439,0.311833,4.06219,20,0.00129081,0.00107111,4,100.0,760.872,7.76841
train_model_b5963ae5,TERMINATED,127.0.0.1:28630,4,0.465926,55,1.5513,1.02055,172,0.00212294,0.0147067,6,1.0,8.15241,
train_model_461749ab,TERMINATED,127.0.0.1:28630,2,0.829947,290,1.35113,0.89025,98,0.00411452,0.0150033,7,1.0,12.5717,
train_model_4f60eaa4,TERMINATED,127.0.0.1:28630,1,0.758455,41,3.02716,0.155665,15,0.00430916,0.000839257,12,100.0,370.934,17.9885
train_model_9804af93,TERMINATED,127.0.0.1:28645,9,0.26718,750,0.153273,0.141419,35,6.41321e-05,0.000668739,3,5.0,276.943,19.1318
train_model_9e8282b0,TERMINATED,127.0.0.1:28649,4,0.450362,79,0.934815,0.137884,173,0.00357777,0.000309717,2,5.0,111.416,18.6728
train_model_81b719dc,TERMINATED,127.0.0.1:28654,3,0.263742,2,0.153157,0.436029,7,0.000520024,0.157971,1,1.0,29.9228,
train_model_447422d9,TERMINATED,127.0.0.1:28654,2,0.889989,8,0.477095,2.16713,220,4.92328e-05,0.0102716,1,1.0,25.2297,




[2m[36m(train_model pid=28615)[0m 0 {'mse': 19.236553}
[2m[36m(train_model pid=28615)[0m 10 {'mse': 19.445566}
[2m[36m(train_model pid=28619)[0m 0 {'mse': 477.25305}
[2m[36m(train_model pid=28615)[0m 20 {'mse': 22.055792}[32m [repeated 2x across cluster][0m
[2m[36m(train_model pid=28630)[0m 0 {'mse': 19.179174}
[2m[36m(train_model pid=28615)[0m 30 {'mse': 19.190908}
[2m[36m(train_model pid=28645)[0m 0 {'mse': 19.181307}
[2m[36m(train_model pid=28649)[0m 0 {'mse': 19.151356}
[2m[36m(train_model pid=28630)[0m 10 {'mse': 18.967073}


2023-10-21 16:10:51,195	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_22845687: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28615)[0m 40 {'mse': 19.18001}


2023-10-21 16:11:16,491	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_f9b5a998: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28615)[0m 50 {'mse': 19.317554}


2023-10-21 16:12:02,045	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_eeb3aaf5: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28630)[0m 20 {'mse': 18.778696}


2023-10-21 16:12:39,909	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_d56f80a9: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28615)[0m 60 {'mse': 19.166344}
[2m[36m(train_model pid=28624)[0m 10 {'mse': 17.958761}
[2m[36m(train_model pid=28630)[0m 30 {'mse': 18.615685}
[2m[36m(train_model pid=28813)[0m 0 {'mse': 19.060581}


2023-10-21 16:13:42,124	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_4ce2b7eb: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28850)[0m 0 {'mse': 19.521143}
[2m[36m(train_model pid=28615)[0m 70 {'mse': 19.152044}


2023-10-21 16:14:40,146	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_577b24bb: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28630)[0m 40 {'mse': 18.4758}


2023-10-21 16:14:59,276	ERROR tune_controller.py:2231 -- Could not re-use actor for trial train_model_11d45035: Trainable runner reuse requires reset_config() to be implemented and return True.


[2m[36m(train_model pid=28615)[0m 80 {'mse': 19.145899}
[2m[36m(train_model pid=28615)[0m 90 {'mse': 19.15601}[32m [repeated 2x across cluster][0m
[2m[36m(train_model pid=28630)[0m 60 {'mse': 18.254393}[32m [repeated 2x across cluster][0m
[2m[36m(train_model pid=28813)[0m 20 {'mse': 18.163446}
[2m[36m(train_model pid=28630)[0m 70 {'mse': 18.16812}
[2m[36m(train_model pid=28630)[0m 80 {'mse': 18.095419}[32m [repeated 2x across cluster][0m
[2m[36m(train_model pid=28813)[0m 30 {'mse': 17.912172}
[2m[36m(train_model pid=28630)[0m 90 {'mse': 18.034513}
[2m[36m(train_model pid=28813)[0m 40 {'mse': 17.848322}
[2m[36m(train_model pid=28813)[0m 50 {'mse': 17.817028}
[2m[36m(train_model pid=28813)[0m 60 {'mse': 17.806108}[32m [repeated 2x across cluster][0m
[2m[36m(train_model pid=28813)[0m 70 {'mse': 17.781229}
[2m[36m(train_model pid=28624)[0m 40 {'mse': 14.459207}
[2m[36m(train_model pid=28813)[0m 90 {'mse': 17.768427}[32m [repeated 2x across

2023-10-21 16:21:05,417	ERROR tune.py:1139 -- Trials did not complete: [train_model_22845687, train_model_f9b5a998, train_model_eeb3aaf5, train_model_d56f80a9, train_model_4ce2b7eb, train_model_577b24bb, train_model_11d45035]
2023-10-21 16:21:05,419	INFO tune.py:1143 -- Total run time: 789.29 seconds (788.96 seconds for the tuning loop).
- train_model_22845687: FileNotFoundError('Could not fetch metrics for train_model_22845687: both result.json and progress.csv were not found at /Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_22845687_11_cell_emb_size=1,dropout=0.6582,hidden_dim=4,impute_loss_weight=0.2249,kld_weight=0.8452,latent_dim=963,lr__2023-10-21_16-10-26')
- train_model_f9b5a998: FileNotFoundError('Could not fetch metrics for train_model_f9b5a998: both result.json and progress.csv were not found at /Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_f9b5a998_12_cell_emb_size=10,dropout=0.3987,hidden_dim=59,impute_loss_weight=6.

/Users/laurasisson/ray_results/train_model_2023-10-21_16-07-56/train_model_02321c4c_3_cell_emb_size=1,dropout=0.1746,hidden_dim=439,impute_loss_weight=0.3118,kld_weight=4.0622,latent_dim=20,lr__2023-10-21_16-08-08
CONFIG: {'cell_emb_size': 1, 'dropout': 0.17459607590516624, 'hidden_dim': 439, 'impute_loss_weight': 0.3118326491123, 'kld_weight': 4.062189584858828, 'latent_dim': 20, 'lr_imputer': 0.001290806826201169, 'lr_rnvae': 0.0010711102093205497, 'sm_emb_size': 4}
METRICS: {'mse': 7.7684097, 'timestamp': 1697919665, 'done': True, 'training_iteration': 100, 'trial_id': '02321c4c', 'date': '2023-10-21_16-21-05', 'time_this_iter_s': 3.382596969604492, 'time_total_s': 760.8720269203186, 'pid': 28624, 'hostname': 'Lauras-Air', 'node_ip': '127.0.0.1', 'config': {'cell_emb_size': 1, 'dropout': 0.17459607590516624, 'hidden_dim': 439, 'impute_loss_weight': 0.3118326491123, 'kld_weight': 4.062189584858828, 'latent_dim': 20, 'lr_imputer': 0.001290806826201169, 'lr_rnvae': 0.001071110209320549

In [49]:
all_mask = make_mask(0) | True
all_train = Imputer.make_input(all_mask)
all_loader = torch.utils.data.DataLoader(all_train, batch_size=32)

submit_data = RNVAE.make_test(test_joined_df)
submit_loader = torch.utils.data.DataLoader(submit_data, batch_size=len(submit_data))

In [54]:
best_input_data = {
    "rnvae_loader": rnvae_loader,
    "train_loaders": [all_loader],
    "eval_loaders": [all_loader],
    "fold_to_eval_df": fold_to_eval_df,
}

best_models = make_models(best_result.config,best_input_data,0)
print(best_result.config)
# Because we trained the models on a cross-validation split, we want to train one final model
# across all data available.

loss = 0
for _ in tqdm.tqdm(range(best_result.metrics["training_iteration"])):
    loss = epoch(best_models)
print(loss)

with torch.no_grad():
    submitbatch = next(iter(submit_loader))
    # This is the most elegant line of python ever written.
    y_pred = best_models["imputer"](submitbatch)["transcriptome"]


submission = pd.DataFrame(y_pred, columns=transcriptome_cols, index=id_map.index)
display(submission)
submission.to_csv('submissions/rnvae.csv')

{'cell_emb_size': 1, 'dropout': 0.17459607590516624, 'hidden_dim': 439, 'impute_loss_weight': 0.3118326491123, 'kld_weight': 4.062189584858828, 'latent_dim': 20, 'lr_imputer': 0.001290806826201169, 'lr_rnvae': 0.0010711102093205497, 'sm_emb_size': 4}


100%|█████████████████████████████████████████| 100/100 [02:26<00:00,  1.46s/it]

tensor(8.1736)





Unnamed: 0_level_0,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,0.528672,0.318549,0.521696,0.771293,1.164450,0.884213,0.039088,0.527970,-0.104824,0.180595,...,-0.107548,0.202047,0.168787,0.420231,0.739610,0.506831,0.240892,0.206056,-0.084824,-0.098783
1,-0.030574,-0.025990,-0.019649,0.041366,-0.017743,-0.081546,-0.123232,0.013412,-0.080594,0.112765,...,-0.074695,-0.067768,-0.154886,0.006338,0.156259,0.106656,0.057608,0.048692,0.077000,-0.189186
2,0.523027,0.228987,0.273987,0.232100,0.788171,1.206717,0.020542,0.381870,-0.083118,0.142458,...,0.011714,0.137847,0.052237,0.385553,0.558164,0.430852,0.284110,0.172620,-0.097646,-0.050526
3,0.178113,0.071593,0.302333,0.282079,0.409346,0.122760,-0.009732,0.240176,-0.141487,0.052558,...,-0.128297,0.098732,-0.013200,0.107470,0.324498,0.194461,0.168248,0.127013,-0.060557,-0.041977
4,0.000687,-0.073101,-0.075086,0.064975,0.122777,-0.040922,-0.125111,0.013967,-0.094423,0.113349,...,-0.145611,-0.079895,-0.190313,-0.027026,0.200238,0.141642,-0.009138,0.023725,0.097237,-0.227018
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
250,1.628841,0.780313,1.896007,2.646558,4.904253,3.166702,0.464417,1.569543,0.001372,-0.110998,...,0.030858,0.838427,0.812316,1.221225,2.348658,1.466492,0.850378,0.304324,-0.248923,0.104876
251,0.310764,0.046431,0.510666,0.601503,1.854848,0.558109,0.044957,0.061032,-0.282235,0.111371,...,-0.311963,-0.007312,-0.150231,-0.264160,0.619928,0.109334,0.122755,-0.029426,-0.130031,-0.010530
252,1.764767,0.504572,0.873535,2.365860,10.343506,6.479660,0.502747,1.515960,0.444920,-0.819880,...,0.551924,0.135300,1.401836,0.654832,3.497616,1.634827,0.872020,0.341745,-0.901023,0.585574
253,3.101876,1.583784,2.687714,4.606878,12.430525,7.394593,0.538450,2.536995,-0.288898,-0.163105,...,-0.201641,1.429180,1.243499,2.115982,5.056717,2.590979,1.249429,0.134950,-0.653481,0.165899
