"""
Copyright 2026 Zsolt Bedőházi, András M. Biricz

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime

In [None]:
RND_SEED = 10

### Load saved dataframe

In [None]:
merged_df = pd.read_csv('merged_metadata_v2.1.csv')
merged_df.head(2)

### Perform filtering for stages I-IV and no neoadj treatments

In [None]:
np.unique(merged_df.response_neoadjuv_therapy, return_counts=True)

In [None]:
np.unique(merged_df.immuno_therapy_cd, return_counts=True)

In [None]:
filt_stages_treatment = (merged_df.stage > 0) &\
( (merged_df.response_neoadjuv_therapy == 0 ) | ( np.isnan(merged_df.response_neoadjuv_therapy)) ) &\
 merged_df.immuno_therapy_cd.isin([0, 82, 85, 86, 87]) 
filt_stages_treatment.sum()

In [None]:
merged_df = merged_df[filt_stages_treatment]
merged_df.reset_index(inplace=True, drop=True)

### Filter nan neoadj response using treatment dates that are prior to biopsy date

In [None]:
merged_df.response_neoadjuv_therapy.isna()

In [None]:
merged_neo_nan_df = merged_df[ merged_df.response_neoadjuv_therapy.isna() ]
merged_neo_nan_df.shape

In [None]:
biopsy_date = merged_neo_nan_df.biopsy_dt.values.astype(str)
biopsy_date[ biopsy_date == 'nan' ] = '0100-01-01' # set to this to mark NaNs with 0
biopsy_date = np.array( [ datetime.timestamp( datetime.strptime(b, "%Y-%m-%d") ) for b in biopsy_date ] ) # shift back
biopsy_date

In [None]:
radiation_start_dt_date = merged_neo_nan_df.radiation_start_dt.values.astype(str)
radiation_start_dt_date[ radiation_start_dt_date == 'nan' ] = '0100-01-01' # set to this to mark NaNs with 0
radiation_start_dt_date = np.array( [ datetime.timestamp( datetime.strptime(b, "%Y-%m-%d") ) for b in radiation_start_dt_date ] ) # shift back
radiation_start_dt_date

In [None]:
plt.figure(figsize=(2,2))
plt.hist(radiation_start_dt_date)

In [None]:
rx_chemo_dt_date = merged_neo_nan_df.rx_chemo_dt.values.astype(str)
rx_chemo_dt_date[ rx_chemo_dt_date == 'nan' ] = '0100-01-01' # set to this to mark NaNs with 0
rx_chemo_dt_date = np.array( [ datetime.timestamp( datetime.strptime(b, "%Y-%m-%d") ) for b in rx_chemo_dt_date ] ) # shift back
rx_chemo_dt_date

In [None]:
plt.figure(figsize=(2,2))
plt.hist(rx_chemo_dt_date)

In [None]:
rx_hormone_dt_date = merged_neo_nan_df.rx_hormone_dt.values.astype(str)
rx_hormone_dt_date[ rx_hormone_dt_date == 'nan' ] = '0100-01-01' # set to this to mark NaNs with 0
rx_hormone_dt_date = np.array( [ datetime.timestamp( datetime.strptime(b, "%Y-%m-%d") ) for b in rx_hormone_dt_date ] ) # shift back
rx_hormone_dt_date

In [None]:
plt.figure(figsize=(2,2))
plt.hist(rx_hormone_dt_date)

#### all entries with dates 

In [None]:
filt_all_treament_date = (radiation_start_dt_date > 0) & ( rx_chemo_dt_date > 0) & (rx_hormone_dt_date > 0) 
filt_all_treament_date[:5], filt_all_treament_date.sum()

In [None]:
filt_all_treatment_date_after_biopsy_date = (biopsy_date[filt_all_treament_date] < radiation_start_dt_date[filt_all_treament_date] )&\
        (biopsy_date[filt_all_treament_date] < rx_chemo_dt_date[filt_all_treament_date] )&\
        (biopsy_date[filt_all_treament_date] < rx_hormone_dt_date[filt_all_treament_date] )
filt_all_treatment_date_after_biopsy_date.sum()

#### all entries without date in any of treatments

In [None]:
filt_any_treatment_date_nan = (radiation_start_dt_date < 0) | ( rx_chemo_dt_date < 0) | (rx_hormone_dt_date < 0) 
filt_any_treatment_date_nan[:5], filt_any_treatment_date_nan.sum()

In [None]:
filt_any_treatment_nan_with_code = ( merged_neo_nan_df[filt_any_treatment_date_nan].radiation_summ_cd.isin([0,7]) ) &\
                                    ( merged_neo_nan_df[filt_any_treatment_date_nan].chemo_summ_cd.isin([0,7]) ) &\
                                    ( merged_neo_nan_df[filt_any_treatment_date_nan].hormone_summ_cd.isin([0,7]) )
filt_any_treatment_nan_with_code.sum()

### Build new dataframe

In [None]:
merged_no_neo_df = merged_df[ ~merged_df.response_neoadjuv_therapy.isna() ]
merged_no_neo_df.shape

In [None]:
np.unique(merged_no_neo_df.radiation_summ_cd, return_counts=True)

In [None]:
merged_nan_neo_date_filtered = merged_neo_nan_df[filt_all_treament_date][filt_all_treatment_date_after_biopsy_date]
merged_nan_neo_date_filtered.shape

In [None]:
merged_nan_neo_code_filtered = merged_neo_nan_df[filt_any_treatment_date_nan][filt_any_treatment_nan_with_code]
merged_nan_neo_code_filtered.shape

In [None]:
merged_df = pd.concat( (merged_no_neo_df, merged_nan_neo_date_filtered, merged_nan_neo_code_filtered ) )
merged_df.sort_values(by='patient_ngsci_id', inplace=True)
merged_df.reset_index(inplace=True, drop=True)
merged_df.shape

In [None]:
np.unique(merged_df.radiation_summ_cd, return_counts=True)

In [None]:
np.unique(merged_df.hormone_summ_cd, return_counts=True)

In [None]:
np.unique(merged_df.stage, return_counts=True)

### Get rows only for unique patients

In [None]:
patients_unique, patients_count = np.unique( merged_df.patient_ngsci_id.values, return_counts=True )
patients_unique.shape, patients_count.shape

In [None]:
patients_to_num = dict( zip( patients_unique, np.arange(patients_unique.shape[0]) ) )

In [None]:
merged_df['patient_num_ngsci'] = np.array( [ patients_to_num[p] for p in merged_df.patient_ngsci_id.values ] )

In [None]:
( merged_df.groupby('patient_num_ngsci')['stage'].nunique() > 1 ).sum(), merged_df.groupby('patient_num_ngsci')['stage'].nunique().sum()
## conclusion: there is no multiple cases with different stage !

In [None]:
merged_df.groupby('patient_num_ngsci')['biopsy_dt']

In [None]:
patient_num_for_multi_biopsies = (merged_df.groupby('patient_num_ngsci')['biopsy_dt'].nunique() > 1)
patient_num_filter = np.arange(patients_unique.shape[0])[ patient_num_for_multi_biopsies ]
patient_num_for_multi_biopsies.sum(), patient_num_filter

In [None]:
patiens_multi_biopsies = merged_df[ np.in1d( merged_df['patient_num_ngsci'], patient_num_filter ) ] 
patiens_multi_biopsies.shape#head()

In [None]:
np.unique( patiens_multi_biopsies.patient_num_ngsci )

In [None]:
patients_grouped_biopsies = patiens_multi_biopsies.groupby('patient_num_ngsci').agg({'biopsy_dt': list}).reset_index()
patients_grouped_biopsies.head()

In [None]:
biopsy_single_filt = np.diff( np.append( merged_df.patient_num_ngsci.values, patients_unique.shape[0] ) ) == 1
biopsy_single_filt.shape

In [None]:
plt.plot(merged_df.patient_num_ngsci)

In [None]:
merged_df_single = merged_df[biopsy_single_filt]
merged_df_multiple = merged_df[~biopsy_single_filt]
merged_df_single.shape, merged_df_multiple.shape

In [None]:
merged_df.to_csv('merged_df_latest.csv')

In [None]:
merged_df_single.reset_index(inplace=True, drop=True)
#merged_df_multiple.reset_index(inplace=True, drop=True)

### Check which variable to partition first

In [None]:
race_uqs, race_counts = np.unique( merged_df_single.race, return_counts=True )
race_uqs, race_counts

In [None]:
np.unique( merged_df_single.stage.values, return_counts=True )

In [None]:
fold_0 = []
fold_1 = []
fold_2 = []
fold_3 = []
fold_4 = [] 
fold_5 = []

In [None]:
idxs_out = []
for r in range(3,5): # go along stage III and IV values
    current_subsample_df_indices = merged_df_single[ merged_df_single.stage == r ].index.values
    print( np.unique( merged_df_single[ merged_df_single.stage == r ].stage.values ) )
    current_subsample_indices_from_zero = np.arange(current_subsample_df_indices.shape[0])
    print(current_subsample_df_indices)
    np.random.seed(RND_SEED)
    rnd_idx = np.random.permutation( current_subsample_indices_from_zero.shape[0] )
    idxs_out.append( current_subsample_df_indices[rnd_idx] )
    
    for f in range(6): # 6 folds
        current_rnd_idx = current_subsample_df_indices[ rnd_idx[f::6] ]
        eval(f'fold_{f}').append( current_rnd_idx )
        
        #print(current_rnd_idx)
idxs_out = np.sort( np.concatenate(idxs_out) )

In [None]:
idxs_rest_filt = ~np.in1d( merged_df_single.index, idxs_out )
np.unique( merged_df_single[ idxs_rest_filt ].race, return_counts=True ), np.unique( merged_df_single[ idxs_rest_filt ].stage.values, return_counts=True )

In [None]:
merged_df_single_rest = merged_df_single[ idxs_rest_filt ]
#merged_df_single_rest.reset_index(inplace=True, drop=True)
merged_df_single_rest.head(2)

#### only sample race minority now


In [None]:
idxs_out = []
for r in race_uqs[1:]: # drop group 1 in this selection
    current_subsample_df_indices = merged_df_single_rest[ merged_df_single_rest.race == r ].index.values
    print( np.unique( merged_df_single_rest[ merged_df_single_rest.race == r ].race.values ) )
    current_subsample_indices_from_zero = np.arange(current_subsample_df_indices.shape[0])
    print(current_subsample_df_indices)
    np.random.seed(RND_SEED)
    rnd_idx = np.random.permutation( current_subsample_indices_from_zero.shape[0] )
    idxs_out.append( current_subsample_df_indices[rnd_idx] )
    
    
    for f in range(6): # 6 folds
        current_rnd_idx = current_subsample_df_indices[ rnd_idx[f::6] ]
        eval(f'fold_{f}').append( current_rnd_idx )
        
        #print(current_rnd_idx)
idxs_out = np.sort( np.concatenate(idxs_out) )

In [None]:
idxs_rest_filt = ~np.in1d( merged_df_single_rest.index, idxs_out )
np.unique( merged_df_single_rest[ idxs_rest_filt ].race, return_counts=True ), np.unique( merged_df_single_rest[ idxs_rest_filt ].stage.values, return_counts=True )

In [None]:
merged_df_single_rest = merged_df_single_rest[ idxs_rest_filt ]
#merged_df_single_rest.reset_index(inplace=True, drop=True)
merged_df_single_rest.head(2)

#### mortality

In [None]:
idxs_out = []
current_subsample_df_indices = merged_df_single_rest[ merged_df_single_rest.mortality == 1 ].index.values
print( np.unique( merged_df_single_rest[ merged_df_single_rest.mortality == 1 ].mortality.values ) )
current_subsample_indices_from_zero = np.arange(current_subsample_df_indices.shape[0])
print(current_subsample_df_indices)
np.random.seed(RND_SEED)
rnd_idx = np.random.permutation( current_subsample_indices_from_zero.shape[0] )
idxs_out.append( current_subsample_df_indices )

for f in range(6): # 6 folds
    current_rnd_idx = current_subsample_df_indices[ rnd_idx[f::6] ]
    eval(f'fold_{f}').append( current_rnd_idx )

    #print(current_rnd_idx)
idxs_out = np.sort( np.concatenate(idxs_out) )

In [None]:
idxs_rest_filt = ~np.in1d( merged_df_single_rest.index, idxs_out )
np.unique( merged_df_single_rest[ idxs_rest_filt ].mortality, return_counts=True )

In [None]:
merged_df_single_rest = merged_df_single_rest[ idxs_rest_filt ]
#merged_df_single_rest.reset_index(inplace=True, drop=True)
merged_df_single_rest.head(2)

In [None]:
merged_df_single_rest.shape

#### age between 50-60

In [None]:
idxs_out = []
#current_subsample_df_indices = merged_df_single_rest[ merged_df_single_rest.tobacco == 1 ].index.values

current_subsample_df_indices = merged_df_single_rest[ np.logical_and(merged_df_single_rest.age.values > 49, 
                                      merged_df_single_rest.age.values < 61) ].index.values

print( np.unique( merged_df_single_rest[ np.logical_and(merged_df_single_rest.age.values > 49, 
                                      merged_df_single_rest.age.values < 61) ].age.values ) )
current_subsample_indices_from_zero = np.arange(current_subsample_df_indices.shape[0])
print(current_subsample_df_indices)
np.random.seed(RND_SEED)
rnd_idx = np.random.permutation( current_subsample_indices_from_zero.shape[0] )
idxs_out.append( current_subsample_df_indices )

for f in range(6): # 6 folds
    current_rnd_idx = current_subsample_df_indices[ rnd_idx[f::6] ]
    eval(f'fold_{f}').append( current_rnd_idx )

    #print(current_rnd_idx)
idxs_out = np.sort( np.concatenate(idxs_out) )

In [None]:
idxs_rest_filt = ~np.in1d( merged_df_single_rest.index, idxs_out )
np.unique( merged_df_single_rest[ idxs_rest_filt ].age.values, return_counts=True )

In [None]:
merged_df_single_rest = merged_df_single_rest[ idxs_rest_filt ]
#merged_df_single_rest.reset_index(inplace=True, drop=True)
merged_df_single_rest.head(2)

In [None]:
np.unique( merged_df_single_rest.stage, return_counts=True )

In [None]:
merged_df_single_rest.shape

In [None]:
idxs_out = []
for r in range(5): # go along all stage values
    current_subsample_df_indices = merged_df_single_rest[ merged_df_single_rest.stage == r ].index.values
    print( np.unique( merged_df_single_rest[ merged_df_single_rest.stage == r ].stage.values ) )
    current_subsample_indices_from_zero = np.arange(current_subsample_df_indices.shape[0])
    print(current_subsample_df_indices)
    np.random.seed(RND_SEED)
    rnd_idx = np.random.permutation( current_subsample_indices_from_zero.shape[0] )
    idxs_out.append( current_subsample_df_indices[rnd_idx] )
    
    
    for f in range(6): # 6 folds
        current_rnd_idx = current_subsample_df_indices[ rnd_idx[f::6] ]
        eval(f'fold_{f}').append( current_rnd_idx )
        
        #print(current_rnd_idx)
idxs_out = np.sort( np.concatenate(idxs_out) )

In [None]:
idxs_rest_filt = ~np.in1d( merged_df_single_rest.index, idxs_out )
np.unique( merged_df_single_rest[ idxs_rest_filt ].tobacco, return_counts=True )

In [None]:
merged_df_single_rest = merged_df_single_rest[ idxs_rest_filt ]
#merged_df_single_rest.reset_index(inplace=True, drop=True)
merged_df_single_rest.head(2)

### Check if indices for all folds are correct

In [None]:
current_fold_uq_all = []
for i in range(6):
    current_fold_uq, current_fold_c = np.unique( np.concatenate( eval(f'fold_{i}')), return_counts=True )
    print( f'Num of duplicate indices in fold {i}: ', ( current_fold_c > 1).sum() )
    current_fold_uq_all.append(current_fold_uq)

In [None]:
np.concatenate( current_fold_uq_all ).shape # this gives back the single patient dataframe - done

In [None]:
#local_test_indices = np.concatenate( fold_5 ) # THIS will be the local test fold

### Save local test fold

### Now look back for additional data into the training folds - first 5 fold

In [None]:
fold_0 = np.concatenate(fold_0)
fold_1 = np.concatenate(fold_1)
fold_2 = np.concatenate(fold_2)
fold_3 = np.concatenate(fold_3)
fold_4 = np.concatenate(fold_4)
fold_5 = np.concatenate(fold_5)
fold_0.shape, fold_1.shape, fold_2.shape, fold_3.shape, fold_4.shape, fold_5.shape

In [None]:
fold_0.shape[0] + fold_1.shape[0] + fold_2.shape[0] + fold_3.shape[0] + fold_4.shape[0] + fold_5.shape[0]

In [None]:
np.unique( merged_df_multiple.stage.values, return_counts=True )

In [None]:
multi_df_to_merge = merged_df_multiple[['biopsy_id', 'stage', 'metastatic_cancer', 'patient_num_ngsci']]
multi_df_to_merge.head()

In [None]:
multi_df_to_merge.shape

In [None]:
#os.makedirs('cv_splits_multi_stratified/', exist_ok=True) # not needed due to sklearn laters
folds_multi_stratified_all = []
for f in range(6): # SELECT TEST 
    current_fold_final_df = merged_df_single.iloc[eval(f'fold_{f}')][['biopsy_id', 'stage', 'metastatic_cancer', 'patient_num_ngsci']]
    print(current_fold_final_df.shape)
    
    to_merge_df = current_fold_final_df.merge(  multi_df_to_merge, on='patient_num_ngsci' )#[['biopsy_id_y', 'stage_y', 'metastatic_cancer_y']]
    to_merge_df = to_merge_df.rename( columns={'biopsy_id_y': 'biopsy_id', 'stage_y': 'stage', 'metastatic_cancer_y': 'metastatic_cancer'} )
    print(np.unique(to_merge_df.stage.values, return_counts=True))

    
    current_fold_final_df_merged = pd.concat( [current_fold_final_df, to_merge_df], axis=0 ).drop('patient_num_ngsci', axis=1)
    print(current_fold_final_df_merged.shape, np.unique(current_fold_final_df_merged.stage.values, return_counts=True))
    
    if f < 5: # save splitted training and validation folds to memory
        folds_multi_stratified_all.append( current_fold_final_df_merged )
        #.to_csv( 'cv_splits_multi_stratified/test_split_multi_stratified.csv', index=False)
    else: # save test fold
        pass
        #current_fold_final_df_merged.to_csv( 'cv_splits_multi_stratified/test_split_multi_stratified.csv', index=False)   # not needed due to sklearn laters
        
folds_multi_stratified_all = np.array( folds_multi_stratified_all, dtype=object )

### Get all training and validation folds and save

In [None]:
for i in range(5):
    print(f'CURRENT FOLD: {i}')
    training_idx = np.delete(np.arange(5), 4-i )
    validation_idx = 4-i
    print(training_idx, validation_idx)
    
    current_training_df = pd.concat( folds_multi_stratified_all[ training_idx ] )
    #current_training_df.to_csv( f'cv_splits_multi_stratified/train_split_multi_stratified_{i}.csv', index=False) # not needed due to sklearn laters
    current_validation_df = folds_multi_stratified_all[ validation_idx ]
    #current_validation_df.to_csv( f'cv_splits_multi_stratified/val_split_multi_stratified_{i}.csv', index=False) # not needed due to sklearn laters
    print( current_training_df.head(1) )
    print( current_validation_df.head(1), '\n\n' )