In [None]:
import numpy as np
import torch 

import data_access.base_loader as base_loader
import data_access.ricu_loader as ricu_loader
import os
import datetime
import wandb
import ast
import logging
import json

import timeautodiff.processing_simple as processing
import timeautodiff.helper_simple as tdf_helper
import timeautodiff.timeautodiff_v4_efficient_simple as timeautodiff
import evaluation_framework.vis as vis


In [None]:
# splitting parameters
train_fraction = 0.2
val_fraction = 0.05
oracle_fraction = 0
oracle_min = 100
intersectional_min_threshold = 100
intersectional_max_threshold = 1000


# # data parameters
data_name = 'eicu' # 'mimic' 'eicu'
task_name = 'mortality24' # 'aki' 'kidney_function' 'los' 'los_24' 'mortality24' 
static_var = 'ethnicity'
features = None
ricu_dataset_path = f'../../real_data/raw/{task_name}/{data_name}'
# processed_output_path = f'../../real_data/processed/{task_name}/{data_name}'
# intermed_output_path = f'../../real_data/intermed/{task_name}/{data_name}'
# processed_data_timestamp = '20250113132215'
processed_output_path = f'outputs/{task_name}/{data_name}/processed/'
intermed_output_path = f'outputs/{task_name}/{data_name}/intermed/'
seed = 0

simple_imputation = True
mode = 'processed'
processed_data_timestamp = '20250501180302'  #'20250501180110'# 
intermed_data_timestamp = None

standardize = False
save_intermed_data = True
save_processed_data = True
split = True
stratify =  False
intersectional = False

if split == False:
    split_text = 'No Split'
else:
    split_text = 'Split'
data_params = {
    'processed_data_timestamp':processed_data_timestamp,
    'task_name': task_name,
    'data_name': data_name,
    'train_fraction': train_fraction,
    'val_fraction': val_fraction,
    'test_fraction': 1 - train_fraction - val_fraction,
    'oracle_fraction': oracle_fraction,
    'oracle_min': oracle_min,
    'intersectional_min_threshold': intersectional_min_threshold,
    'intersectional_max_threshold': intersectional_max_threshold,
    'split': split_text,
    'standardize' : standardize,
}

loader = ricu_loader.RicuLoader(seed, task_name, data_name,static_var,ricu_dataset_path,simple_imputation,
                                    features, processed_output_path,intermed_output_path)





X_dict_tf, y_dict, static = loader.get_data(
    mode='processed', 
    train_fraction=train_fraction,
    val_fraction=val_fraction,
    oracle_fraction=oracle_fraction,
    oracle_min=oracle_min,
    intersectional_min_threshold=intersectional_min_threshold,
    intersectional_max_threshold=intersectional_max_threshold,
    stratify=stratify,
    intersectional=intersectional,
    save_intermed_data=False,
    save_processed_data=False,
    demographics_to_stratify_on = ['age_group','ethnicity','gender'],
    processed_timestamp=processed_data_timestamp
)
    
if not isinstance(X_dict_tf, dict):
    X_dict_tf = {file: X_dict_tf[file] for file in X_dict_tf.files}
    y_dict = {file: y_dict[file] for file in y_dict.files}

# data_params = {
#     'processed_data_timestamp':processed_data_timestamp,
#     'task_name': task_name,
#     'data_name': data_name,
#     'train_fraction': train_fraction,
#     'val_fraction': val_fraction,
#     'test_fraction': test_fraction,
#     'oracle_fraction': oracle_fraction,
#     'oracle_min': oracle_min,
#     'intersectional_min_threshold': intersectional_min_threshold,
#     'intersectional_max_threshold': intersectional_max_threshold,
#     'split': split_text,
#     'standardize' : standardize,
# }
X_dict_tf.keys()

In [None]:

# most_important_features = [19, 27, 17, 35, 22, 44, 42, 43, 37, 26]
X_train = X_dict_tf['X_imputed_train'][:,:,:]
X_test = X_dict_tf['X_imputed_test'][:,:,:]
X_val = X_dict_tf['X_imputed_val'][:,:,:]

m_train = X_dict_tf['m_train'][:,:,:]
m_test = X_dict_tf['m_test'][:,:,:]
m_val = X_dict_tf['m_val'][:,:,:]

feature_names = X_dict_tf['feature_names'][:]
y_train = y_dict['y_train'][:]
y_test = y_dict['y_test'][:]
y_val = y_dict['y_val'][:]


static_feature_names = ['ethnicity','gender','age_group']
static_features_to_include_indices = sorted([y_dict['feature_names'].tolist().index(include)  for include in static_feature_names])
c_train = y_dict['c_train'][:,static_features_to_include_indices]
c_test = y_dict['c_test'][:,static_features_to_include_indices]
c_val = y_dict['c_val'][:,static_features_to_include_indices]

cond_names = static_feature_names


top10_important_features = [19, 27, 17, 35, 22, 44, 42, 43, 37, 26]
top3_important_features = [44,42,43]
top6_important_features = [42, 22, 27, 35, 43, 17]

important_features_names = X_dict_tf['feature_names'][top10_important_features]
important_features_names

X_train_10 = processing.normalize_and_reshape(X_train)
X_train_10 = X_train_10[:,:,top10_important_features]
y_train_10 = y_train

X_val_10 = processing.normalize_and_reshape(X_val)
X_val_10 = X_val_10[:,:,top10_important_features]
y_val_10 = y_val

print('Shape of X train:', X_train.shape)
print('Shape of X test:', X_test.shape)
print('Shape of X val:', X_val.shape)

print('Shape of y train:', y_train.shape)
print('Shape of y test:', y_test.shape)
print('Shape of y val:', y_val.shape)

print('Shape of c train:', c_train.shape)
print('Shape of c test:', c_test.shape)
print('Shape of c val:', c_val.shape)


## Model Loading

In [None]:
################################################################################################################
# Model Evaluation
################################################################################################################
################################################################################################################
# Model Evaluation
################################################################################################################
diff_timestamp = '20250430_124212_10features_v4_efficient_simple_eicu_mortality24'
diff_timestamp = diff_timestamps[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"############ Evaluating timestamp {diff_timestamp}: ############")

model = tdf_helper.load_models_only(diff_timestamp, task_name, data_name)



In [None]:
response_test, outcome_test, static_test, time_info_test = processing.process_data_for_synthesizer(X_test, y_test, c_test, top10_important_features)
cond_test = torch.concatenate((static_test, outcome_test), axis=2)
response_test = response_test.float()
time_info_test = time_info_test.float()
cond_test = cond_test.float()


## Sampling

In [None]:
synth_data_list = []
synth_data_y_list = []

# Generate synthetic data


# # for test data 
# real_data_y = conditioning_test[:, 0, outcome_indices]
# _synth_data_y = conditioning_test[:, 0, outcome_indices]
# demographic_data = conditioning_test[:, 0, cond_indices]

sample_set = 'test'
hybrid = False

if sample_set == 'test':
    conditioning = cond_test
    time = time_info_test
    real = response_test
# elif sample_set == 'oracle':
#     conditioning = conditioning_oracle
#     time = time_oracle
#     real = real_oracle
# if hybrid:
#     _, complementary_real ,_, complmentary_conditioning= train_test_split(real_test,conditioning_test, test_size = real_test.shape[0] - real_oracle.shape[0] , stratify = conditioning_test[:,0,cond_indices].cpu().numpy() )
#     hybrid_conditioning = torch.cat((conditioning_oracle, complmentary_conditioning), dim=0)
#     hybrid_demographic_data = hybrid_conditioning[:,0,cond_indices]


real_data_y = conditioning[:, 0, -1]
_synth_data_y = conditioning[:, 0, -1]
demographic_data = conditioning[:, 0, :-1]

# remove single sample subgroups from data
mask_test = single_sample_subgroups_mask(demographic_data.cpu().numpy())
real_data_y = real_data_y[mask_test]
_synth_data_y = _synth_data_y[mask_test]
demographic_data = demographic_data[mask_test]
conditioning = conditioning[mask_test]
time = time[mask_test]
real = real[mask_test]

# # remove single sample subgroups from train data
# mask_train = single_sample_subgroups_mask(c_train)
# c_train = c_train[mask_train]
# down_train_X = down_train_X[mask_train]
# down_train_y = down_train_y[mask_train]
n_generations = 2
for i in tqdm(range(n_generations), desc="Generating Synthetic Data", leave=True):



    _synth_data = tdf_helper.generate_synthetic_data_in_batches(model, cond_test, time_info_test, 
                                                                       batch_size = 10000)
    
    if not hybrid:
        synth_data_list.append(_synth_data.cpu().numpy())
        synth_data_y_list.append(_synth_data_y.cpu().numpy().reshape(-1,))
    # else:
    #     # create hybrid data consisting of oracle synthetic ata and remaining real data from test set
    #     hybrid_data = torch.cat((_synth_data, complementary_real.cpu()), dim=0)
    #     hybrid_y = hybrid_conditioning[:,0,outcome_indices]
    #     synth_data_list.append(hybrid_data.cpu().numpy())
    #     synth_data_y_list.append(hybrid_y.cpu().numpy().reshape(-1,))

