In [None]:
import sys
import os
import logging
import pandas as pd
import datasets
from pprint import pprint
KEY = '2-NOTEBOOK'
WORKSPACE_PATH = os.getcwd().split(KEY)[0]
print(WORKSPACE_PATH); os.chdir(WORKSPACE_PATH)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

SPACE = {
    'DATA_RAW': f'_Data/0-Data_Raw',
    'DATA_RFT': f'_Data/1-Data_RFT',
    'DATA_CASE': f'_Data/2-Data_CASE',
    'DATA_AIDATA': f'_Data/3-Data_AIDATA',
    'DATA_EXTERNAL': f'code/external',
    'CODE_FN': f'code/pipeline',
    'MODEL_ROOT': f'./_Model',
}
assert os.path.exists(SPACE['CODE_FN']), f'{SPACE["CODE_FN"]} not found'
print(SPACE['CODE_FN'])
sys.path.append(SPACE['CODE_FN'])

In [None]:
# 2021-12-03:00:00 before 24h. 
# 2021-12-03:00:00

In [None]:
DATA_SPLIT = '_Data/4-Data_Split'
path = os.path.join(DATA_SPLIT, 'WellDoc_full.parquet')
df_case_all = pd.read_parquet(path)
# df_case_all
# df_case_all['hour'] = df_case_all['ObsDT'].dt.hour

idx = df_case_all['ObsDT'].dt.minute == 0
df_case_hours = df_case_all[idx].reset_index(drop = True)
# df_case_hours

df_case = df_case_hours

idx1 = df_case['CGMInfoBf24h-Num-RecNum'] >= 289
idx2 = df_case['CGMInfoAf2h-Num-RecNum'] >= 24
idx3 = df_case['CGMInfoAf2to8h-Num-RecNum'] >= 12 * 6
df_case_good = df_case[idx1 & idx2 & idx3].reset_index(drop = True)

In [None]:
df_patday_all =df_case_good[['PID', 'Date']].value_counts().sort_index().reset_index()

In [None]:
import pandas as pd
DATA_SPLIT = '_Data/4-Data_Split'
path = os.path.join(DATA_SPLIT, 'WellDoc_patient_split_info.parquet')
df_patient_info = pd.read_parquet(path)
df_patient_info

In [None]:
idx1 = df_patient_info['n_early_days'] >= 15
idx2 = df_patient_info['n_middle_days'] >= 2
idx3 = df_patient_info['n_late_days'] >= 3

idx4 = df_patient_info['split'] != 'unassigned'

print(df_patient_info.shape)
df_patient_selected = df_patient_info[idx1 & idx2 & idx3 & idx4].reset_index(drop = True)
print(df_patient_selected.shape)
df_patient_selected

In [None]:
df_case_good.shape

# df_case_good

In [None]:
df_patient_selected.columns

In [None]:
df_patient_stratum = df_patient_selected[['PID', 'split', 'stratum', 'middle_first_date','late_first_date']]
# df_with_stratum = pd.merge(results['df'], patient_stratum, on='PID', how='left')

df_patient_stratum = df_patient_stratum.reset_index(drop = True)
df_patient_stratum['middle_first_date'] = pd.to_datetime(df_patient_stratum['middle_first_date'])
df_patient_stratum['late_first_date'] = pd.to_datetime(df_patient_stratum['late_first_date'])
# df_patday_all 

In [None]:
df_patday_all =df_case_good[['PID', 'Date']].value_counts().sort_index().reset_index()

idx = df_patday_all['count'] >= 24
print(df_patday_all.shape)
df_patday_good = df_patday_all[idx].reset_index(drop = True)
print(df_patday_good.shape)
df_patday_good = pd.merge(df_patday_good, df_patient_stratum, on='PID')
print(df_patday_good.shape)
df_patday_good['Date'] = pd.to_datetime(df_patday_good['Date'])
df_patday_good

In [None]:
def get_early_middle_late_label(row):

    if row['Date'] >= row['middle_first_date'] and row['Date'] < row['late_first_date']:
        return 'middle'
    elif row['Date'] >= row['late_first_date']:
        return 'late'
    else:
        return 'early'

df_patday_good['time_bin'] = df_patday_good.apply(get_early_middle_late_label, axis=1)
df_patday_good['time_bin'].value_counts().sort_index()

# df_patday_good['label'].value_counts().sort_index()


In [None]:
df_with_stratum = df_patday_good

In [None]:
stratum_list = df_with_stratum['stratum'].unique()

stratum_stats = df_with_stratum.groupby('stratum').agg({
    'PID': ['nunique', 'count']
}).reset_index()
stratum_stats.columns = ['stratum', 'n_patients', 'n_days']

stratum_stats

In [None]:
insufficient_strata = stratum_stats[
    (stratum_stats['n_patients'] < 13) | 
    (stratum_stats['n_days'] < 400)
]

if len(insufficient_strata) > 0:
    print("WARNING: The following strata have insufficient data:")
    print(insufficient_strata)
    raise ValueError("Some strata have insufficient patients or days")


In [None]:
df_with_stratum

In [None]:
df_patient_selected[['stratum', 'split']].value_counts().sort_index().reset_index()

In [None]:
#  ------------ Process each stratum ------------

final_df = pd.DataFrame()


mini_set_day_info = {
    'train': {'total': 10, 'early': 15, 'middle': 2, 'late': 3},
    'val':   {'total': 2,  'early': 15, 'middle': 2, 'late': 3},
    'test':  {'total': 3,  'early': 0, 'middle': 0, 'late': 0, 'recent': 10}
}

for stratum in stratum_list:
    # print(f"Processing stratum: {stratum}")
    df_one_group = df_with_stratum[df_with_stratum['stratum'] == stratum].copy()

    # Initialize final collection
    final_rows = []

    # Split by 'train', 'val', 'test'
    split_group = df_one_group.groupby(df_one_group['split'].str.split('-').str[0])

    for split_name, group_df in split_group:
        if split_name not in ['train', 'val', 'test']:
            continue
        # print(f"Processing split: {split_name} ({len(group_df)} rows)")
        split_info = mini_set_day_info[split_name]

        # Shuffle PIDs
        pids = group_df['PID'].unique()
        shuffled_pids = pd.Series(pids).sample(frac=1, random_state=42).tolist()

        selected_count = 0
        for pid in shuffled_pids:
            patient_df = group_df[group_df['PID'] == pid].sort_values('Date')
            counts = patient_df['time_bin'].value_counts()

            # --- Train & Val logic ---
            if split_name in ['train', 'val']:
                if (counts.get('early', 0) >= split_info['early'] and
                    counts.get('middle', 0) >= split_info['middle'] and
                    counts.get('late', 0) >= split_info['late']):
                    
                    subset = pd.concat([
                        patient_df[patient_df['time_bin'] == 'early'].tail(split_info['early']),
                        patient_df[patient_df['time_bin'] == 'middle'].head(split_info['middle']),
                        patient_df[patient_df['time_bin'] == 'late'].head(split_info['late']),
                    ])
                    subset['split_timebin'] = f"{split_name}-" + subset['time_bin']
                    final_rows.append(subset[['PID', 'Date', 'stratum', 'split_timebin']])
                    selected_count += 1

            # --- Test logic ---
            elif split_name == 'test':
                if len(patient_df) >= split_info['recent']:
                    subset = patient_df.tail(split_info['recent']).copy()
                    subset['split_timebin'] = f"{split_name}-recent"
                    final_rows.append(subset[['PID', 'Date', 'stratum', 'split_timebin']])
                    selected_count += 1

            if selected_count == split_info['total']:
                break

    # Final result
    stratum_final_df = pd.concat(final_rows).reset_index(drop=True)
    final_df = pd.concat([final_df, stratum_final_df])


final_df

In [None]:
final_df[['split_timebin']].value_counts().sort_index().reset_index()

In [None]:
len(final_df) / 12 

In [None]:
final_df

In [None]:
15 * 12

In [None]:
df_case_good.head()

In [None]:
df_case_good['Date'] = pd.to_datetime(df_case_good['Date'])

df_case_fairglucose = pd.merge(df_case_good, final_df, on = ['PID', 'Date'])

len(df_case_fairglucose) / len(final_df)

In [None]:
df_case_fairglucose

In [None]:
path = os.path.join(DATA_SPLIT, 'WellDoc_ds_case_fairglucose_split.parquet')


df_case_fairglucose.to_parquet(path)

In [None]:
df_count = df_case_fairglucose[['stratum', 'split_timebin']].value_counts().sort_index().reset_index()
df_table = df_count.pivot(index = 'stratum', columns = 'split_timebin', values = 'count').fillna(0)


cols = [ 'train-early', 'train-middle', 'train-late', 'val-early', 'val-middle', 'val-late', 'test-recent']

df_table = df_table[cols]
df_table