# 0) Prepare Colab environment

In [None]:
# -------------------------------------------------------------------
# 📦 1) Clone your repo & cd into it
# -------------------------------------------------------------------
!git clone https://github.com/mahmoudibrahim98/icu-autodiff.git
%cd icu-autodiff

# -------------------------------------------------------------------
# 📦 2) Install  project dependencies
# -------------------------------------------------------------------

!pip install -r colab-compatible-requirements.txt

# -------------------------------------------------------------------
# 📦 3) (If your code lives in subfolders) add them to PYTHONPATH
# -------------------------------------------------------------------
import sys
sys.path.append('.')         # or 'src', etc., depending on your layout

# -------------------------------------------------------------------
# 📦 4) (Optional) Mount Drive for large data
# -------------------------------------------------------------------
#from google.colab import drive
#drive.mount('/content/drive')
#DATA_DIR = '/content/drive/MyDrive/yourproject/data'


# 1) Data Preparation (Needed once)

In [None]:

import sys
import os
# Add the project root directory to the Python path
parent_dir = os.path.dirname(os.path.abspath(''))
sys.path.append(parent_dir)
                            
from datetime import datetime, timedelta
import pandas as pd
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
import os
import numpy as np
import torch
from sklearn.model_selection import train_test_split
# Set the CUDA device to 0
torch.cuda.set_device(0)

# Verify the current device
current_device = torch.cuda.current_device()
print(f"Current CUDA device: {current_device}")
print(torch.cuda.is_available())

import data_access.base_loader as base_loader
import data_access.ricu_loader as ricu_loader
from absl import flags






In [None]:
# splitting parameters
train_fraction = 0.45
val_fraction = 0.1
oracle_fraction = 0
oracle_min = 100
intersectional_min_threshold = 100
intersectional_max_threshold = 1000


# # data parameters
data_name = 'mimic' # 'mimic' 'eicu'
task_name = 'mortality24' # 'aki' 'kidney_function' 'los' 'los_24' 'mortality24' 
static_var = 'ethnicity'
features = None
ricu_dataset_path = f'../raw_data/{task_name}/{data_name}'
processed_output_path = f'outputs/{task_name}/{data_name}/processed/'
intermed_output_path = f'outputs/{task_name}/{data_name}/intermed/'

seed = 0

simple_imputation = True
mode = 'raw'
intermed_data_timestamp = None

standardize = False
save_intermed_data = True
save_processed_data = True


split = True # will split into train, val, test, stratified on outcome and demographics_to_stratify_on
stratify =  False
intersectional = False

if split == False:
    split_text = 'No Split'

'''
Two modes of operation:

'''

loader = ricu_loader.RicuLoader(seed, task_name, data_name,static_var,ricu_dataset_path,simple_imputation,
                                    features, processed_output_path,intermed_output_path)


if mode == 'raw':
    # Create directories if they do not exist
    if save_intermed_data:
        os.makedirs(intermed_output_path, exist_ok=True)
    if save_processed_data:
        os.makedirs(processed_output_path, exist_ok=True)
        
    X_dict_tf, y_dict, static = loader.get_data(mode='raw', train_fraction=train_fraction, val_fraction=val_fraction, oracle_fraction=oracle_fraction, 
                                                oracle_min=oracle_min, intersectional_min_threshold=intersectional_min_threshold,
                                                intersectional_max_threshold=intersectional_max_threshold,
                                                standardize=standardize,
                                                stratify=stratify, intersectional=intersectional, split = split,
                                                save_intermed_data=save_intermed_data, save_processed_data=save_processed_data,
                                                demographics_to_stratify_on = ['age_group','ethnicity','gender'])
else:
    raise ValueError("Invalid mode specified. Choose 'raw', 'processed', or 'intermediate'.")

if not isinstance(X_dict_tf, dict):
    X_dict_tf = {file: X_dict_tf[file] for file in X_dict_tf.files}
    y_dict = {file: y_dict[file] for file in y_dict.files}

X_dict_tf.keys()

# 2) Training

In [None]:
# Add the project root directory to the Python path
import sys
import os
parent_dir = os.path.dirname(os.path.abspath(''))
sys.path.append(parent_dir)

import numpy as np
import torch 

import data_access.base_loader as base_loader
import data_access.ricu_loader as ricu_loader
from datetime import datetime
import wandb
import ast
import logging
import json

import timeautodiff.processing_simple as processing
import timeautodiff.helper_simple as tdf_helper
import timeautodiff.timeautodiff_v4_efficient_simple as timeautodiff

## 2.1) Data Preperation

In [None]:

# most_important_features = [19, 27, 17, 35, 22, 44, 42, 43, 37, 26]
X_train = X_dict_tf['X_imputed_train'][:,:,:]
X_holdout = X_dict_tf['X_imputed_test'][:,:,:]
X_holdout_val = X_dict_tf['X_imputed_val'][:,:,:]

m_train = X_dict_tf['m_train'][:,:,:]
m_holdout = X_dict_tf['m_test'][:,:,:]
m_holdout_val = X_dict_tf['m_val'][:,:,:]

feature_names = X_dict_tf['feature_names'][:]
y_train = y_dict['y_train'][:]
y_holdout = y_dict['y_test'][:]
y_holdout_val = y_dict['y_val'][:]


static_feature_to_include = ['ethnicity','gender','age_group']
static_features_to_include_indices = sorted([y_dict['feature_names'].tolist().index(include)  for include in static_feature_to_include])
c_train = y_dict['c_train'][:,static_features_to_include_indices]
c_holdout = y_dict['c_test'][:,static_features_to_include_indices]
c_holdout_val = y_dict['c_val'][:,static_features_to_include_indices]

cond_names = y_dict['feature_names'][static_features_to_include_indices]



top10_important_features = [19, 27, 17, 35, 22, 44, 42, 43, 37, 26]
top3_important_features = [44,42,43]
top6_important_features = [42, 22, 27, 35, 43, 17]

important_features_names = X_dict_tf['feature_names'][top10_important_features]
important_features_names

X_train_10 = processing.normalize_and_reshape(X_train)
X_train_10 = X_train_10[:,:,top10_important_features]

print('Shape of X train:', X_train.shape)
print('Shape of X Holdout:', X_holdout.shape)
print('Shape of X Holdout val:', X_holdout_val.shape)

print('Shape of y train:', y_train.shape)
print('Shape of y Holdout:', y_holdout.shape)
print('Shape of y Holdout val:', y_holdout_val.shape)

print('Shape of c train:', c_train.shape)
print('Shape of c Holdout:', c_holdout.shape)
print('Shape of c Holdout val:', c_holdout_val.shape)


In [None]:

################################################################################################################
################################################################################################################
################################################################################################################
                                                    # Prepare Data for Training #
################################################################################################################
################################################################################################################
################################################################################################################


metadata = f"{data_name}_{task_name}"

process_data = True
load_data = False
train_models = True
train_auto = True
train_diff = True
load_model = False
# processed_data_timestamp = '20241203_130537_10features'

model_version = 'v4_efficient_simple'

    
EXP_PATH = os.path.join(os.getcwd(), 'outputs')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
gen_model = 'TimeAutoDiff'
output_dir = f'outputs/{task_name}/{data_name}/{gen_model}/{timestamp}_{len(important_features_names)}features_{model_version}_{metadata}'
os.makedirs(output_dir, exist_ok=True)
numerical_processing = 'normalize'


    
# prorcess data for training of generators
processed_X, processed_y, processed_c, time_info = processing.process_data_for_synthesizer(X_train, y_train, c_train, top10_important_features)
cond = torch.concatenate((processed_c, processed_y), axis=2)
response = processed_X
response = response.float()
time_info = time_info.float()


metadata = {
    'model_version': model_version,
    'genmodel_timestamp': timestamp,
    'important_features_names': important_features_names.tolist(),
    'number of features': len(important_features_names),
    'seq_len': processed_X.shape[1],
    'seed': seed,
    'patient_length': processed_X.shape[0],
    'numerical_processing': numerical_processing,
}
metadata.update(data_params)
metadata_path = os.path.join(output_dir, 'metadata.json')
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=4)
    
    
################################################################################################################
# Checking Processed Data #
################################################################################################################

print(f"Shape of the response data: {processed_X.shape}")
print(f"Shape of the condition data: {cond.shape}")




## 2.3) Training Auto encoder

In [None]:
################################################################################################################
################################################################################################################
################################################################################################################
                                                    # Training #
################################################################################################################
################################################################################################################
################################################################################################################
efficient = True
auto_mmd_weight = 0
auto_consistency_weight = 0
diff_mmd_weight = 0
diff_consistency_weight = 0
full_metadata = f'auto_mmd_{auto_mmd_weight}_auto_cons_{auto_consistency_weight}_diff_mmd_{diff_mmd_weight}_diff_cons_{diff_consistency_weight}'
# metadata = f'{id}'

use_wandb = False

################################################################################################################
# Defining Model Parameters #
################################################################################################################
if train_models:
    VAE_training = 200
    diff_training = 200
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ###### Auto-encoder Parameters ######
    n_epochs = VAE_training; eps = 1e-5
    weight_decay = 1e-6; lr = 2e-4; hidden_size = 200; num_layers = 2; batch_size = 100
    channels = 64; min_beta = 1e-5; max_beta = 0.1; emb_dim = 128; time_dim = time_info.shape[2];  lat_dim = response.shape[2]; threshold = 1

    if lat_dim > response.shape[2]:
        raise ValueError("lat_dim should be less than the number of important features.")

    ###### Diffusion Parameters ######
    n_epochs = diff_training; hidden_dim = 200; num_layers = 2; diffusion_steps = 100;


    new_params = {
        "VAE_training": VAE_training,
        "diff_training": diff_training,
        "device": str(device),
        "imputation strategy": "randomly select from imputed patients.", # "drop missing values"
        "eps" : eps,
        "auto_weight_decay" : weight_decay,
        "auto_lr" : lr,
        "auto_hidden_size" : hidden_size,
        "auto_num_layers" : num_layers,
        "auto_batch_size" : batch_size,
        "auto_channels" : channels,
        "auto_min_beta" : min_beta,
        "auto_max_beta" : max_beta,
        "auto_emb_dim" : emb_dim,
        "auto_time_dim" : time_dim,
        "auto_lat_dim" : lat_dim,
        "auto_threshold" : threshold,
        "diff_hidden_dim" : hidden_dim,
        "diffusion_steps" : diffusion_steps,
        "diff_num_layers" : num_layers,
        "auto_mmd_weight" : auto_mmd_weight,
        "auto_consistency_weight" : auto_consistency_weight,
        "diff_mmd_weight" : diff_mmd_weight,
        "diff_consistency_weight" : diff_consistency_weight    
    }   

    # Call the method
    tdf_helper.append_new_params_to_metadata(output_dir, new_params)

    # Path to the metadata JSON file
    metadata_path = os.path.join(output_dir, 'metadata.json')
    # Read the existing JSON file
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)

    # Extract the parameters
    patient_length = metadata.get('patient_length')
    imputation_strategy = metadata.get('imputation strategy')
    number_of_features = metadata.get('number of features')




    ################################################################################################################
    # WANDB Initialization #
    ################################################################################################################
    if use_wandb:
        config = dict(
            model = "TimeAutoDiff",
            patient_length = patient_length,
            imputation_strategy = imputation_strategy,
            number_of_features = number_of_features,
            epochs_VAE = VAE_training,
            epochs_diffusion = diff_training,
            pred_task = task_name,
            data_name = data_name,
        )

        use_cuda = torch.cuda.is_available()
        wandb.init(
            project = 'TimeAutoDiff',
            config = config,
            name = output_dir.split('/')[-1],
        )

    ################################################################################################################
    # Auto-encoder Training #
    ################################################################################################################
if train_auto:
    torch.cuda.empty_cache()
    if efficient:
        ds = timeautodiff.train_autoencoder(response, channels, hidden_size, num_layers, lr, weight_decay, n_epochs,
                                                      batch_size, min_beta, max_beta, emb_dim, time_dim, lat_dim, device,output_dir, checkpoints=True,
                                                    mmd_weight = auto_mmd_weight, consistency_weight = auto_consistency_weight, use_wandb=use_wandb)
    # Save Autoencoder
    ae = ds[0]
    ae.save_model(os.path.join(output_dir, 'autoencoder'))
    # Save latent features
    latent_features = ds[1]
    processing.save_tensor(latent_features,output_dir, 'latent_features.pt')
    print("Latent features saved successfully.")
else:
    latent_features = torch.load(os.path.join(output_dir, 'latent_features.pt'))
    ae = timeautodiff.DeapStack.load_model(os.path.join(output_dir, 'autoencoder.pt'))


## 2.3) Training Diffusion Model

In [None]:
################################################################################################################
# Diffusion Training #
################################################################################################################
if train_diff:
    num_classes = len(latent_features)

    new_params = {
        "diff_num_classes" : num_classes,
    }   
    # Call the method
    tdf_helper.append_new_params_to_metadata(output_dir, new_params)

    diff = timeautodiff.train_diffusion(latent_features, cond, time_info, hidden_dim, num_layers, diffusion_steps, n_epochs,output_dir,
                                        checkpoints = True, num_classes = num_classes,
                                        mmd_weight = diff_mmd_weight, consistency_weight = diff_consistency_weight, use_wandb=use_wandb)
with open('output_metadata.txt', 'a') as f:
    f.write(f"Metadata: {metadata}, Full Metadata: {full_metadata}, Output Directory: {output_dir}\n")

# 3) Generating Synthetic Data

## 3.1) Model Loading (In case not loaded)

In [None]:
################################################################################################################
# Model Evaluation
################################################################################################################
output_dir = f'outputs/{task_name}/{data_name}/TimeAutoDiff/'
latest_diffusion_timestamp = sorted(os.listdir(output_dir))[-1]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"############ Evaluating timestamp {latest_diffusion_timestamp}: ############")

model = tdf_helper.load_models_only(latest_diffusion_timestamp, task_name, data_name)



## 3.2) Sampling from Model

In [None]:
response_train, outcome_train, static_train, time_info_train = processing.process_data_for_synthesizer(X_train, y_train, c_train, top10_important_features)
cond_train = torch.concatenate((static_train, outcome_train), axis=2)
response_train = response_train.float()
time_info_train = time_info_train.float()
cond_train = cond_train.float()


In [None]:
synth_data_list = []
synth_data_y_list = []



n_generations = 2
for i in tqdm.notebook.tqdm(range(n_generations), desc="Generating Synthetic Data", leave=True):



    _synth_data = tdf_helper.generate_synthetic_data_in_batches(model, cond_train, time_info_train, 
                                                                       batch_size = 10000)
    _synth_data_y = cond_train[:, 0, -1]
    synth_data_list.append(_synth_data.cpu().numpy())
    synth_data_y_list.append(_synth_data_y.cpu().numpy().reshape(-1,))


