# Space

In [None]:
import os
import logging
import pandas as pd 
from pprint import pprint 
from IPython.display import display, HTML
pd.set_option('display.max_columns', None)
KEY = 'WorkSpace'
WORKSPACE_PATH = os.getcwd().split(KEY)[0] + KEY
# print(WORKSPACE_PATH)
os.chdir(WORKSPACE_PATH)
import sys
from proj_space import SPACE
sys.path.append(SPACE['CODE_FN'])
SPACE['WORKSPACE_PATH'] = WORKSPACE_PATH

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

from datasets import disable_caching
disable_caching()

# Step 1: Record and Case Base

In [None]:
from config.config_case.GROUP import GROUP_TO_GROUPMethodArgs
from config.config_case.CF import CF_to_CFArgs
from config.config_case.CKPD import Ckpd_to_CkpdObsConfig
from config.config_case.TagRec import TagRec_to_TagRecArgs
from config.config_case.TagCF import TagCF_to_TagCFArgs 
from config.config_case.Flt import FltName_to_FltArgs
from config.config_case.CASE import TriggerCaseBaseName_to_TriggerCaseBaseArgs

from config.config_record.Cohort import CohortName_to_OneCohortArgs
from config.config_case.CKPD import Ckpd_to_CkpdObsConfig

from recfldtkn.record_base import Record_Base
from recfldtkn.case_base.case_base import Case_Base

CohortNames = [i for i in CohortName_to_OneCohortArgs.keys()]
print(CohortNames)

In [None]:
###################################
Inference_Entry = None # this is not inference mode
Case_Args_Settings = {
    'Ckpd_to_CkpdObsConfig': Ckpd_to_CkpdObsConfig,
    'CF_to_CFArgs': CF_to_CFArgs,
    'TagCF_to_TagCFArgs': TagCF_to_TagCFArgs,
    'TagRec_to_TagRecArgs': TagRec_to_TagRecArgs,
    'FltName_to_FltArgs': FltName_to_FltArgs,
    'GROUP_TO_GROUPMethodArgs': GROUP_TO_GROUPMethodArgs,
}


Record_Proc_Config = {
    'save_data': True, 
    'load_data':True, 
    'via_method': 'ds',
}

Case_Proc_Config = {
    'max_trigger_case_num': None, 
    'use_task_cache': False, 
    'caseset_chunk_size': 200000, # 200k for CGM, 50k for others.
    'save_data': True, 
    'load_data': True, 
    'load_casecollection': True, 
    'via_method': 'ds',
    'n_cpus': 1, 
    'batch_size': 1000,  
}
###################################  

CohortName_list = [ 
    'WellDoc2023CVSDeRx',
]

TriggerCaseBaseName = 'Bf24HAf2H_CGM_And_Event'
TriggerCaseBaseArgs =  {
    # --------- this three are relatively stable ----------------
    'Trigger': {
        'TriggerName': 'CGM5MinEntry', 
        'TagRec': [
            'TagRec.PDemoFromP',
        ],
        'Group': 'GrpGenderDisease', # 
        'Filter': 'FltBasicDemo',
        'ObsTask': {
            'TagCF_list': [
                'TagCF.Bf24hCGMinfo', 
                'TagCF.Af2hCGMinfo',
            ],
            'CF_list':  [],
        }
    },
    # --------------------------------

    # --------------------------------
    'FilterCaseSet': {
        'Filter': 'FltMiniBfAfCGMRecInfo',
        'ObsTask': {
            'TagCF_list': [
                'TagCF.Bf24hCGMinfo', 
                'TagCF.Af2hCGMinfo',
            ],
            'CF_list':  [
                'cf.TargetCGM_Bf24H', 
                'cf.TargetCGM_Af2H',
            ],
        },
    },

    # --------------------------------
    'FilterCaseSetWithAnyEvents': {
        'Filter': 'FltWithBf24hAf2HAf2Ht8H-MEDAL-OR',
        # 'Filter': 'FltWithBf24hFood', 
        'ObsTask': {
            'TagCF_list': [
                'TagCF.Bf24hCGMinfo', 
                'TagCF.Af2hCGMinfo',
                
                'TagCF.Bf24hRecNum',
                'TagCF.Af2hRecNum',
                # 'TagCF.Af2ht8hRecNum',
            ],

            'CF_list':  [
                'cf.PDemo',
                'cf.TargetCGM_Bf24H', 
                'cf.TargetCGM_Af2H',
                # 'cf.TargetCGM_Af2Hto8H',
                'cf.TimeSparse_Bf24H', 
                'cf.TimeSparse_Af2H',
                # 'cf.TimeSparse_Af2Hto8H',
                'cf.DietSparse_Bf24H',
                'cf.DietSparse_Af2H',
                # 'cf.DietSparse_Af2Hto8H',
                ],
        },
    },
    # --------------------------------
}

TriggerCaseBaseName_to_TriggerCaseBaseArgs[TriggerCaseBaseName] = TriggerCaseBaseArgs
pprint(TriggerCaseBaseArgs, sort_dicts=False)

In [None]:
from recfldtkn.check import update_and_assert_CaseInfo
from recfldtkn.check import retrive_pipeline_info
PIPELINE_INFO = retrive_pipeline_info(SPACE)


CaseSettingInfo = update_and_assert_CaseInfo(
                                TriggerCaseBaseName,
                                TriggerCaseBaseArgs,
                                Case_Args_Settings,
                                Case_Proc_Config, 
                                PIPELINE_INFO, 
                                )

HumanRecordRecfeat_Args = CaseSettingInfo['HumanRecordRecfeat_Args']
record_base = Record_Base(CohortName_list, 
                            HumanRecordRecfeat_Args,
                            CohortName_to_OneCohortArgs,
                            SPACE = SPACE, 
                            Inference_Entry = Inference_Entry,
                            Record_Proc_Config = Record_Proc_Config,
                            )

In [None]:
TriggerCaseBaseName_to_CohortNameList = {
    TriggerCaseBaseName: CohortName_list,
}

TriggerCaseBaseName_to_CohortNameList

case_base = Case_Base(
    record_base = record_base, 
    TriggerCaseBaseName_to_CohortNameList = TriggerCaseBaseName_to_CohortNameList, 
    TriggerCaseBaseName_to_TriggerCaseBaseArgs = TriggerCaseBaseName_to_TriggerCaseBaseArgs,
    Case_Proc_Config = Case_Proc_Config,
    Case_Args_Settings = Case_Args_Settings, 
)

In [None]:
CaseSetNameToCaseset = case_base.TriggerCaseBaseName_to_CaseSetNameToCaseset[TriggerCaseBaseName]
CaseSetNameToCaseset

In [None]:
for name, caseset in CaseSetNameToCaseset.items(): pass 
caseset

In [None]:
caseset.ds_case

In [None]:
df_case = caseset.df_case

In [None]:
[CF for CF in case_base.TriggerCaseBaseName_to_CFtoCFvocab[TriggerCaseBaseName]]

## Prepare Data

In [None]:
from recfldtkn.base import apply_multiple_conditions
import numpy as np 

def filter_data(Data, rules):
    special_column = 'selected'
    df_case = Data['df_case']
    ds_case = Data['ds_case']

    result = apply_multiple_conditions(df_case, rules)
    df_case[special_column] = result
    ds_case = ds_case.add_column(special_column, df_case[special_column].values)
    # print(df_case.shape)

    df_case_filter = df_case[df_case[special_column]].reset_index(drop=True)    
    # print(df_case_filter.shape)

    filter_array = np.array(ds_case[special_column])
    indices = np.where(filter_array == 1)[0]
    ds_case_filter = ds_case.select(indices)
    # print(len(ds_case_filter))

    return {'df_case': df_case_filter, 'ds_case': ds_case_filter}

# Example usage
rules = [
    ['co.Bf24H_Food_recnum:recnum', '>=', 1],
    # ['co.Bf24H_Carb_recnum:recnum', '>=', 1],
]

# caseset
Data = {'df_case': caseset.df_case, 'ds_case': caseset.ds_case}
print(Data['df_case'].shape)
Data = filter_data(Data, rules)
print(Data['df_case'].shape)

In [None]:
CF_to_CFvocab = case_base.TriggerCaseBaseName_to_CFtoCFvocab[TriggerCaseBaseName]
print([i for i in CF_to_CFvocab])

# Step 2: EntryFn - Input_Part

## Args

In [None]:
OneEntryArgs = {
    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_MultiTknInStep',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
            # 'cf.TimeSparse_Bf24H', 
            # 'cf.TimeSparse_Af2H',
            # 'cf.DietSparse_Bf24H',
            # 'cf.DietSparse_Af2H',
        ],
        'TargetField': 'TargetCGM',
        # 'TimeField':   'Time',
        # 'EventFields': [
        #     'Diet',
        # ],
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # 'WithFutureEvent' #  # 'NoFutureEvent', 'WithFutureEvent', 
    }, 
}

EntryInputMethod = OneEntryArgs['Input_Part']['EntryInputMethod']

## InputCFs

In [None]:
import torch 
import datasets
import inspect
import numpy as np
from scipy.sparse import csr_matrix, hstack
import itertools

## %%%%%%%%%%%%%%%%%%%%% user functions
def get_INPUT_CFs(OneEntryArgs):
    Input_Part = OneEntryArgs['Input_Part']
    CF_list = Input_Part['CF_list']
    ############################ # INPUT_CFs
    assert type(CF_list) == list, f'InputCFs must be a list, but got {type(CF_list)}'
    # INPUT_CFs = sorted(InputCFs_Args)
    INPUT_CFs = CF_list

    InferenceMode = Input_Part['InferenceMode'] 
    BeforePeriods = Input_Part['BeforePeriods']
    TargetField = Input_Part['TargetField']
    if InferenceMode == 'NoFutureEvent':
        INPUT_CFs = [i for i in INPUT_CFs if any([j in i for j in BeforePeriods])]
    elif InferenceMode == 'WithFutureEvent':
        INPUT_CFs = [i for i in INPUT_CFs if any([j in i for j in BeforePeriods]) or TargetField not in i]

    ############################
    return INPUT_CFs

get_INPUT_CFs.fn_string = inspect.getsource(get_INPUT_CFs)

In [None]:
# EntryInputMethod = OneEntryArgs['Input_Part']['EntryInputMethod']
InputCFs = get_INPUT_CFs(OneEntryArgs)
InputCFs

In [None]:
ds_case = Data['ds_case']
ds_case

## Examples

In [None]:
examples = ds_case.shuffle(seed=42)[:64] # .select(range(5))  
# examples = ds_case[:4] 
pprint(examples, sort_dicts=False, compact=True)

In [None]:
OneEntryArgs

In [None]:
Input_Part = OneEntryArgs['Input_Part']
Input_Part

In [None]:
InputCFs = Input_Part['CF_list']
InputCFs

## TargetCF

In [None]:
TargetField = Input_Part['TargetField']
TargetField

In [None]:
TargetCFs = [i for i in InputCFs if TargetField in i]
TargetCFs

In [None]:
from datetime import datetime 


s = datetime.now()
examples_tfm = {}

############################################################
# # 0:00:00.002059
## method 1:
# df = pd.DataFrame({cf: examples[cf + '--input_ids'] for cf in TargetCFs})
# df['input_ids'] = df.apply(lambda x: list(itertools.chain(*x.values)), axis=1)
# examples_tfm['input_ids'] = torch.LongTensor(np.array(df['input_ids'].to_list())) # ().copy()


############################################################
# # 0:00:00.000868
# method 2: 
# Step 1: Directly access columns as numpy arrays
target_arrays = [np.array(examples[f"{cf}--input_ids"]) for cf in TargetCFs]
# Step 2: Concatenate along columns (axis=1) to combine features
# Assuming each array has shape (batch_size, seq_len)
stacked_ids = np.concatenate(target_arrays, axis=1)
examples_tfm['input_ids'] = torch.LongTensor(stacked_ids)
# examples_tfm['input_ids'] = stacked_ids # torch.LongTensor()


e = datetime.now()
print(f'TargetCFs: {e-s}')
examples_tfm


## Update the Emptiness 

In [None]:
def detect_empty_values(values):
    if len(values) == 1 and len(values[0]) == 1 and int(values[0][0]) == 0:
        EmptyFlag = True
    else:
        EmptyFlag = False
    return EmptyFlag
detect_empty_values.fn_string = inspect.getsource(detect_empty_values)


def update_emptiness_of_examples(examples, CF):
    # make sure your CF is an EventCF, which means your have steps. 
    batch_to_values = examples[CF + '--input_ids']
    batch_to_empty = [detect_empty_values(values) for values in batch_to_values]
    examples_updated = {}
    for items in ['input_ids', 'input_wgts', 'timestep']:
        if f'{CF}--{items}' not in examples: continue 
        batch_to_values = examples[CF + '--' + items]
        batch_to_values_updated = []
        empty_value = []
        for dp_idx, empty in enumerate(batch_to_empty):
            if empty:
                batch_to_values_updated.append(empty_value)
            else:
                batch_to_values_updated.append(batch_to_values[dp_idx])
        examples_updated[CF + '--' + items] = batch_to_values_updated
    return examples_updated
update_emptiness_of_examples.fn_string = inspect.getsource(update_emptiness_of_examples)


## tfm_fn_AIInputData

In [None]:
from datetime import datetime


def tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab):
    
    # s1 = datetime.now()
    Input_Part = OneEntryArgs['Input_Part']
    InputCFs = get_INPUT_CFs(OneEntryArgs)
    # e1 = datetime.now()
    # print(f'get_INPUT_CFs: {e1-s1}')

    examples_tfm = {}
    # s2 = datetime.now()
    # ------------------------------------------------------------ # 
    TargetField = Input_Part['TargetField']
    TargetCFs = [i for i in InputCFs if TargetField in i]
    
    # df = pd.DataFrame({cf: examples[cf + '--input_ids'] for cf in TargetCFs})
    # df['input_ids'] = df.apply(lambda x: list(itertools.chain(*x.values)), axis=1)
    # examples_tfm['input_ids'] = torch.LongTensor(np.array(df['input_ids'].to_list())) # ().copy()
    
    target_arrays = [np.array(examples[f"{cf}--input_ids"]) for cf in TargetCFs]
    stacked_ids = np.concatenate(target_arrays, axis=1) # # Assuming each array has shape (batch_size, seq_len)
    examples_tfm['input_ids'] = torch.LongTensor(stacked_ids)

    # e2 = datetime.now()
    # print(f'TargetField: {e2-s2}')


    # ------------------------------------------------------------ # 
    # already ordered
    # s3 = datetime.now()
    EventFields = Input_Part.get('EventFields', [])
    if len(EventFields) > 0:
        TimeField = Input_Part.get('TimeField', None)
        if TimeField is not None:
            EventCFs = [i for i in InputCFs if TargetField not in i and TimeField not in i]
        else:
            EventCFs = [i for i in InputCFs if TargetField not in i]
    else:
        EventCFs = []
    # e3 = datetime.now()
    # print(f'Get different Field Information: {e3-s3}')

    # s4 = datetime.now()
    # update emptiness of examples
    for EventCF in EventCFs:
        examples_updated = update_emptiness_of_examples(examples, EventCF)
        for k, v in examples_updated.items(): examples[k] = v
    # e4 = datetime.now()
    # print(f'update_emptiness_of_examples: {e4-s4}')

   
    # Multi EventCFs
    timestep_info = None
    for OneEvent in EventFields:
        # s5 = datetime.now()
        OneEventCFs = [i for i in InputCFs if OneEvent in i]
        
        ############################################################
        example_event_info = {}
        for seqtype in ['input_ids', 'input_wgts', 'timestep']:
            columns_data = [examples[f"{cf}--{seqtype}"] for cf in OneEventCFs]
            values = []
            for sample_items in zip(*columns_data):
                combined = list(itertools.chain(*sample_items))
                values.append(combined)
            example_event_info[seqtype] = values
        ############################################################

        # Precompute timestep info once per Event type
        if timestep_info is None: timestep_info = get_timestep_info(examples, OneEventCFs)

        names_orig = [i for i in example_event_info]
        event_info_final = {name: [] for name in ['input_ids', 'input_wgts', 'timestep_orig_ids', 'event_indicators']}
        max_timesteps = len(timestep_info['timestep_orig_ids'])
        max_features = 0
        for items_sample in zip(*example_event_info.values()):
            single_data_point = dict(zip(names_orig, items_sample))
            # s1 = datetime.now()
            updated_data_point = update_seqtype_base_on_timestep(single_data_point, timestep_info)
            # e1 = datetime.now()
            # print('update_seqtype_base_on_timestep: ', e1-s1)
            for name, value in updated_data_point.items():
                # print(name, len(value))
                if name == 'input_ids': max_features = max(max_features, max(len(i) for i in value))
                event_info_final[name].append(value)


        for seqtype in ['input_ids', 'input_wgts']:
            values = event_info_final[seqtype]
            values_pad = vectorized_pad(values, max_timesteps, max_features, pad_value=0)
            event_info_final[seqtype] = values_pad

        for seqtype in ['timestep_orig_ids', 'event_indicators']:
            values = event_info_final[seqtype]
            values_pad = np.array(values) 
            event_info_final[seqtype] = values_pad

        for k, v in event_info_final.items():
            if '_wgt' in k:
                event_info_final[k] = torch.FloatTensor(v)
            else:
                event_info_final[k] = torch.LongTensor(v)
            examples_tfm[OneEvent + '--' + k] = event_info_final[k]
    
        # e5 = datetime.now()
        # print(f'Multi EventCFs -- {OneEvent}: {e5-s5}')


    # s6 = datetime.now()
    # ------------------------------------------------------------ # 
    TimeField = Input_Part.get('TimeField', None)
    # TimeField
    if TimeField is not None:
        TimeCFs = [i for i in InputCFs if TimeField in i]
        CFvocab = CF_to_CFvocab[TimeCFs[0]]
        if timestep_info is None: timestep_info = get_timestep_info(examples, TimeCFs)

        time_info_final = get_timestepinfo_array(examples, timestep_info, CFvocab)
        for k, v in time_info_final.items():
            if '_wgt' in k:
                time_info_final[k] = torch.FloatTensor(v)
            else:
                time_info_final[k] = torch.LongTensor(v)

            examples_tfm[TimeField + '--' + k] = time_info_final[k]
    # e6 = datetime.now()
    # print(f'TimeField: {e6-s6}')

    return examples_tfm


tfm_fn_AIInputData.fn_string = inspect.getsource(tfm_fn_AIInputData)

In [None]:
# print('\n==============================================\n')
from datetime import datetime
s = datetime.now()
examples_tfm = tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)
e = datetime.now()
print(f'tfm_fn_AIInputData_new: {e-s}')


# old version
# get_INPUT_CFs: 0:00:00.000003
# TargetField: 0:00:00.001757
# EventFields: 0:00:00.000003
# update_emptiness_of_examples: 0:00:00.000056
# Multi EventCFs -- Diet: 0:00:00.470472
# TimeField: 0:00:00.205376
# tfm_fn_AIInputData: 0:00:00.678363


# Diet and Time
# tfm_fn_AIInputData_new: 0:00:00.047536 
# 

In [None]:
for k, v in examples_tfm.items():
    print(k, v.shape)

## entry_fn_AIInputData

In [None]:
def entry_fn_AIInputData(Data, 
                         CF_to_CFvocab, 
                         OneEntryArgs,
                         tfm_fn_AIInputData = None):
    
    # Input feaures. 
    # INPUT_CFs = get_INPUT_CFs(OneEntryArgs)
    # print(INPUT_CFs)
    transform_fn = lambda examples: tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)

    # ds_case 
    ds_case = Data['ds_case']
    if type(ds_case) == pd.DataFrame:
        ds_case = datasets.Dataset.from_pandas(ds_case) 

    use_map = OneEntryArgs.get('use_map', False)
    num_proc = OneEntryArgs.get('num_proc', 4)
    if use_map == False:
        ds_case.set_transform(transform_fn)
    else:
        ds_case = ds_case.map(transform_fn, batched = True, num_proc = num_proc)
    ds_tfm = ds_case
    Data['ds_tfm'] = ds_tfm
    return Data

tfm_fn_AIInputData.fn_string = inspect.getsource(tfm_fn_AIInputData)
entry_fn_AIInputData.fn_string = inspect.getsource(entry_fn_AIInputData)

## Examine

In [None]:
Data = entry_fn_AIInputData(Data, 
                            CF_to_CFvocab, 
                            OneEntryArgs,
                            tfm_fn_AIInputData)

ds_tfm = Data['ds_tfm']
ds_tfm

In [None]:
batch = ds_tfm[:4]
# batch

for k, v in batch.items():
    print(k, v.shape)

In [None]:

from torch.utils.data import DataLoader
import time
import numpy as np

# 1. Create DataLoader with your actual training parameters
loader = DataLoader(
    dataset=ds_tfm,  # Your dataset with set_transform
    batch_size=32,            # Use your real batch size
    num_workers=4,            # Match your training setup
    pin_memory=True,          # Same as training config
    shuffle=False             # Disable for consistent measurement
)

# 2. Warm-up run (initial batches are slower due to setup)
print("Warming up...")
for _ in loader:
    pass



# 3. Timed measurement
num_batches = len(loader)
print(f"Testing with {num_batches} batches...")

start_time = time.perf_counter()  # More precise timer
for _ in loader:
    pass
total_time = time.perf_counter() - start_time

# 4. Calculate metrics
throughput = num_batches / total_time
samples_per_sec = len(ds_tfm) / total_time

print(f"\nResults:")
print(f"- Batches/s: {throughput:.1f}")
print(f"- Samples/s: {samples_per_sec:.1f}")
print(f"- Batch time: {1000*total_time/num_batches:.1f}ms")
print(f"- Total time: {total_time:.2f}s")

# Step 3: EntryFn - Output_Part 

## Args

In [None]:
# TaskType = 'MLUniLabel'
SeriesName  = 'Bf24.Af2H'
OneTaskName = 'cgm_lhm_bf24h_af2h_5min'
OneEntryArgs = {
    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_MultiTknInStep',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
            # 'cf.TimeSparse_Bf24H', 
            # 'cf.TimeSparse_Af2H',
            # 'cf.DietSparse_Bf24H',
            # 'cf.DietSparse_Af2H',
        ],
        'TargetField': 'TargetCGM',
        # 'TimeField':   'Time',
        # 'EventFields': [
        #     'Diet',
        # ],
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # 'WithFutureEvent' #  # 'NoFutureEvent', 'WithFutureEvent', 
    }, 

    # ----------------- Output Part -----------------
    'Output_Part': {
        'EntryOutputMethod': 'NTP',
    },

    # ----------------- Task Part -----------------
    'Task_Part': {
        'Tagging': [],
        'Filtering': [], 
    },
}

# Data = {'df_case': caseset.df_case, 'ds_case': caseset.ds_case}

EntryOutputMethod = OneEntryArgs['Output_Part']['EntryOutputMethod']
CF_to_CFvocab = case_base.TriggerCaseBaseName_to_CFtoCFvocab[TriggerCaseBaseName]
print([i for i in CF_to_CFvocab])

## Function

In [None]:
## %%%%%%%%%%%%%%%%%%%%%
# UniLabel
import inspect 
import numpy as np 
# from recfldtkn.loadtools import convert_variables_to_pystirng

def get_OUTPUT_CFs(OneEntryArgs):
    if 'Output_Part' not in OneEntryArgs:
        return []
    else:
        return OneEntryArgs['Output_Part'].get('CF_list', [])
get_OUTPUT_CFs.fn_string = inspect.getsource(get_OUTPUT_CFs)


def entry_fn_AITaskData(Data, 
                        CF_to_CFvocab, 
                        OneEntryArgs,
                        tfm_fn_AIInputData = None,
                        entry_fn_AIInputData = None,
                        ):

    # InputCFs = OneEntryArgs['Input_FullArgs']['INPUT_CFs_Args']['InputCFs']


    def transform_fn_output(examples, tfm_fn_AIInputData, OneEntryArgs, CF_to_CFvocab):
        examples_tfm = tfm_fn_AIInputData(examples, OneEntryArgs, CF_to_CFvocab)
        # examples_tfm['labels'] = torch.LongTensor([[i] for i in examples['Labeling']])
        examples_tfm['labels'] = examples_tfm['input_ids'].clone() 
        return examples_tfm
    
    transform_fn = lambda examples: transform_fn_output(examples, tfm_fn_AIInputData, OneEntryArgs, CF_to_CFvocab)
    ds_case = Data['ds_case']

    if type(ds_case) == pd.DataFrame:
        ds_case = datasets.Dataset.from_pandas(ds_case)
        
    # ds_case.set_transform(transform_fn)
    use_map = OneEntryArgs.get('use_map', False)
    num_proc = OneEntryArgs.get('num_proc', 4)
    if use_map == False:
        ds_case.set_transform(transform_fn)
    else:
        ds_case = ds_case.map(transform_fn, batched = True, num_proc = num_proc)

    ds_tfm = ds_case
    Data['ds_tfm'] = ds_tfm
    
    return Data

entry_fn_AITaskData.fn_string = inspect.getsource(entry_fn_AITaskData)

In [None]:
Data = entry_fn_AITaskData(Data, 
                           CF_to_CFvocab, 
                           OneEntryArgs,
                           tfm_fn_AIInputData,
                           entry_fn_AIInputData)

ds_tfm = Data['ds_tfm']
ds_tfm

In [None]:
batch = ds_tfm[:4]
batch

In [None]:
for k, v in batch.items():
    print(k, v.shape)

In [None]:
# from recfldtkn.base import Base
# from recfldtkn.aidata_base.entry import AIDATA_ENTRYOUTPUT_PATH

# prefix = [
#     'import torch',
#     'import pandas as pd', 
#     'import numpy as np', 
#     'import datasets',
#     ]
# fn_variables = [
#     get_OUTPUT_CFs,
#     entry_fn_AITaskData,
# ]
# pycode = Base.convert_variables_to_pystirng(fn_variables = fn_variables, prefix = prefix)
# pypath = os.path.join(SPACE['CODE_FN'], AIDATA_ENTRYOUTPUT_PATH, f'{EntryOutputMethod}.py')
# print(pypath)
# if not os.path.exists(os.path.dirname(pypath)): os.makedirs(os.path.dirname(pypath))
# with open(pypath, 'w') as file: file.write(pycode)

In [None]:

from torch.utils.data import DataLoader
import time
import numpy as np

# 1. Create DataLoader with your actual training parameters
loader = DataLoader(
    dataset=ds_tfm,  # Your dataset with set_transform
    batch_size=32,            # Use your real batch size
    num_workers=4,            # Match your training setup
    pin_memory=True,          # Same as training config
    shuffle=False             # Disable for consistent measurement
)

# 2. Warm-up run (initial batches are slower due to setup)
print("Warming up...")
for _ in loader:
    pass



# 3. Timed measurement
num_batches = len(loader)
print(f"Testing with {num_batches} batches...")

start_time = time.perf_counter()  # More precise timer
for _ in loader:
    pass
total_time = time.perf_counter() - start_time

# 4. Calculate metrics
throughput = num_batches / total_time
samples_per_sec = len(ds_tfm) / total_time

print(f"\nResults:")
print(f"- Batches/s: {throughput:.1f}")
print(f"- Samples/s: {samples_per_sec:.1f}")
print(f"- Batch time: {1000*total_time/num_batches:.1f}ms")
print(f"- Total time: {total_time:.2f}s")