In [None]:
%reload_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

dir_path = Path(os.getcwd()).absolute()
module_path = str(dir_path.parent.parent.parent) # repo root
if module_path not in sys.path:
    sys.path.append(module_path)
    
import numpy as np
import pandas as pd
import torch, lifelines

import pytorch_lightning as pl
pl.utilities.seed.seed_everything(seed=42)

import matplotlib.pyplot as plt
from pytorch_lightning import Trainer

# OWN MODULES
from experiments.data.data_module import UKRegDataModule
from organsync.models.organsync_network import OrganSync_Network
from organsync.models.organite_network import  OrganITE_Network
from organsync.models.inference import Inference_OrganITE, Inference_OrganSync

import warnings
warnings.filterwarnings('ignore')


In [None]:
# SETUP DATA

batch_size = 256

root_data_dir = Path("../datasets").absolute()

data_dir = root_data_dir / "processed_UKReg" # replace with your path to the processed UKReg dataset

dm = UKRegDataModule(data_dir, batch_size=batch_size, control=False)
dm.prepare_data()
dm.setup(stage='fit')

dm._train_processed

# OrganITE

In [None]:
# OrganITE

input_dim = dm.size(1)
hidden_dim=16
num_hidden_layers=4
output_dim=6
lr=0.007
gamma=0.79
lambd=0.15
kappa=0.15
weight_decay=0.0006
n_clusters=15
activation_type="leaky_relu"
dropout_prob=0.11
epochs = 100

organite_model = OrganITE_Network(
             input_dim=input_dim,
             hidden_dim=hidden_dim,
             num_hidden_layers=num_hidden_layers,
             output_dim=output_dim,
             lr=lr,
             gamma=gamma,
             lambd=lambd,
             kappa=kappa,
             weight_decay=weight_decay,
             n_clusters=n_clusters,
             activation_type=activation_type,
             dropout_prob=dropout_prob,
).double()

dm.setup(stage='fit')
organite_trainer = Trainer(callbacks=[], max_epochs=epochs)
organite_trainer.fit(organite_model, datamodule=dm)

dm.setup(stage='test')
organite_trainer.test(datamodule=dm)

In [None]:
import cloudpickle

organite_model.eval()
inference_oite = Inference_OrganITE(model=organite_model, mean=dm.mean, std=dm.std)

del organite_model.trainer
organite_model.trainer = None

with open("models/organite_inference.p", "wb") as f:
    cloudpickle.dump(inference_oite, f)


# OrganSync

In [None]:
# OrganSync


lr = .0095
weight_decay = 0.00001
num_hidden_layers = 3
hidden_dim = 50
output_dim = 19
dropout_prob = 0.006
epochs = 100
batch_size = 128
activation_type = 'leaky_relu'
gamma = 0.87 # lr decay
lambd = .1

# CONSTRUCT MODEL(W)
input_dim = dm.size(1)
organsync_model_with_organ = OrganSync_Network(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_hidden_layers=num_hidden_layers,
    output_dim=output_dim,
    lr=lr, gamma=gamma, lambd=lambd, weight_decay=weight_decay,
    activation_type=activation_type,
    dropout_prob=dropout_prob).double()
 
trainer_with_organ = Trainer(callbacks=[], max_epochs=epochs)
trainer_with_organ.fit(organsync_model_with_organ, datamodule=dm)

In [None]:
import cloudpickle

organsync_model_with_organ.eval()
del organsync_model_with_organ.trainer
organsync_model_with_organ.trainer = None

inference_os = Inference_OrganSync(model=organsync_model_with_organ, mean=dm.mean, std=dm.std)


with open("models/organsync_inference.p", "wb") as f:
    cloudpickle.dump(inference_os, f)

inference_os

# OrganITE VAE

In [None]:
# Organ clustering
test_df = dm._train_processed

organs = test_df[list(dm.o_cols)]


In [None]:
probs = (np.bincount(organite_model.cluster.predict(organs)) / len(organs) * 100)

probs

In [None]:
import cloudpickle


with open("models/organite_clustering_probs.p", "wb") as f:
    cloudpickle.dump(probs, f)

organs.to_csv("models/organs.csv", compression='gzip', index = None)
    