# Space

In [None]:
import os
import logging
import pandas as pd 
from pprint import pprint 
from IPython.display import display, HTML
pd.set_option('display.max_columns', None)
KEY = 'WorkSpace'
WORKSPACE_PATH = os.getcwd().split(KEY)[0] + KEY
# print(WORKSPACE_PATH)
os.chdir(WORKSPACE_PATH)
import sys
from proj_space import SPACE
sys.path.append(SPACE['CODE_FN'])
SPACE['WORKSPACE_PATH'] = WORKSPACE_PATH
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

from datasets import disable_caching
disable_caching()

SPACE['MODEL_ENDPOINT'] = 'vTest'

# Part 1: AIData

In [None]:
from recfldtkn.aidata_base.aidata import AIData

DATA_AIDATA = SPACE['DATA_AIDATA']
OneAIDataName = 'CgmLhm_Bf24Af2Af2t8H_5Min_3Cohort_EventFlt_Sample'


OneEntryArgs = {
    # ----------------- Task Part -----------------
    'Task_Part': {

        'Tagging': {
            # 'TagName_to_TaggingMethod': {
            #     # TagName: TaggingMethod {Rules: [(x,x,x)], Op: and or}
            # },
            # 'ColumnsAddToDsCase': [],
            'TagFilter': True, # <--- still need to add Fitlter Tag, as we need to do the RandomDownSample.
            'TagSplit': False, # <--- do not need to add Split Tag anymore, as we already have. 
        },

        'Filtering': {
            # 'FilterTagging': None,
            'FilterTagging': {
                "Rules": [
                    ['RandDownSample', '<=', 0.5],
                    ['co.Bf24H_Food_recnum:recnum', '>=', 1], 
                    ], 
                'Op': 'and',
            }
        }, 
        
        'Splitting': {
            # 'SplitTagging': { # <----- for the Tagging part.
            #     'RANDOM_SEED': 32,
            #     'out_ratio': 0.1,
            #     'test_ratio': 'tail0.1',
            #     'valid_ratio': 0.1
            # },
            'TrainEvals': {
                'TrainSetName': 'In-Train', 
                'EvalSetNames': ['In-Test', 'In-Valid', 'Out']
            },
        }
    },

    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_MultiTknInStep',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
            'cf.TimeSparse_Bf24H', 
            'cf.TimeSparse_Af2H',
            'cf.DietSparse_Bf24H',
            'cf.DietSparse_Af2H',
        ],
        'TargetField': 'TargetCGM',
        'TimeField':   'Time',
        'EventFields': [
            'Diet',
        ],
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # 'WithFutureEvent' #  # 'NoFutureEvent', 'WithFutureEvent', 
    }, 

    # ----------------- Output Part -----------------
    'Output_Part': {
        'EntryOutputMethod': 'NTP',
    },
}

aidata = AIData.load_aidata(DATA_AIDATA, OneAIDataName, SPACE)
aidata.update_NameToData_with_OneEntryArgs(OneEntryArgs)
dataset = aidata.Name_to_DS
dataset



In [None]:
# aidata.Name_to_DsAIData
split_name = [i for i in  aidata.Name_to_Data][0]
Name_to_Data = aidata.Name_to_Data# [split_name]
Data = Name_to_Data[split_name]
df_case = Data['df_case']
df_case.head()

In [None]:
ds_tfm = Data['ds_tfm']
# ds_tfm

batch_size = 4
batch = ds_tfm[:batch_size]
batch

# Part 2: Model Init

## Step 1: init_model

In [None]:
from nn.cgmlhm.configuration_cgmlhm import CgmLhmConfig 

ModelArgs = {
    'model_type': 'cgmlhm',
    'OneEntryArgs': aidata.OneEntryArgs,
    'CF_to_CFvocab': aidata.CF_to_CFvocab,
    
}

config = CgmLhmConfig(**ModelArgs)
# print(config)
config.field_to_fieldinfo


In [None]:
from nn.cgmlhm.modeling_cgmlhm import GgmLhmLMHeadModel

model = GgmLhmLMHeadModel(config)
model

# Part 3: Forward

In [None]:
import numpy as np 
import torch 

batch2dp = 8
batch = ds_tfm.select(range(batch2dp))[:batch2dp]

In [None]:
output = model(**batch)
output.loss

In [None]:
[i for i in output]

In [None]:
past_key_values_lsm, past_key_values_fusor = output.past_key_values# [0][0].shape
print(past_key_values_lsm[0][0].shape)
print(len(past_key_values_lsm), len(past_key_values_lsm[0]))

# past_key_values_fusor could be None
print(past_key_values_fusor[0][0].shape)
print(len(past_key_values_fusor), len(past_key_values_fusor[0]))

In [None]:
[i for i in output.keys()]

# Part 4: Inference

In [None]:
for k, v in batch.items():
    batch[k] = v.to(model.device)
    print(k, v.device, v.shape)

## 1. NTP 

In [None]:
###############################
num_old_tokens = 289

items_list = ['losses_each_seq', 
              'losses_each_token', 
              'predicted_ntp_labels', 
              ]
###############################

batch_ntp = {k: v[:, :num_old_tokens] for k, v in batch.items()}

for k, v in batch_ntp.items(): print(k, v.shape)


output = model(**batch_ntp)

# get predicted_labels
logits = output.logits


# get the loss each token
labels = batch['labels'][:, :num_old_tokens]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

logits_permuted = shift_logits.permute(0, 2, 1)
loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
losses = loss_fn(logits_permuted, shift_labels)

batch_ntp_output = {}

if 'losses_each_seq' in items_list:
    losses_each_seq = losses.mean(dim=1).detach().cpu().numpy().tolist()
    batch_ntp_output['losses_each_seq'] = losses_each_seq

if 'losses_each_token' in items_list:
    losses_each_token = losses.detach().cpu().numpy()
    losses_each_token = [losses_each_token[i] for i in range(len(losses_each_token))]
    batch_ntp_output['losses_each_token'] = losses_each_token

if 'predicted_ntp_labels' in items_list:
    # from logits to next token prediction.
    predicted_ntp_labels = torch.argmax(logits, dim=-1)
    predicted_ntp_labels = predicted_ntp_labels.detach().cpu().numpy()# .tolist()
    predicted_ntp_labels = [predicted_ntp_labels[i] for i in range(len(predicted_ntp_labels))]
    batch_ntp_output['predicted_ntp_labels'] = predicted_ntp_labels

df_ouput = pd.DataFrame(batch_ntp_output)
df_ouput

## 2. Gen

In [None]:
from transformers import GenerationConfig

###############################
items_list = ['hist', 'real', 'pred', 'logits']
num_old_tokens = 289
max_new_tokens = 24 
do_sample = False 
with_future_events = False # True
###############################


HF_GenerationConfig = {}
HF_GenerationConfig['max_new_tokens'] = max_new_tokens
HF_GenerationConfig['do_sample'] = do_sample
HF_GenerationConfig['return_dict_in_generate'] = True
if 'logits' in items_list:
    HF_GenerationConfig['output_scores'] = True

batch_gen = {k: v[:, :num_old_tokens] for k, v in batch.items() if '--' not in k}

batch_gen_field = {k: v for k, v in batch.items() if '--' in k}
if with_future_events == False:
    for k, v in batch_gen_field.items():
        if 'event_indicators' in k:
            v[:, num_old_tokens:] = 0
            batch_gen_field[k] = v
batch_gen.update(batch_gen_field)

for k, v in batch_gen.items(): 
    print(k, v.shape)

In [None]:
num_old_tokens

In [None]:
batch_gen['Diet--event_indicators'][:, num_old_tokens:].sum()

In [None]:
generation_config = GenerationConfig(**HF_GenerationConfig)
gen_outputs = model.generate(generation_config = generation_config, 
                              **batch_gen)

In [None]:
gen_outputs.sequences.shape

In [None]:
batch_gen_output = {}
# if 'hist' in 
if 'hist' in items_list:
    hist = batch_gen['input_ids']
    hist = hist.cpu().numpy()
    batch_gen_output['hist'] = hist

if 'real' in items_list:
    real = batch['labels'][:, num_old_tokens: num_old_tokens+max_new_tokens]
    real = real.cpu().numpy()
    batch_gen_output['real'] = real

if 'pred' in items_list:
    sequences = gen_outputs['sequences']
    pred = sequences[:, -max_new_tokens:]
    pred = pred.cpu().numpy()
    batch_gen_output['pred'] = pred

if 'logits' in items_list:
    logits = gen_outputs['scores']
    logit_scores = np.array([logit.cpu().numpy() 
                            for logit in logits]
                            ).transpose(1, 0, 2) 
    batch_gen_output['logit_scores'] = logit_scores


batch_gen_output = {
    k: [v[i] for i in range(v.shape[0])] for k, v in batch_gen_output.items()
}


df_output_gen = pd.DataFrame(batch_gen_output)
df_output_gen

## Step 1: Process_A_Single_Batch

In [None]:
def process_a_single_batch(model, batch, InferenceArgs = None):

    if InferenceArgs is None: InferenceArgs = {}

    # ------------ next-token-generation part ----------------
    NTP_Args = InferenceArgs.get('NTP_Args', None)
    if NTP_Args is not None:
        ###############################
        num_old_tokens = NTP_Args['num_old_tokens']
        items_list = NTP_Args['items_list']
        ###############################
        batch_ntp = {k: v[:, :num_old_tokens] for k, v in batch.items()}
        output = model(**batch_ntp)

        # get predicted_labels
        lm_logits = output.logits

        # get the loss each token
        labels = batch['labels'][:, :num_old_tokens]
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        logits_permuted = shift_logits.permute(0, 2, 1)
        loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
        losses = loss_fn(logits_permuted, shift_labels)

        batch_ntp_output = {}

        if 'losses_each_seq' in items_list:
            losses_each_seq = losses.mean(dim=1).detach().cpu().numpy().tolist()
            batch_ntp_output['losses_each_seq'] = losses_each_seq

        if 'losses_each_token' in items_list:
            losses_each_token = losses.detach().cpu().numpy()
            losses_each_token = [losses_each_token[i] for i in range(len(losses_each_token))]
            batch_ntp_output['losses_each_token'] = losses_each_token

        if 'predicted_ntp_labels' in items_list:
            predicted_ntp_labels = torch.argmax(lm_logits, dim=-1)
            predicted_ntp_labels = predicted_ntp_labels.detach().cpu().numpy()# .tolist()
            predicted_ntp_labels = [predicted_ntp_labels[i] for i in range(len(predicted_ntp_labels))]
            batch_ntp_output['predicted_ntp_labels'] = predicted_ntp_labels
    else:
        batch_ntp_output = {}
        

    # ------------ generation part ----------------
    GEN_Args = InferenceArgs.get('GEN_Args', None)
    if GEN_Args is not None:
        ###############################
        items_list = GEN_Args['items_list']
        num_old_tokens = GEN_Args['num_old_tokens']
        max_new_tokens = GEN_Args['max_new_tokens']
        do_sample = GEN_Args['do_sample']
        ###############################


        HF_GenerationConfig = {}
        HF_GenerationConfig['max_new_tokens'] = max_new_tokens
        HF_GenerationConfig['do_sample'] = do_sample
        HF_GenerationConfig['return_dict_in_generate'] = True
        if any(['logits' in i for i in items_list]):
            HF_GenerationConfig['output_scores'] = True
        generation_config = GenerationConfig(**HF_GenerationConfig)

        batch_gen = {k: v[:, :num_old_tokens] for k, v in batch.items() if '--' not in k}


        # gen_outputs with future events
        if 'pred_wfe' in items_list:
            batch_gen_field_wte = {k: v for k, v in batch.items() if '--' in k}
            batch_gen_wte = {**batch_gen, **batch_gen_field_wte}
            gen_outputs_wte = model.generate(generation_config = generation_config, **batch_gen_wte)
        else:
            gen_outputs_wte = None

        # gen_outputs without future events
        if 'pred_nfe' in items_list:
            batch_gen_field_nfe = {k: v for k, v in batch.items() if '--' in k}
            for k, v in batch_gen_field_nfe.items():
                if 'event_indicators' in k:
                    v[:, num_old_tokens:] = 0   # set future events to 0    
                    batch_gen_field_nfe[k] = v
            batch_gen_nfe = {**batch_gen, **batch_gen_field_nfe}
            gen_outputs_nfe = model.generate(generation_config = generation_config, **batch_gen_nfe)
        else:
            gen_outputs_nfe = None


        
        batch_gen_output = {}
        # if 'hist' in 
        if 'hist' in items_list:
            hist = batch_gen['input_ids']
            hist = hist.cpu().numpy()
            batch_gen_output['hist'] = hist

        if 'real' in items_list:
            real = batch['labels'][:, num_old_tokens: num_old_tokens+max_new_tokens]
            real = real.cpu().numpy()
            batch_gen_output['real'] = real

        if 'pred_wfe' in items_list:
            sequences = gen_outputs_wte['sequences']
            pred_wfe = sequences[:, -max_new_tokens:]
            pred_wfe = pred_wfe.cpu().numpy()
            batch_gen_output['pred_wfe'] = pred_wfe

        if 'logits_wfe' in items_list:
            logits_wfe = gen_outputs_wte['scores']
            logits_wfe = np.array([logit.cpu().numpy() 
                                    for logit in logits_wfe]
                                    ).transpose(1, 0, 2) 
            batch_gen_output['logits_wfe'] = logits_wfe

        if 'pred_nfe' in items_list:
            sequences = gen_outputs_nfe['sequences']
            pred_nfe = sequences[:, -max_new_tokens:]
            pred_nfe = pred_nfe.cpu().numpy()
            batch_gen_output['pred_nfe'] = pred_nfe

        if 'logits_nfe' in items_list:
            logits_nfe = gen_outputs_nfe['scores']
            logits_nfe = np.array([logit.cpu().numpy() 
                                    for logit in logits_nfe]
                                    ).transpose(1, 0, 2) 
            batch_gen_output['logits_nfe'] = logits_nfe


        batch_gen_output = {
            k: [v[i] for i in range(v.shape[0])] for k, v in batch_gen_output.items()
        }
    else:
        batch_gen_output = {}

    batch_output = {**batch_ntp_output, **batch_gen_output}
    return batch_output

In [None]:
# batch_gen['input_ids'].shape
InferenceArgs = {
    'NTP_Args': {
        'num_old_tokens': 289, 
        'items_list': ['losses_each_seq', 'losses_each_token', 'predicted_ntp_labels']
    }, 
    'GEN_Args': {
        'num_old_tokens': 289,
        'max_new_tokens': 24,
        'do_sample': False,
        'items_list': ['hist', 'real', 'pred_wfe', 'logits_wfe', 'pred_nfe', 'logits_nfe'], # wfe: with future events, nfe: without future events
    },
}

In [None]:
batch_output = process_a_single_batch(model, batch, InferenceArgs)

df_batch = pd.DataFrame(batch_output)
df_batch

In [None]:
rec = df_batch.iloc[0]
rec

In [None]:
rec['logits_wfe'].shape

## Step 2: df_case_eval

In [None]:
########################
Split_Name = [i for i in aidata.Name_to_Data][0]
Data = aidata.Name_to_Data[Split_Name]
########################

ds_tfm  = Data['ds_tfm']
df_case = Data['df_case']
print(ds_tfm)
display(df_case.head())

In [None]:
#################################
max_inference_num = 1000
save_df = False 
load_df = False 
chunk_size = 12800
batch_size = 16
#################################

# case_id_columns = aidata.case_id_columns
model = model

Split_Name = [i for i in aidata.Name_to_Data][0]
Data = aidata.Name_to_Data[Split_Name]

ds_tfm = Data['ds_tfm']
df_case = Data['df_case']

if max_inference_num is not None: 
    ds_tfm = ds_tfm.select(range(max_inference_num))
    df_case = df_case.iloc[:max_inference_num]

print(ds_tfm)
print(df_case.shape)
display(df_case.head())

In [None]:
from tqdm import tqdm

###################
# df_case
# ds_tfm
###################

print(model.device)
chunk_numbers = len(df_case) // chunk_size
print(chunk_numbers)

for chunk_id in range(chunk_numbers+1):
    # chunk_id = 0
    start = chunk_id * chunk_size
    end = min((chunk_id+1) * chunk_size, len(df_case))
    print(start, end)


    df_case_chunk = df_case.iloc[start:end].reset_index(drop = True)
    ds_tfm_chunk = ds_tfm.select(range(start, end))
    print(ds_tfm_chunk)
    print(df_case_chunk.shape)

In [None]:
# dataset_chunk = ds_tfm.select(range(start, end)) # ds: chunk_size, 1024.

# TODO: update the folder path and file path
# folder = os.path.join(SPACE['MODEL_ROOT'], model_checkpoint_name, task)
# if not os.path.exists(folder): os.makedirs(folder)


# file = os.path.join(folder, f'chunk_{chunk_id:05}_s{start}_e{end}.p')

# if load_df == True and os.path.exists(file):
#     logger.info(f'Loading chunk {chunk_id} from {file}')
#     inference_results_list.append(file)
#     continue

df_eval_chunk = pd.DataFrame()
for batch_s in tqdm(range(0, len(ds_tfm_chunk), batch_size)):
    batch_e = min(batch_s + batch_size, len(ds_tfm_chunk))
    batch = ds_tfm_chunk[batch_s: batch_e]
    for k, v in batch.items():
        batch[k] = v.to(model.device)
    with torch.no_grad():
        model.eval()
        output = process_a_single_batch(model, batch, InferenceArgs)
        
    df_batch = pd.DataFrame(output)
    df_eval_chunk = pd.concat([df_eval_chunk, df_batch], axis = 0)

df_eval_chunk = df_eval_chunk.reset_index(drop=True)  

df_chunk = pd.concat([df_case_chunk, df_eval_chunk], axis = 1)

df_chunk

# Part 5: Evaluation

In [None]:
df_case_eval = df_chunk
df_case_eval.head()

In [None]:
case = df_case_eval.iloc[0]
case

In [None]:
from nn.eval.seqeval import SeqEvalForOneDataPoint

x_obs_seq = case['hist']    
y_real_seq = case['real']
y_pred_seq = case['pred_wfe']
etric_list = ['rMSE', 'MAE']

print(len(x_obs_seq), len(y_real_seq), len(y_pred_seq))

eval_dp = SeqEvalForOneDataPoint(x_obs_seq, y_real_seq, y_pred_seq, etric_list)
print(eval_dp)
print(eval_dp.get_metric_scores())
eval_dp.plot_cgm_sensor()

In [None]:
from nn.eval.seqeval import SeqEvalForOneDataPointWithHorizons


x_obs_seq_total = case['hist']    
y_real_seq_total = case['real']
y_pred_seq_total = case['pred_wfe']
metric_list = ['rMSE', 'MAE']

horizon_to_se = {
    '000-030min': [0, 6],
    '000-060min': [0, 12],
    '000-120min': [0, 18],
    '000-180min': [0, 24],
    '060-120min': [6, 18],
}

eval_dp = SeqEvalForOneDataPointWithHorizons(x_obs_seq_total, 
                                             y_real_seq_total, 
                                             y_pred_seq_total, 
                                             metric_list,
                                             horizon_to_se)
eval_dp.get_complete_metrics_with_horizon()



In [None]:
from nn.eval.seqeval import SeqEvalForOneEvalSet

setname = 'test'
x_hist_seq_name = 'hist'
y_real_seq_name = 'real'
y_pred_seq_name = 'pred_wfe'

df_case_eval = df_case_eval

eval_instance = SeqEvalForOneEvalSet(
    setname = setname,
    df_case_eval = df_case_eval, 
    x_hist_seq_name = x_hist_seq_name,
    y_real_seq_name = y_real_seq_name, 
    y_pred_seq_name = y_pred_seq_name,
    metric_list = metric_list,
    horizon_to_se = horizon_to_se, 
)

eval_results = eval_instance.get_evaluation_report()
eval_results