# Space

In [None]:
import sys
import os 
import logging
import pandas as pd
from pprint import pprint 
from IPython.display import display, HTML

KEY = '1-WORKSPACE'
WORKSPACE_PATH = os.getcwd().split(KEY)[0]
print(WORKSPACE_PATH); os.chdir(WORKSPACE_PATH)
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

SPACE = {
    'DATA_RAW': f'_Data/0-Data_Raw',
    'DATA_RFT': f'_Data/1-Data_RFT',
    'DATA_CASE': f'_Data/2-Data_CASE',
    'DATA_AIDATA': f'_Data/3-Data_AIDATA',
    'DATA_EXTERNAL': f'code/external',
    'CODE_FN': f'code/pipeline', 
}
assert os.path.exists(SPACE['CODE_FN']), f'{SPACE["CODE_FN"]} not found'

print(SPACE['CODE_FN'])
sys.path.append(SPACE['CODE_FN'])

# CF Data

In [None]:
from recfldtkn.aidata_base.entry import EntryAIData_Builder
import datasets

OneAIDataName = 'DietEventBench'
CF_DataName = 'DietEvent-CGM5MinEntry-1ea9d787eef20fb7'
CohortName_list = ['WellDoc2022CGM']
CF_DataName_list = [f'{i}/{CF_DataName}' for i in CohortName_list]


entry = EntryAIData_Builder(SPACE = SPACE)

dataset = entry.merge_one_cf_dataset(CF_DataName_list)
data_config = dataset.info.config_name 
print('total', dataset)

In [None]:
CFName = 'HM5MinStep'

interval_delta = pd.Timedelta(minutes=5)
idx2tkn = [
    pd.Timestamp('2022-01-01 00:00:00') + interval_delta * i for i in range(24 * 12)
]
idx2tkn = [f'{i.hour:02d}:{i.minute:02d}' for i in idx2tkn]
tkn2idx = {tkn: idx for idx, tkn in enumerate(idx2tkn)}
CF_to_CFvocab = data_config['CF_to_CFvocab']
CF_to_CFvocab[CFName] = {'idx2tkn': idx2tkn, 'tkn2idx': tkn2idx, }

In [None]:
# CF_to_CFvocab

In [None]:
####### should be a split here #######
Data = {'ds_case': dataset}

# INPUT: Mto1Period_1TknInStep

## Args

In [None]:
OneEntryArgs = {
    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': '1TknInStepWt5MinHM',
        'CF_list': [
            'CGMValueBf24h',
            # 'CGMValueAf2h',
        ],
        'BeforePeriods': ['Bf24h'],
        # 'AfterPeriods': ['Af2h'],
        'TimeIndex': True, 
        'InferenceMode': False, # True, # True, # False, # True, 
        'TargetField': 'CGMValue',
        'TargetRange': [40, 400],
        # 'HM': None, 
        'HM': {'start': -24, 'unit': 'h', 'interval': '5m'},
    }, 
}

EntryInputMethod = OneEntryArgs['Input_Part']['EntryInputMethod']

## Function

In [None]:
import itertools 
import inspect


## %%%%%%%%%%%%%%%%%%%%% user functions
def get_INPUT_CFs(OneEntryArgs):
    Input_Part = OneEntryArgs['Input_Part']
    CF_list = Input_Part['CF_list']
    ############################ # INPUT_CFs
    assert type(CF_list) == list, f'InputCFs must be a list, but got {type(CF_list)}'
    # INPUT_CFs = sorted(InputCFs_Args)
    INPUT_CFs = CF_list

    InferenceMode = Input_Part['InferenceMode'] 
    BeforePeriods = Input_Part['BeforePeriods']
    # TargetField = Input_Part['TargetField']
    if InferenceMode == True:
        INPUT_CFs = [i for i in INPUT_CFs if any([j in i for j in BeforePeriods])]

    ############################
    return INPUT_CFs

get_INPUT_CFs.fn_string = inspect.getsource(get_INPUT_CFs)

In [None]:
import inspect
import torch 
import numpy as np 

## %%%%%%%%%%%%%%%%%%%%% user functions
import itertools

def tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab):
    # 1. grab your input CF names and the target‐range bounds
    INPUT_CFs    = get_INPUT_CFs(OneEntryArgs)                           # e.g. ['CGMValueBf24h', …]
    low, high    = OneEntryArgs['Input_Part']['TargetRange']             # e.g. [40, 400]

    # 2. pull out the raw "--tid" lists for each CF
    #    examples[f"{cf}--tid"] is assumed to be a list of lists (len = batch size)
    tid_lists = [examples[f"{cf}--tid"] for cf in INPUT_CFs]

    # 3. for each example in the batch, clamp each sequence to [low,high] and flatten
    #    we do this all in Python lists + numpy.clip, which is far faster than DataFrame/apply
    flat_seqs = []
    for per_cf_seqs in zip(*tid_lists):
        # per_cf_seqs is a tuple like (seq_cf1, seq_cf2, …) for one example
        clamped = []
        for seq in per_cf_seqs:
            # numpy.clip can work on any sequence type
            arr = np.clip(seq, low, high)
            clamped.extend(arr.tolist())
        flat_seqs.append(clamped)

    # 4. stack into one LongTensor [batch_size, total_seq_length]
    input_ids = torch.tensor(flat_seqs, dtype=torch.long)

    # length_each_cf = [len(i) for i in tid_lists[0]]
    now_list = examples['ObsDT']
    HM_seq_list = []
    HM_args = OneEntryArgs['Input_Part'].get('HM', None)


    CFName = 'HM5MinStep'

    # columns_tid = [i for i in examples if '--tid' in i and CFName in i]
    
    tkn2idx = CF_to_CFvocab[CFName]['tkn2idx']
    if HM_args is not None:
        HM_start = HM_args['start']
        HM_unit = HM_args['unit']
        HM_interval = HM_args['interval']
        if HM_interval == '5m':
            interval_delta = pd.Timedelta(minutes=5)
        else:
            raise ValueError(f"Not implemented interval: {HM_interval}")
        
        for now in now_list:
            # HM_now = f'{now.hour}:{now.minute}'
            HM_start_dt = now + pd.Timedelta(value=HM_start, unit=HM_unit)

            length = len(input_ids[0])


            # HM_end = now + pd.Timedelta(hours=HM_end)
            # HM_now = f'{now.hour}:{now.minute}'
            
            HM_seq = [HM_start_dt + i * interval_delta for i in range(length)]
            HM_seq = [tkn2idx[f'{i.hour:02d}:{i.minute:02d}'] for i in HM_seq]
            HM_seq_list.append(HM_seq)
        
    return {
        'input_ids': input_ids,
        'hm_ids': torch.tensor(HM_seq_list, dtype=torch.long),
        # you could also add labels here, e.g.
        # 'labels': input_ids.clone()
    }

tfm_fn_AIInputData.fn_string = inspect.getsource(tfm_fn_AIInputData)

In [None]:
examples = dataset[:4]


CF_to_CFvocab = data_config['CF_to_CFvocab']
examples_tfm = tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)
examples_tfm

In [None]:
import inspect
import numpy as np
from scipy.sparse import csr_matrix, hstack
import itertools

def entry_fn_AIInputData(Data, 
                         CF_to_CFvocab, 
                         OneEntryArgs,
                         tfm_fn_AIInputData = None):
    
    # Input feaures. 
    # INPUT_CFs = get_INPUT_CFs(OneEntryArgs)
    # print(INPUT_CFs)
    transform_fn = lambda examples: tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)
    # ds_case 
    ds_case = Data['ds_case']
    if type(ds_case) == pd.DataFrame:
        ds_case = datasets.Dataset.from_pandas(ds_case) 
    ds_case.set_transform(transform_fn)
    ds_tfm = ds_case
    Data['ds_tfm'] = ds_tfm
    return Data

entry_fn_AIInputData.fn_string = inspect.getsource(entry_fn_AIInputData)

In [None]:
# Data
# Data = {'ds_case': dataset_all}

print([i for i in Data])


CF_to_CFvocab = data_config['CF_to_CFvocab']
Data = entry_fn_AIInputData(Data, 
                            CF_to_CFvocab, 
                            OneEntryArgs,
                            tfm_fn_AIInputData)
ds_tfm = Data['ds_tfm']
ds_tfm

In [None]:
batch = ds_tfm[:4]
batch

## Save Entry Fn

In [None]:
from recfldtkn.aidata_base.entry import AIDATA_ENTRYINPUT_PATH
from recfldtkn.base import Base

pypath = os.path.join(SPACE['CODE_FN'],  AIDATA_ENTRYINPUT_PATH, f'{EntryInputMethod}.py')
# print(pypath) 

prefix = [
    'import itertools',
    'import pandas as pd', 
    'import numpy as np', 
    'import datasets',
    'import torch',
    'import datasets',
    ]

fn_variables = [
    get_INPUT_CFs,
    tfm_fn_AIInputData,
    entry_fn_AIInputData,
]

pycode = Base.convert_variables_to_pystirng(fn_variables = fn_variables, prefix = prefix)

print(pypath)
if not os.path.exists(os.path.dirname(pypath)): os.makedirs(os.path.dirname(pypath))
with open(pypath, 'w') as file: file.write(pycode)