In [28]:
# stdlib
import sys
import warnings

# synthcity absolute
import synthcity.logger as log
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import TimeSeriesSurvivalDataLoader

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

In [56]:
import numpy as np
from synthcity.utils.datasets.time_series.pbc import PBCDataloader

(
    static_surv,
    temporal_surv,
    temporal_surv_horizons,
    outcome_surv,
) = PBCDataloader().load()
T, E = outcome_surv

horizons = [0.25, 0.5, 0.75]
time_horizons = np.quantile(T, horizons).tolist()

loader = TimeSeriesSurvivalDataLoader(
    temporal_data=temporal_surv,
    observation_times=temporal_surv_horizons,
    static_data=static_surv,
    T=T,
    E=E,
    time_horizons=time_horizons,
)

loader.dataframe()

Unnamed: 0,seq_id,seq_time_id,seq_static_sex,seq_temporal_SGOT,seq_temporal_age,seq_temporal_albumin,seq_temporal_alkaline,seq_temporal_ascites,seq_temporal_drug,seq_temporal_edema,seq_temporal_hepatomegaly,seq_temporal_histologic,seq_temporal_platelets,seq_temporal_prothrombin,seq_temporal_serBilir,seq_temporal_serChol,seq_temporal_spiders,seq_out_time_to_event,seq_out_event
0,0,0.569489,0.0,-1.485263,0.248058,-0.894575,0.195532,1.0,0.0,1.0,1.0,3.0,-0.529101,0.136768,3.281890,0.000000,1.0,0.569489,1.0
1,0,1.095170,0.0,0.195488,0.248058,-1.570646,0.285613,1.0,0.0,1.0,1.0,3.0,-0.456022,0.813132,2.015877,-0.469461,1.0,0.569489,1.0
2,1,5.319790,0.0,-0.442126,1.292856,-1.431455,-0.605844,1.0,0.0,1.0,1.0,2.0,-1.395605,0.339677,0.172710,-0.658914,1.0,14.152338,0.0
3,1,6.261636,0.0,-0.046806,1.292856,-1.172958,-0.512364,1.0,0.0,1.0,1.0,2.0,-1.259888,0.339677,-0.013468,-0.603657,1.0,14.152338,0.0
4,1,7.266455,0.0,0.293680,1.292856,-1.312149,-0.443529,1.0,0.0,1.0,1.0,2.0,-1.364286,0.339677,0.098239,0.000000,1.0,14.152338,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1940,311,1.045888,0.0,0.986129,-1.962482,0.059878,1.385274,0.0,1.0,2.0,0.0,2.0,-1.103291,1.624769,3.672865,3.319599,1.0,3.989158,0.0
1941,311,1.867265,0.0,0.641817,-1.962482,-0.099197,0.916176,0.0,1.0,2.0,0.0,1.0,-0.998892,1.354223,2.350998,2.901224,1.0,3.989158,0.0
1942,311,2.921367,0.0,0.552551,-1.962482,0.338261,0.327254,0.0,1.0,0.0,0.0,1.0,-0.894494,0.474950,0.694010,-0.066873,0.0,3.989158,0.0
1943,311,3.425145,0.0,0.016956,-1.962482,-0.377580,0.251620,0.0,1.0,0.0,0.0,1.0,-0.466462,-0.066141,0.340271,0.000000,0.0,3.989158,0.0


In [60]:
temporal_surv[0].columns

Index(['drug', 'ascites', 'hepatomegaly', 'spiders', 'edema', 'histologic',
       'serBilir', 'serChol', 'albumin', 'alkaline', 'SGOT', 'platelets',
       'prothrombin', 'age'],
      dtype='object')

In [58]:

loader.dataframe().columns

Index(['seq_id', 'seq_time_id', 'seq_static_sex', 'seq_temporal_SGOT',
       'seq_temporal_age', 'seq_temporal_albumin', 'seq_temporal_alkaline',
       'seq_temporal_ascites', 'seq_temporal_drug', 'seq_temporal_edema',
       'seq_temporal_hepatomegaly', 'seq_temporal_histologic',
       'seq_temporal_platelets', 'seq_temporal_prothrombin',
       'seq_temporal_serBilir', 'seq_temporal_serChol', 'seq_temporal_spiders',
       'seq_out_time_to_event', 'seq_out_event'],
      dtype='object')

In [30]:
T

0       0.569489
1      14.152338
2       0.736502
3       0.276531
4       4.120578
         ...    
307     4.988501
308     4.553171
309     4.402585
310     4.128792
311     3.989158
Name: time_to_event, Length: 312, dtype: float64

In [31]:
E

0      1
1      0
2      1
3      1
4      0
      ..
307    0
308    0
309    0
310    0
311    0
Name: event, Length: 312, dtype: int64

In [32]:
T

0       0.569489
1      14.152338
2       0.736502
3       0.276531
4       4.120578
         ...    
307     4.988501
308     4.553171
309     4.402585
310     4.128792
311     3.989158
Name: time_to_event, Length: 312, dtype: float64

In [33]:
static_surv

Unnamed: 0,sex
0,0.0
1,0.0
2,1.0
3,0.0
4,0.0
...,...
307,0.0
308,0.0
309,0.0
310,0.0


In [34]:
print(len(outcome_surv[0]	))
print(len(outcome_surv[1]	))
outcome_surv[0]

312
312


0       0.569489
1      14.152338
2       0.736502
3       0.276531
4       4.120578
         ...    
307     4.988501
308     4.553171
309     4.402585
310     4.128792
311     3.989158
Name: time_to_event, Length: 312, dtype: float64

In [54]:
temporal_surv_horizons[1]

array([ 5.31978973,  6.26163618,  7.26645493,  8.26305991,  9.2514511 ,
       12.04961121, 13.15299529, 13.6540357 , 14.15233819])

In [35]:
print(len(time_horizons))
for i in range(10):
    print(temporal_surv_horizons[i].shape)

3
(2,)
(9,)
(4,)
(7,)
(6,)
(6,)
(7,)
(8,)
(7,)
(1,)


In [36]:
print(len(temporal_surv_horizons))
for i in range(10):
    print(temporal_surv_horizons[i].shape)

312
(2,)
(9,)
(4,)
(7,)
(6,)
(6,)
(7,)
(8,)
(7,)
(1,)


In [37]:
print(temporal_surv[-1].columns)
temporal_surv[0].columns

Index(['drug', 'ascites', 'hepatomegaly', 'spiders', 'edema', 'histologic',
       'serBilir', 'serChol', 'albumin', 'alkaline', 'SGOT', 'platelets',
       'prothrombin', 'age'],
      dtype='object')


Index(['drug', 'ascites', 'hepatomegaly', 'spiders', 'edema', 'histologic',
       'serBilir', 'serChol', 'albumin', 'alkaline', 'SGOT', 'platelets',
       'prothrombin', 'age'],
      dtype='object')

In [38]:
for i in range(10):
    print(temporal_surv[i].shape)

(2, 14)
(9, 14)
(4, 14)
(7, 14)
(6, 14)
(6, 14)
(7, 14)
(8, 14)
(7, 14)
(1, 14)


In [39]:
temporal_surv[-1].shape

(5, 14)

In [40]:

len(temporal_surv)

312

In [69]:
type(static)

pandas.core.frame.DataFrame

In [76]:
temporal[1]

Unnamed: 0,Open,High,Low,Close,Volume
21,0.557005,0.569527,0.626597,0.587297,0.261217
22,0.552061,0.569527,0.594394,0.605604,0.165112
23,0.510852,0.540625,0.587745,0.590535,0.220572
24,0.451786,0.459784,0.513299,0.52117,0.194089
25,0.421704,0.394788,0.44498,0.469988,0.140587
26,0.387225,0.411498,0.42738,0.455293,0.204145
27,0.345879,0.392523,0.408997,0.437733,0.234811
28,0.286951,0.267487,0.323859,0.367622,0.211399
29,0.332143,0.328802,0.390613,0.369489,0.321244
30,0.205906,0.282073,0.281487,0.332628,0.346344


In [77]:
for i in range(10):
    
    print(temporal[i].shape)

(10, 5)
(10, 5)
(10, 5)
(10, 5)
(10, 5)
(10, 5)
(10, 5)
(10, 5)
(10, 5)
(10, 5)


In [85]:
for i in range(10):
    
    print(len(horizons[i]))

10
10
10
10
10
10
10
10
10
10


In [78]:
len(temporal)

50

In [80]:
len(horizons)

50

In [81]:
len(outcome)

50

In [83]:
outcome

Unnamed: 0,Open_next
0,0.710852
1,0.756044
2,0.687362
3,0.642857
4,0.628297
5,0.671978
6,0.704808
7,0.684753
8,0.684753
9,0.607281


In [100]:
len(temporal)
temporal



[        Open      High       Low     Close    Volume
 40  0.361401  0.378079  0.409257  0.378020  0.110766
 41  0.370879  0.331705  0.398696  0.426526  0.257110
 42  0.388325  0.344945  0.390613  0.408343  0.281098
 43  0.393819  0.402775  0.473012  0.449066  0.179121
 44  0.389149  0.348485  0.374967  0.430137  0.279255
 45  0.359753  0.359247  0.403651  0.449439  0.166763
 46  0.399038  0.373123  0.431552  0.410585  0.211437
 47  0.378984  0.384961  0.422686  0.433749  0.249088
 48  0.225962  0.355707  0.284146  0.427522  0.382690
 49  0.099863  0.165675  0.190483  0.276463  0.265592,
         Open      High       Low     Close    Volume
 21  0.557005  0.569527  0.626597  0.587297  0.261217
 22  0.552061  0.569527  0.594394  0.605604  0.165112
 23  0.510852  0.540625  0.587745  0.590535  0.220572
 24  0.451786  0.459784  0.513299  0.521170  0.194089
 25  0.421704  0.394788  0.444980  0.469988  0.140587
 26  0.387225  0.411498  0.427380  0.455293  0.204145
 27  0.345879  0.392523  0.

In [101]:
horizons

[[0.35227272727274794,
  0.3409090909090935,
  0.32954545454546746,
  0.2840909090909065,
  0.2727272727272805,
  0.26136363636365445,
  0.25,
  0.23863636363637397,
  0.20454545454546746,
  0.19318181818184144],
 [0.6590909090909065,
  0.6477272727272805,
  0.6363636363636545,
  0.6022727272727479,
  0.5909090909090935,
  0.5795454545454675,
  0.5681818181818414,
  0.556818181818187,
  0.5227272727272805,
  0.5113636363636545],
 [0.19318181818184144,
  0.18181818181818699,
  0.17045454545456096,
  0.125,
  0.11363636363637397,
  0.10227272727274794,
  0.09090909090909349,
  0.04545454545456096,
  0.03409090909090651,
  0.02272727272728048],
 [0.2840909090909065,
  0.2727272727272805,
  0.26136363636365445,
  0.25,
  0.23863636363637397,
  0.20454545454546746,
  0.19318181818184144,
  0.18181818181818699,
  0.17045454545456096,
  0.125],
 [0.6022727272727479,
  0.5909090909090935,
  0.5795454545454675,
  0.5681818181818414,
  0.556818181818187,
  0.5227272727272805,
  0.511363636363654

Mi modelo

In [64]:
from synthcity.plugins import Plugins
from synthcity.utils.datasets.time_series.google_stocks import GoogleStocksDataloader
from synthcity.plugins.core.dataloader import TimeSeriesDataLoader
static, temporal, horizons, outcome = GoogleStocksDataloader().load()
loader = TimeSeriesDataLoader(
            temporal_data=temporal,
            observation_times=horizons,
            static_data=static,
            outcome=outcome,
)

plugin = Plugins().get("timegan", n_iter = 50)
plugin.fit(loader)

plugin.generate(count = 10)

[2024-02-29T17:15:22.706207-0500][2250433][CRITICAL] load failed: cannot import name '_centered' from 'scipy.signal.signaltools' (/gel/usr/cyyba/.local/lib/python3.8/site-packages/scipy/signal/signaltools.py)
[2024-02-29T17:15:22.706207-0500][2250433][CRITICAL] load failed: cannot import name '_centered' from 'scipy.signal.signaltools' (/gel/usr/cyyba/.local/lib/python3.8/site-packages/scipy/signal/signaltools.py)
[2024-02-29T17:15:22.706207-0500][2250433][CRITICAL] load failed: cannot import name '_centered' from 'scipy.signal.signaltools' (/gel/usr/cyyba/.local/lib/python3.8/site-packages/scipy/signal/signaltools.py)
[2024-02-29T17:15:22.708515-0500][2250433][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_bayesian_network' has no attribute 'plugin'
[2024-02-29T17:15:22.708515-0500][2250433][CRITICAL] load failed: module 'synthcity.plugins.generic.plugin_bayesian_network' has no attribute 'plugin'
[2024-02-29T17:15:22.708515-0500][2250433][CRITICAL] load failed: modul

Unnamed: 0,seq_id,seq_time_id,seq_temporal_Close,seq_temporal_High,seq_temporal_Low,seq_temporal_Open,seq_temporal_Volume,seq_out_Open_next
0,0,0.993770,0.747341,0.746894,0.683526,0.717553,0.193881,0.341376
1,0,0.998660,0.415571,0.655955,0.718293,0.683821,0.203200,0.341376
2,0,0.992888,0.721466,0.745312,0.427327,0.664797,0.208171,0.341376
3,0,0.767018,0.716643,0.701495,0.686124,0.665292,0.219903,0.341376
4,0,0.871275,0.408874,0.669304,0.708071,0.380994,0.196118,0.341376
...,...,...,...,...,...,...,...,...
85,9,0.983092,0.685961,0.669304,0.708071,0.380994,0.196118,0.341381
86,9,0.998092,0.722754,0.684393,0.424516,0.382395,0.194949,0.341381
87,9,0.942539,0.702297,0.689401,0.691115,0.642821,0.201439,0.341381
88,9,0.999705,0.706766,0.376431,0.710871,0.386410,0.232276,0.341381


In [92]:
import pandas as pd
df =  pd.read_csv('generative_input/input_onehot_encoding.csv')
adm = pd.read_csv('./data/data_preprocess_nonfilteres.csv')
res = pd.merge(adm[["HOSPITAL_EXPIRE_FLAG","SUBJECT_ID","HADM_ID"]],df, on=["SUBJECT_ID","HADM_ID"], how='right')
res = res.fillna(0)
outcome =  res.groupby('SUBJECT_ID')['HOSPITAL_EXPIRE_FLAG'].max()

print(outcome.shape)

(44952,)


In [93]:


print(res.shape, adm.shape, df.shape)
# Assuming df is your DataFrame

# Find columns that contain 'unnamed' in their name
cols_to_drop = res.filter(like='Unnamed', axis=1).columns
res.drop(cols_to_drop, axis=1, inplace=True)
print(res.shape, adm.shape, df.shape)

print(res.isnull().sum().sum())
res = res.fillna(0)
print(res.isnull().sum().sum())

(56678, 690) (58976, 41) (56678, 689)
(56678, 689) (58976, 41) (56678, 689)
0
0


loader = TimeSeriesSurvivalDataLoader(
    temporal_data=temporal_surv,
    observation_times=temporal_surv_horizons,
    static_data=static_surv,
    T=T,
    E=E,
    time_horizons=time_horizons,
)

loader.dataframe()

In [94]:
res.columns

Index(['HOSPITAL_EXPIRE_FLAG', 'SUBJECT_ID', 'HADM_ID', 'ADMITTIME',
       '1_diagnosis', '2_diagnosis', '3_diagnosis', '4_diagnosis',
       '5_diagnosis', '6_diagnosis',
       ...
       'RELIGION_Otra', 'RELIGION_UNOBTAINABLE', 'MARITAL_STATUS_MARRIED',
       'MARITAL_STATUS_Otra', 'MARITAL_STATUS_SINGLE',
       'ETHNICITY_BLACK/AFRICAN AMERICAN', 'ETHNICITY_Otra', 'ETHNICITY_WHITE',
       'GENDER_M', 'GENDER_Otra'],
      dtype='object', length=689)

In [None]:

static_data = res[['ADMISSION_TYPE', 'ADMISSION_LOCATION',
                'DISCHARGE_LOCATION', 'INSURANCE',  'RELIGION',
                'MARITAL_STATUS', 'ETHNICITY','GENDER','SUBJECT_ID', 'HADM_ID']]


static_data = ['ADMISSION_TYPE', 'ADMISSION_LOCATION',
                'DISCHARGE_LOCATION', 'INSURANCE',  'RELIGION',
                'MARITAL_STATUS', 'ETHNICITY','GENDER','HOSPITAL_EXPIRE_FLAG', 'SUBJECT_ID', 'HADM_ID', 'ADMITTIME',]


                

In [103]:
# crear visit ranks
res['ADMITTIME'] = pd.to_datetime(res['ADMITTIME'])

# Ordenar el DataFrame por 'SUBJECT_ID' y 'ADMITTIME' para asegurar el orden correcto
res = res.sort_values(by=['SUBJECT_ID', 'ADMITTIME'])

# Agregar una nueva columna 'VISIT_NUMBER' que indica el número de visita para cada 'SUBJECT_ID'
res['visit_rank'] = res.groupby('SUBJECT_ID').cumcount() + 1
# Crear una nueva columna 'visit_rank' que represente el número de la visita para cada paciente


# Ahora, vamos a separar las visitas en DataFrames individuales y guardarlos en una lista
max_visits = res['visit_rank'].max()
temporal_surv = [res[res['visit_rank'] == i] for i in range(1, max_visits + 1)]


In [96]:
temporal_surv[0].columns

Index(['HOSPITAL_EXPIRE_FLAG', 'SUBJECT_ID', 'HADM_ID', 'ADMITTIME',
       '1_diagnosis', '2_diagnosis', '3_diagnosis', '4_diagnosis',
       '5_diagnosis', '6_diagnosis',
       ...
       'RELIGION_UNOBTAINABLE', 'MARITAL_STATUS_MARRIED',
       'MARITAL_STATUS_Otra', 'MARITAL_STATUS_SINGLE',
       'ETHNICITY_BLACK/AFRICAN AMERICAN', 'ETHNICITY_Otra', 'ETHNICITY_WHITE',
       'GENDER_M', 'GENDER_Otra', 'visit_rank'],
      dtype='object', length=690)

In [104]:
import pandas as pd

# Asumiendo que df_list es tu lista de DataFrames y que df_list[0] es el DataFrame que contiene todos los 'SUBJECT_ID' que quieres mantener

# Crear un DataFrame que solo contiene los 'SUBJECT_ID' únicos de df_list[0]
unique_subjects = pd.DataFrame(temporal_surv[0]['SUBJECT_ID'].unique(), columns=['SUBJECT_ID'])

# Inicializar una lista vacía para almacenar los DataFrames modificados
new_df_list = []

# Para cada DataFrame en df_list, hacer un 'merge' con 'unique_subjects' para asegurar que todos los 'SUBJECT_ID' de df_list[0] estén presentes
for df in temporal_surv:
    new_df = pd.merge(unique_subjects, df, on='SUBJECT_ID', how='outer')
    new_df.fillna(0, inplace=True)  # Llenar los valores faltantes con 0
    new_df_list.append(new_df)

In [105]:
for i in range(10):
    print(new_df_list[i].shape)

(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
(44952, 690)
