# Space

In [None]:
import os
import logging
import pandas as pd 
from pprint import pprint 
from IPython.display import display, HTML
pd.set_option('display.max_columns', None)
KEY = 'WorkSpace'
WORKSPACE_PATH = os.getcwd().split(KEY)[0] + KEY
# print(WORKSPACE_PATH)
os.chdir(WORKSPACE_PATH)
import sys
from proj_space import SPACE
sys.path.append(SPACE['CODE_FN'])
SPACE['WORKSPACE_PATH'] = WORKSPACE_PATH

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

from datasets import disable_caching
disable_caching()

# Step 1. AIData

In [None]:
# Oneday: 288, 24pd. 1/12
from datasets import load_from_disk

AIDataName = 'CGM_32h_24pd_WellDoc_v2' # v2 6 cohorts. 
path = os.path.join(SPACE['DATA_AIDATA'], AIDataName)
print(path)
dataset = load_from_disk(path)
dataset

In [None]:
Data = {
    'ds_case': dataset,
}

In [None]:
config = dataset.info.__dict__['config_name']# .features['cf'].feature.vocab
print([i for i in config])
CF_to_CFvocab = config['CF_to_CFvocab']
print([i for i in CF_to_CFvocab])

# Step 2: EntryFn - Input_Part

## Args

In [None]:
OneEntryArgs = {
    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_1TknInStep',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
        ],
        'TargetField': 'TargetCGM',
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # True, # True, # False, # True, 
    }, 
}

EntryInputMethod = OneEntryArgs['Input_Part']['EntryInputMethod']

## Function Develop

In [None]:
import torch 
import datasets
import inspect
import numpy as np
import itertools


## %%%%%%%%%%%%%%%%%%%%% user functions
def get_INPUT_CFs(OneEntryArgs):
    Input_Part = OneEntryArgs['Input_Part']
    CF_list = Input_Part['CF_list']
    ############################ # INPUT_CFs
    assert type(CF_list) == list, f'InputCFs must be a list, but got {type(CF_list)}'
    # INPUT_CFs = sorted(InputCFs_Args)
    INPUT_CFs = CF_list

    InferenceMode = Input_Part['InferenceMode'] 
    BeforePeriods = Input_Part['BeforePeriods']
    # TargetField = Input_Part['TargetField']
    if InferenceMode == True:
        INPUT_CFs = [i for i in INPUT_CFs if any([j in i for j in BeforePeriods])]

    ############################
    return INPUT_CFs


def tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab):
    INPUT_CFs = get_INPUT_CFs(OneEntryArgs)
    examples_tfm = {}
    df = pd.DataFrame({cf: examples[cf + '--input_ids'] for cf in INPUT_CFs})
    df['input_ids'] = df.apply(lambda x: list(itertools.chain(*x.values)), axis=1)
    examples_tfm['input_ids'] = torch.LongTensor(np.array(df['input_ids'].to_list())) # ().copy()
    # move this part to the TaskPart. 
    # examples_tfm['labels'] = torch.LongTensor(np.array(df['input_ids'].to_list().copy()))
    return examples_tfm


def entry_fn_AIInputData(Data, 
                         CF_to_CFvocab, 
                         OneEntryArgs,
                         tfm_fn_AIInputData = None):
    
    # Input feaures. 
    # INPUT_CFs = get_INPUT_CFs(OneEntryArgs)
    # print(INPUT_CFs)
    transform_fn = lambda examples: tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)

    # ds_case 
    ds_case = Data['ds_case']
    if type(ds_case) == pd.DataFrame:
        ds_case = datasets.Dataset.from_pandas(ds_case) 
    ds_case.set_transform(transform_fn)
    ds_tfm = ds_case
    Data['ds_tfm'] = ds_tfm
    return Data


get_INPUT_CFs.fn_string = inspect.getsource(get_INPUT_CFs)
tfm_fn_AIInputData.fn_string = inspect.getsource(tfm_fn_AIInputData)
entry_fn_AIInputData.fn_string = inspect.getsource(entry_fn_AIInputData)

## Examine

In [None]:
Data = entry_fn_AIInputData(Data, 
                            CF_to_CFvocab, 
                            OneEntryArgs,
                            tfm_fn_AIInputData)


ds_tfm = Data['ds_tfm']
ds_tfm

In [None]:
batch = ds_tfm[:4]
batch

In [None]:
for k, v in batch.items():
    print(k, v.shape)

##  Save Entry Fn

In [None]:
from recfldtkn.aidata_base.entry import AIDATA_ENTRYINPUT_PATH
from recfldtkn.base import Base

pypath = os.path.join(SPACE['CODE_FN'],  AIDATA_ENTRYINPUT_PATH, f'{EntryInputMethod}.py')
# print(pypath) 

prefix = [
    'import itertools',
    'import pandas as pd', 
    'import numpy as np', 
    'import datasets',
    'import torch',
    'import datasets',
    ]

fn_variables = [
    get_INPUT_CFs,
    tfm_fn_AIInputData,
    entry_fn_AIInputData,
]

pycode = Base.convert_variables_to_pystirng(fn_variables = fn_variables, prefix = prefix)

print(pypath)
if not os.path.exists(os.path.dirname(pypath)): os.makedirs(os.path.dirname(pypath))
with open(pypath, 'w') as file: file.write(pycode)

# Step 3: EntryFn - Output_Part 

## Args

In [None]:
# TaskType = 'MLUniLabel'
SeriesName  = 'Bf24.Af2H'

OneTaskName = 'cgm_bf24h_af2h_5min'
OneEntryArgs = {
    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_1TknInStep',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
        ],
        'TargetField': 'TargetCGM',
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # True, # True, # False, # True, 
    }, 

    # ----------------- Output Part -----------------
    'Output_Part': {
        'EntryOutputMethod': 'NTP',
        'set_transform': True,
    },

    # ----------------- Task Part -----------------
    'Task_Part': {
        'Tagging': [],
        'Filtering': [], 
    },
}

# Data = {'df_case': caseset.df_case, 'ds_case': caseset.ds_case}

EntryOutputMethod = OneEntryArgs['Output_Part']['EntryOutputMethod']
print([i for i in CF_to_CFvocab])

## Function

In [None]:
## %%%%%%%%%%%%%%%%%%%%%
# UniLabel
import inspect 
import numpy as np 
# from recfldtkn.loadtools import convert_variables_to_pystirng


def get_OUTPUT_CFs(OneEntryArgs):
    if 'Output_Part' not in OneEntryArgs:
        return []
    else:
        return OneEntryArgs['Output_Part'].get('CF_list', [])
get_OUTPUT_CFs.fn_string = inspect.getsource(get_OUTPUT_CFs)


def entry_fn_AITaskData(Data, 
                        CF_to_CFvocab, 
                        OneEntryArgs,
                        tfm_fn_AIInputData = None,
                        entry_fn_AIInputData = None,
                        ):

    # InputCFs = OneEntryArgs['Input_FullArgs']['INPUT_CFs_Args']['InputCFs']


    def transform_fn_output(examples, tfm_fn_AIInputData, OneEntryArgs, CF_to_CFvocab):
        examples_tfm = tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)
        # examples_tfm['labels'] = torch.LongTensor([[i] for i in examples['Labeling']])
        examples_tfm['labels'] = examples_tfm['input_ids'].clone() 
        return examples_tfm
    
    transform_fn = lambda examples: transform_fn_output(examples, tfm_fn_AIInputData, OneEntryArgs, CF_to_CFvocab)
    ds_case = Data['ds_case']

    if type(ds_case) == pd.DataFrame:
        ds_case = datasets.Dataset.from_pandas(ds_case)
        
    ####### You can either set transform, or map it. 
    if OneEntryArgs['Output_Part']['set_transform'] == True:
        ds_case.set_transform(transform_fn)
    else:
        ds_case = ds_case.map(transform_fn)
    ds_tfm = ds_case
    Data['ds_tfm'] = ds_tfm
    return Data

entry_fn_AITaskData.fn_string = inspect.getsource(entry_fn_AITaskData)

In [None]:
Data = entry_fn_AITaskData(Data, 
                           CF_to_CFvocab, 
                           OneEntryArgs,
                           tfm_fn_AIInputData,
                           entry_fn_AIInputData)

ds_tfm = Data['ds_tfm']
ds_tfm

In [None]:
batch = ds_tfm[:4]
for k, v in batch.items():
    print(k, v.shape)


In [None]:
from recfldtkn.base import Base
from recfldtkn.aidata_base.entry import AIDATA_ENTRYOUTPUT_PATH

prefix = [
    'import torch',
    'import pandas as pd', 
    'import numpy as np', 
    'import datasets',
    ]
fn_variables = [
    get_OUTPUT_CFs,
    entry_fn_AITaskData,
]
pycode = Base.convert_variables_to_pystirng(fn_variables = fn_variables, prefix = prefix)
pypath = os.path.join(SPACE['CODE_FN'], AIDATA_ENTRYOUTPUT_PATH, f'{EntryOutputMethod}.py')
print(pypath)
if not os.path.exists(os.path.dirname(pypath)): os.makedirs(os.path.dirname(pypath))
with open(pypath, 'w') as file: file.write(pycode)