# Space

In [None]:
import os
import logging
import pandas as pd 
from pprint import pprint 
from IPython.display import display, HTML
pd.set_option('display.max_columns', None)
KEY = 'WorkSpace'
WORKSPACE_PATH = os.getcwd().split(KEY)[0] + KEY
# print(WORKSPACE_PATH)
os.chdir(WORKSPACE_PATH)
import sys
from proj_space import SPACE
sys.path.append(SPACE['CODE_FN'])
SPACE['WORKSPACE_PATH'] = WORKSPACE_PATH
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

from datasets import disable_caching
disable_caching()

SPACE['MODEL_ENDPOINT'] = 'vTest'

# Part 1: AIData

## Step 1. AIData Args

In [None]:

###################################
HumanRecordRecfeat_Args = {
    'P': {
        # 'BP': [],
        'CGM5Min': ['CGM5Min-N2Cin1'],
        # 'Carb': ['Carb-N2Cin20'],
        # 'Exercise': ['Exercise-Nume'],
        # 'Food': ['Food-NutriNume'],
        'P': ['P-DemoCate'],
        # 'Sleep': ['Sleep-Nume'],
        # 'Step': ['Step-Nume'],
        # 'Weight': ['Weight-Nume'],
        # 'PHeight': [], 
    }
}

CohortName_list = [
    # 'WellDoc2022CGM', 
    # 'WellDoc2023CVSTDC', 
    'WellDoc2023CVSDeRx',
]
HumanRecordRecfeat_Args = HumanRecordRecfeat_Args
Record_Proc_Config = {'save_data': True, 'load_data':True, 'via_method': 'ds'}
Inference_Entry = None # this is not inference mode
###################################

In [None]:
TriggerCaseBaseName = 'Bf24HAf2H_CGM'
TriggerCaseBaseArgs =  {
    # --------- this three are relatively stable ----------------
    'Trigger': {
        'TriggerName': 'CGM5MinEntry', 
        'TagRec': [
            'TagRec.PDemoFromP',
        ],
        'Group': 'GrpGenderDisease', # 
        'Filter': 'FltBasicDemo',
        'ObsTask': {
            'TagCF_list': [
                'TagCF.Bf24hCGMinfo', 
                'TagCF.Af2hCGMinfo',
            ],
            'CF_list':  [],
        }
    },
    # --------------------------------
    
    # --------------------------------
    'FilterCaseSet': {
        'Filter': 'FltMiniBfAfCGMRecInfo',
        'ObsTask': {
            'TagCF_list': [
                'TagCF.Bf24hCGMinfo', 
                'TagCF.Af2hCGMinfo',
            ],
            'CF_list':  [
                'cf.TargetCGM_Bf24H', 
                'cf.TargetCGM_Af2H',
            ],
        },
    }
    # --------------------------------
}


Case_Proc_Config = {
    'max_trigger_case_num': None, 
    'use_task_cache': False, 
    'caseset_chunk_size': 200000, # 200k for CGM, 50k for others.
    'save_data': True, 
    'load_data': True, 
    'load_casecollection': True, 
    'via_method': 'ds',
    'n_cpus': 1, 
    'batch_size': 1000,  
}

In [None]:
###################################
ModelUnitName = 'CGMLSMBf24Af2H-5Min'
OneEntryArgsTemplate = {
    # ----------------- Task Part -----------------
    'Task_Part': {
        'TaskCFs_Args': [],
        'Tagging': {
            'TagName_to_TaggingMethod': {
                # TagName: TaggingMethod {Rules: [(x,x,x)], Op: and or}
            },
            'ColumnsAddToDsCase': [],
            'TagFilter': True,
            'TagSplit': True, 
        },

        'Filtering': {
            # 'FilterTagging': None,
            'FilterTagging': None,
        }, 
        
        'Splitting': {
            'SplitTagging': {
                'RANDOM_SEED': 32,
                'out_ratio': 0.1,
                'test_ratio': 'tail0.1',
                'valid_ratio': 0.1
            },
            'TrainEvals': {
                'TrainSetName': 'In-Train', 
                'EvalSetNames': ['In-Test', 'In-Valid', 'Out']
            },
        }
    },

    # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_1TknInStep',
        'InputCFs_Args': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
        ],
    }, 

    # ----------------- Output Part -----------------
    'Output_Part': {
        'EntryOutputMethod': 'NTP',
        'OutputCFs_Args': [],
    },
}


SubAIDataName_to_Args = {
    # 'RandomDownSample0.1': {
    #     'Task_Part:Filtering:FilterTagging': {
    #         "Rules": [('RandDownSample', '<=', 0.1)], 
    #         'Op': 'and',
    #     },
    # },
    'FullDataNoFiltering': {
        'Task_Part:Filtering:FilterTagging': None,
    },
}

# EntryInputMethod  = OneEntryArgs['Input_Part']['EntryInputMethod']
# EntryOutputMethod = OneEntryArgs['Output_Part']['EntryOutputMethod']

## Step 2: AIData Base

In [None]:
from config.config_record.Cohort import CohortName_to_OneCohortArgs
from config.config_case.CKPD import Ckpd_to_CkpdObsConfig
from recfldtkn.record_base import Record_Base

record_base = Record_Base(
    CohortName_list, 
    HumanRecordRecfeat_Args,
    CohortName_to_OneCohortArgs,
    SPACE = SPACE, 
    Inference_Entry = Inference_Entry,
    Record_Proc_Config = Record_Proc_Config,
)

In [None]:
from config.config_case.GROUP import GROUP_TO_GROUPMethodArgs
from config.config_case.CF import CF_to_CFArgs
from config.config_case.CKPD import Ckpd_to_CkpdObsConfig
from config.config_case.TagRec import TagRec_to_TagRecArgs
from config.config_case.TagCF import TagCF_to_TagCFArgs 
from config.config_case.Flt import FltName_to_FltArgs
from config.config_case.CASE import TriggerCaseBaseName_to_TriggerCaseBaseArgs
from recfldtkn.case_base.case_base import Case_Base

Case_Args_Settings = {
    'Ckpd_to_CkpdObsConfig': Ckpd_to_CkpdObsConfig,
    'CF_to_CFArgs': CF_to_CFArgs,
    'TagCF_to_TagCFArgs': TagCF_to_TagCFArgs,
    'TagRec_to_TagRecArgs': TagRec_to_TagRecArgs,
    'FltName_to_FltArgs': FltName_to_FltArgs,
    'GROUP_TO_GROUPMethodArgs': GROUP_TO_GROUPMethodArgs,
}

TriggerCaseBaseName_to_CohortNameList = {TriggerCaseBaseName: CohortName_list}
TriggerCaseBaseName_to_TriggerCaseBaseArgs[TriggerCaseBaseName] = TriggerCaseBaseArgs
TriggerCaseBaseName_to_CohortNameList = {TriggerCaseBaseName: CohortName_list,}

# 2min 1 cpu
# 1m40s 8 cpus
case_base = Case_Base(
    record_base = record_base, 
    TriggerCaseBaseName_to_CohortNameList = TriggerCaseBaseName_to_CohortNameList, 
    TriggerCaseBaseName_to_TriggerCaseBaseArgs = TriggerCaseBaseName_to_TriggerCaseBaseArgs,
    Case_Proc_Config = Case_Proc_Config,
    Case_Args_Settings = Case_Args_Settings, 
)

In [None]:
from recfldtkn.aidata_base.aidata_base import get_OneAIDataName_to_OneAIDataArgs
from recfldtkn.aidata_base.aidata_base import AIData_Base

# SeriesName = 'UnilabelWeightpredAf1M' 
# OneEntryArgsTemplate = SeriesName_to_OneEntryArgsTemplate[SeriesName]
####################
OneAIDataName_to_OneAIDataArgs = get_OneAIDataName_to_OneAIDataArgs(ModelUnitName, 
                                                                    CohortName_list, 
                                                                    TriggerCaseBaseName,
                                                                    TriggerCaseBaseArgs, 
                                                                    OneEntryArgsTemplate, 
                                                                    SubAIDataName_to_Args)
####################



pprint(OneAIDataName_to_OneAIDataArgs, sort_dicts=False)



aidata_base = AIData_Base(
    case_base = case_base, 
    OneAIDataName_to_OneAIDataArgs = OneAIDataName_to_OneAIDataArgs,
    SPACE = SPACE, 
)   

## Step 3: AIData 

In [None]:
OneAIDataName = aidata_base.get_AIDataName_list()[0]
pprint(OneAIDataName)

OneAIData_Args = aidata_base.get_OneAIDataArgs_from_OneAIDataName(OneAIDataName)
pprint(OneAIData_Args, sort_dicts=False)

In [None]:
aidata = aidata_base.get_aidata_from_OneAIDataName(OneAIDataName)
aidata

In [None]:
Name_to_Data = aidata.Name_to_Data
for Name, Data in Name_to_Data.items():
    print(Name, ':', Data['ds_tfm'])

# hold-out
# hold-in

## Step 4: Prepare A Batch

In [None]:
# aidata.Name_to_DsAIData
split_name = [i for i in  aidata.Name_to_Data][0]
dataset = aidata.Name_to_Data[split_name]
df_case = dataset['df_case']
df_case.head()

In [None]:
ds_tfm = dataset['ds_tfm']
ds_tfm

In [None]:
batch_size = 4
batch = ds_tfm[:batch_size]
batch

In [None]:
input_ids = batch['input_ids']
input_ids.shape

In [None]:
input_ids

In [None]:
input_ids[2, :] # 313 = 288 (24h) +  1 (obspoint) + 24 (2h)

In [None]:
labels = batch['labels']
labels.shape

In [None]:
batch

# Part 2: Model Init

## Step 1: init_model

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
ModelArgs = {
    'model_type': 'cgmgpt_lm',
}

In [None]:
import transformers
import logging
from transformers import (
    CONFIG_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    AutoConfig,
    HfArgumentParser,
    TrainingArguments,
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
# MODEL_CONFIG_CLASSES
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
# MODEL_TYPES

############# this is the NN development that showing our novelty #############
from nn.cgmlsm.configuration_cgmgpt import CgmGptConfig
from nn.cgmlsm.modeling_cgmgpt import CgmGptLMHeadModel
#################################################################

In [None]:
# ----- within the method of init_model.
CF_to_CFvocab = aidata.CF_to_CFvocab
CF = list(CF_to_CFvocab.keys())[0]
CFvocab = CF_to_CFvocab[CF]
tkn2tid = CFvocab['input_ids']['tkn2tid']

config_kwargs = {
    # "cache_dir": model_args.cache_dir,
    # "revision": model_args.model_revision,
    # "token": model_args.token,
    # "trust_remote_code": model_args.trust_remote_code,
    ###########
    'vocab_size': len(tkn2tid),
    'bos_token_id': tkn2tid['[BOS]'],
    'eos_token_id': tkn2tid['[EOS]'],
    'pad_token_id':  0,
    ###########
}

ModelArgs.update(config_kwargs)

pprint(ModelArgs)
config = CgmGptConfig(**ModelArgs)
pprint(config)

In [None]:
import torch 

model = CgmGptLMHeadModel(config) 
total_params = sum(p.numel() for p in model.parameters())
print(total_params)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

print(model)

# Part 3: Save and Load

## Step 1. Save

In [None]:
##########
model_checkpoint_path = '_test2'
##########

In [None]:
ds_case = aidata.Name_to_Data[split_name]['ds_case']
ds_case

In [None]:
ds_case._format_kwargs

In [None]:
# -----------  save aidata -----------
data_path = os.path.join(model_checkpoint_path, 'Data')
if not os.path.exists(data_path): os.makedirs(data_path)
aidata.save_aidata(data_path)

In [None]:
# save ModelInstance 
#######################
model = model
#######################


if not os.path.exists(model_checkpoint_path): os.makedirs(model_checkpoint_path)
model_path = os.path.join(model_checkpoint_path, 'Model')

########################### TODO: update this.
model.save_pretrained(model_path)
###########################

In [None]:
# ----------- save ModelInstanceArgs -----------
# ModelInstanceArgs = self.ModelInstanceArgs  
import json 

ModelInstanceArgs = {
    'ModelArgs': ModelArgs,
    # 'TrainingArgs': TrainingArgs,
    # 'InferenceArgs': InferenceArgs,
    # 'EvaluationArgs': EvaluationArgs,
    'SPACE': SPACE,
}

ModelInstanceArgs_path = os.path.join(model_checkpoint_path, 'ModelInstanceArgs.json')
with open(ModelInstanceArgs_path, 'w') as f:
    json.dump(ModelInstanceArgs, f, indent = 4)

## Step 2: Load

In [None]:
# model_checkpoint_path = '../_Model/vTestCGMFull/models/CGMOnlyLSM/checkpoint-8401' 

model_checkpoint_path = model_checkpoint_path

In [None]:
# model
assert model 
model_path = os.path.join(model_checkpoint_path, 'Model')     
model = model.from_pretrained(model_path)
# print(id(model), id(model2))
model