# Space

In [None]:
import os
import logging
import pandas as pd 
from pprint import pprint 
from IPython.display import display, HTML
pd.set_option('display.max_columns', None)
KEY = 'WorkSpace'
WORKSPACE_PATH = os.getcwd().split(KEY)[0] + KEY
# print(WORKSPACE_PATH)
os.chdir(WORKSPACE_PATH)
import sys
from proj_space import SPACE
sys.path.append(SPACE['CODE_FN'])
SPACE['WORKSPACE_PATH'] = WORKSPACE_PATH
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

from datasets import disable_caching
disable_caching()

SPACE['MODEL_ENDPOINT'] = 'vTest'

os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Part 1: AIData

In [None]:


from recfldtkn.base import apply_multiple_conditions
import datasets 
import numpy as np

#### ----------- part 0: load dataset -----------
##########################################################################################################################


#############################################################
AIDataName = 'EventFood2CGM_bf5min_WellDoc_v2_v0323'
#############################################################
path = os.path.join(SPACE['DATA_AIDATA'], AIDataName)
dataset = datasets.load_from_disk(path)
PID_with_food_full = list(set(dataset['PID']))
print(len(PID_with_food_full)) # 654

columns = dataset.column_names
columns_tag = [i for i in columns if '--' not in i]
df_tag = dataset.select_columns(columns_tag).to_pandas()

Split_to_Selection_Food = {
    'eval_food_t1d_train': {
        'Rules': [
            # ['PID', 'in', pid_selected],
            # ['GenderGroup', '==', 'Gender.1'],
            ['DiseaseTypeGroup', '==', 'DiseaseType.1.0'],
            # ['AgeGroup', '==', '40-64'],
            ['Split', '==', 'Train'],
        ], 
        'Op': 'and',
    },

    'eval_food_t2d_train': {
        'Rules': [
            # ['PID', 'in', pid_selected],
            # ['GenderGroup', '==', 'Gender.1'],
            ['DiseaseTypeGroup', '==', 'DiseaseType.2.0'],
            # ['AgeGroup', '==', '40-64'],
            ['Split', '==', 'Train'],
        ], 
        'Op': 'and',
    },


    'eval_food_t1d_valid': {
        'Rules': [
            # ['PID', 'in', pid_selected],
            # ['GenderGroup', '==', 'Gender.2'],
            ['DiseaseTypeGroup', '==', 'DiseaseType.1.0'],
            # ['AgeGroup', '==', '40-64'],
            ['Split', '==', 'Valid'],
        ], 
        'Op': 'and',
    },

    'eval_food_t2d_valid': {
        'Rules': [
            # ['PID', 'in', pid_selected],
            # ['GenderGroup', '==', 'Gender.2'],
            ['DiseaseTypeGroup', '==', 'DiseaseType.2.0'],
            # ['AgeGroup', '==', '40-64'],
            ['Split', '==', 'Valid'],
        ], 
        'Op': 'and',
    },

    'eval_food_t1d_test': {
        'Rules': [
            # ['PID', 'in', pid_selected],
            # ['GenderGroup', '==', 'Gender.2'],
            ['DiseaseTypeGroup', '==', 'DiseaseType.1.0'],
            # ['AgeGroup', '==', '40-64'],
            ['Split', '==', 'Test'],
        ], 
        'Op': 'and',
    },

    'eval_food_t2d_test': {
        'Rules': [
            # ['PID', 'in', pid_selected],
            # ['GenderGroup', '==', 'Gender.2'],
            ['DiseaseTypeGroup', '==', 'DiseaseType.2.0'],
            # ['AgeGroup', '==', '40-64'],
            ['Split', '==', 'Test'],
        ], 
        'Op': 'and',
    },
}

split_to_dataset = {}
for split_name, Selection in Split_to_Selection_Food.items():
    # split_to_dataset[split_name] = dataset.filter(lambda x: apply_multiple_conditions(x, split_config['Rules'], split_config['Op']))
    Rules = Selection['Rules']
    Op = Selection['Op']

    index = apply_multiple_conditions(df_tag, Rules, Op)
    indices = np.where(index == 1)[0]
    # len(indices)
    dataset_selected = dataset.select(indices)
    split_to_dataset[split_name] = dataset_selected

split_to_dataset_food = split_to_dataset
# print(split_to_dataset_food)



In [None]:
#### ----------- part: define entry arguments -----------
OneEntryArgs = {
     # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_MultiTknInStepNoWgt',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
            # 'cf.TargetCGM_Af2Hto8H',
        ],
        'TargetField': 'TargetCGM',
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # 'WithFutureEvent' #  # 'NoFutureEvent', 'WithFutureEvent', 
    }, 

    # ----------------- Output Part -----------------
    'Output_Part': {
        'EntryOutputMethod': 'CausalLM',
        'set_transform': True,
        'num_proc': 4, 
    },
}


In [None]:
# df_case

In [None]:
# ds_tfm = Data['ds_tfm']
# # ds_tfm

# batch_size = 4
# batch = ds_tfm[:batch_size]
# for k, v in batch.items(): print(k, v.shape)
# batch

# Part 2: Model Init

## Step 1: init_model

In [None]:
from nn.cgmlhm.configuration_cgmlhm import CgmLhmConfig 

config = dataset.info.__dict__['config_name']# .features['cf'].feature.vocab
CF_to_CFvocab = config['CF_to_CFvocab']
CF_to_CFArgs = config['CaseSettingInfo']['Case_Args_Settings']['CF_to_CFArgs']
TriggerCaseBaseName = config['TriggerCaseBaseName']
TriggerCaseBaseArgs = config['TriggerCaseBaseName_to_TriggerCaseBaseArgs'][TriggerCaseBaseName]
TriggerName = TriggerCaseBaseArgs['Trigger']['TriggerName']

ModelArgs = {
    'model_type': 'cgmlhm',
    'OneEntryArgs': OneEntryArgs,
    'CF_to_CFvocab': CF_to_CFvocab,
    'fe_num_hidden_layers': 6, 
    'sc_num_hidden_layers': 0, 
    'tf_n_layer': 0, 
}

config = CgmLhmConfig(**ModelArgs)
# print(config)
config.field_to_fieldinfo

In [None]:
from nn.cgmlhm.modeling_cgmlhm import GgmLhmLMHeadModel

model = GgmLhmLMHeadModel(config)
model

In [None]:
batch

In [None]:
for layer_name, params in model.named_parameters():
    print(layer_name, params.shape)

# Part 3: Forward

In [None]:
import numpy as np 
import torch 

batch2dp = 8
batch = ds_tfm.select(range(batch2dp))[:batch2dp]

In [None]:
output = model(**batch)
output.loss

In [None]:
past_key_values_lsm, past_key_values_fusor = output.past_key_values# [0][0].shape
print(past_key_values_lsm[0][0].shape)
print(len(past_key_values_lsm), len(past_key_values_lsm[0]))

# past_key_values_fusor could be None
if past_key_values_fusor is not None:   
    print(past_key_values_fusor[0][0].shape)
    print(len(past_key_values_fusor), len(past_key_values_fusor[0]))

# Part 4: Train

In [None]:
# aidata.TrainEvalsInTrain

In [None]:
# aidata.Name_to_DsAIData
###############################
TrainSetName = aidata.TrainEvals['TrainSetName_InTrain']
EvalSetNames = aidata.TrainEvals['EvalSetNames_InTrain']
max_train_samples = 1000
max_eval_samples = 64
###############################


# ------------ train datasets ------------
TrainData = aidata.Name_to_Data[TrainSetName]
ds_tfm_train = TrainData['ds_tfm']
if max_train_samples is not None:
    max_train_samples = min(len(ds_tfm_train), max_train_samples)
    ds_tfm_train = ds_tfm_train.shuffle(seed=42).select(range(max_train_samples))
logger.info(ds_tfm_train)


# ------------ eval datasets ------------
eval_dataset_dict = {}
for evalname in EvalSetNames:
    if evalname not in aidata.Name_to_Data: 
        logger.info(f'{evalname} not in aidata.Name_to_Data')
        continue
    eval_dataset = aidata.Name_to_Data[evalname]['ds_tfm']
    if max_eval_samples is not None:
        max_eval_samples = min(len(eval_dataset), max_eval_samples)
        eval_dataset = eval_dataset.shuffle(seed=42).select(range(max_eval_samples))
    eval_dataset_dict[evalname] = eval_dataset
logger.info(f'---- eval_datasets ----')
logger.info(eval_dataset_dict)


In [None]:
print(len(ds_tfm_train))
for k, v in eval_dataset_dict.items():
    print(k, len(v))    

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from transformers import Trainer, TrainingArguments, TrainerCallback


#################################
HuggingFaceTrainingArgs = {
    'output_dir': '_test',  # will be updated to model_instance.model_checkpoint_path
    'overwrite_output_dir': False,

    'do_train': True, 
    'num_train_epochs': 10,
    'per_device_train_batch_size': 4, # 64, # 4, # 64
    'per_device_eval_batch_size': 4, # 64, # 4, # 64
    'gradient_accumulation_steps': 4,
    'save_strategy': 'epoch',
    'save_total_limit': 10, 

    'logging_steps': 1,

    'do_eval': True, 
    'eval_steps': 100, 
    'eval_strategy': 'steps',
    'report_to': 'wandb',


    'save_strategy': 'steps',
    'save_steps': 1000,
    'save_total_limit': 3,

    
    
    # ------- do not change these -------
    'remove_unused_columns': False, # <--- must be False.
    'dataloader_drop_last': True,
}
#################################

training_args = TrainingArguments(**HuggingFaceTrainingArgs)
training_args

In [None]:
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    default_data_collator,
    is_torch_tpu_available,
    set_seed,
)

print(training_args.seed)
set_seed(training_args.seed)

In [None]:
from datetime import datetime
from datasets.fingerprint import Hasher 

###################
AfTknNum = 24
###################

timestamp = datetime.now().strftime("%Y%m%d-%H")
experiment_id = timestamp + "-" + Hasher().hash([aidata.OneAIDataArgs, config])

print(experiment_id)

In [None]:
class TimestampCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        # Add the current timestamp to the logs
        logs["step"] = state.global_step
        logs["timestamp"] = str(datetime.now())

In [None]:
import evaluate

def compute_metrics_for_ntp(eval_preds, experiment_id, AfTknNum = 24):

    metric_acc = evaluate.load("accuracy", experiment_id = experiment_id)
    metric_mse = evaluate.load('mse',      experiment_id = experiment_id)

    preds, labels = eval_preds
    # print(preds.shape, labels.shape)
    # print(preds.shape, labels.shape)
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics but we need to shift the labels
    labels = labels[:, 1:]
    preds  = preds[:, :-1] 
    # print(preds.shape, labels.shape)

    all_labels = labels.reshape(-1)
    all_preds = preds.reshape(-1)
    # print(all_labels.shape, all_preds.shape)


    af_labels = labels[:, -AfTknNum:].reshape(-1)
    af_preds  = preds[:, -AfTknNum:].reshape(-1)
    # print(af_labels.shape, af_preds.shape)
    
    d_accu = metric_acc.compute(predictions=all_preds, references=all_labels)
    d_mse = metric_mse.compute(predictions=all_preds, references=all_labels)
    d_accu_af = metric_acc.compute(predictions=af_preds, references=af_labels)
    d_mse_af = metric_mse.compute(predictions=af_preds, references=af_labels)
    
    d = {}
    for k, v in d_accu.items(): d[k] = v
    for k, v in d_accu_af.items(): d[k + '_af'] = v

    for k, v in d_mse.items(): d[k] = v
    for k, v in d_mse_af.items(): d[k + '_af'] = v

    d['rMSE'] = np.sqrt(d['mse'])
    d['rMSEaf'] = np.sqrt(d['mse_af'])
    
    d['ACUU']   = d['accuracy'] # np.sqrt()
    d['ACUUaf'] = d['accuracy_af'] # np.sqrt()
    
    del d['mse'], d['mse_af'], d['accuracy'], d['accuracy_af']
    return d

In [None]:
def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    # print(logits.shape, type(logits), '<----- logits')
    return logits.argmax(dim=-1)

In [None]:
from torch.profiler import profile, ProfilerActivity, schedule as profiler_schedule

class CorrectProfilerCallback(TrainerCallback):
    def __init__(self, wait=1, warmup=1, active=3):
        self.wait_steps = wait
        self.warmup_steps = warmup
        self.active_steps = active
        self.profiler = None
        self.step_count = 0
        self.profiling_active = False  # Track state manually

    def on_train_begin(self, args, state, control, **kwargs):
        self.profiler = profile(
            activities=[ProfilerActivity.CUDA],
            schedule=profiler_schedule(
                wait=self.wait_steps,
                warmup=self.warmup_steps,
                active=self.active_steps,
                repeat=1
            ),
            on_trace_ready=self._on_trace_ready,
            record_shapes=True
        )

    def _on_trace_ready(self, prof):
        print(prof.key_averages().table(sort_by="cuda_time_total"))

    def on_step_begin(self, args, state, control, **kwargs):
        if not self.profiling_active and state.global_step >= self.wait_steps:
            self.profiler.start()
            self.profiling_active = True

    def on_step_end(self, args, state, control, **kwargs):
        if self.profiling_active:
            self.profiler.step()
            self.step_count += 1

    def on_train_end(self, args, state, control, **kwargs):
        if self.profiling_active:
            self.profiler.stop()
            self.profiling_active = False

In [None]:
trainer = Trainer(
    ########## you have your model 
    model = model,
    ########## you have your training_args
    args = training_args,
    ########## get train_dataset
    train_dataset = ds_tfm_train, # if training_args.do_train else None,
    ########## get eval_dataset
    eval_dataset = eval_dataset_dict, # <--- for in-training evaluation
    ########## huge question here: is it ok to ignore the tokenizer?
    # tokenizer = tokenizer, # Apr 2024: don't add tokenizer, hard to save.
    ########## huge question here: data_collator
    data_collator = default_data_collator,
    compute_metrics = lambda x: compute_metrics_for_ntp(x, experiment_id, AfTknNum),
    preprocess_logits_for_metrics = preprocess_logits_for_metrics,
    callbacks = [CorrectProfilerCallback(wait=1, warmup=1, active=3)],
)

logger.info(trainer)

In [None]:
# Run training (for exactly 5 steps: wait=1 + warmup=1 + active=3)
training_args.max_steps = 5
trainer.train()

In [None]:
torch.cuda.synchronize()  # Before starting profiling

# Final Train

In [None]:
len(ds_tfm_train)

In [None]:
training_args.output_dir

In [None]:
from transformers.trainer_utils import get_last_checkpoint

def prepare_last_checkpoint(training_args):
    # ------------------------------- part 3: last checkpoint -------------------------------
    # Detecting last checkpoint.
    last_checkpoint = None

    dont_overwrite_output_dir = bool(not training_args.overwrite_output_dir)

    if os.path.isdir(training_args.output_dir) and training_args.do_train and dont_overwrite_output_dir:

        last_checkpoint = get_last_checkpoint(training_args.output_dir)

        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
               f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
            logger.info(
               f"Checkpoint detected, resuming training at {last_checkpoint}."
                "To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    return last_checkpoint

In [None]:
checkpoint = prepare_last_checkpoint(training_args)
print(checkpoint)

In [None]:
for batch in trainer.get_train_dataloader():
    print(f"Batch shape: {batch['input_ids'].shape}")
    break  # Just check the first batch

In [None]:
5466579 / 64 / 4 / 5

In [None]:
train_result = trainer.train(resume_from_checkpoint = checkpoint)