# Space

In [None]:
import os
import logging
import pandas as pd 
from pprint import pprint 
from IPython.display import display, HTML
pd.set_option('display.max_columns', None)
KEY = 'WorkSpace'
WORKSPACE_PATH = os.getcwd().split(KEY)[0] + KEY
# print(WORKSPACE_PATH)
os.chdir(WORKSPACE_PATH)
import sys
from proj_space import SPACE
sys.path.append(SPACE['CODE_FN'])
SPACE['WORKSPACE_PATH'] = WORKSPACE_PATH
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='[%(levelname)s:%(asctime)s:(%(filename)s@%(lineno)d %(name)s)]: %(message)s')

from datasets import disable_caching
disable_caching()

SPACE['MODEL_ENDPOINT'] = 'vTest'

# Part 1: AIData

In [None]:


# Oneday: 288, 24pd. 1/12
from datasets import load_from_disk


# 24 / 288

# AIDataName = 'CGM_32h_24pd_WellDoc_v2_v0323' # CGM, 32h, 24 data per day. 
AIDataName = 'CGM_32h_24pd_WellDoc_v2_sample' # CGM, 32h, 24 data per day. 


path = os.path.join(SPACE['DATA_AIDATA'], AIDataName)
print(path)
dataset = load_from_disk(path)
# dataset

config = dataset.info.__dict__['config_name']# .features['cf'].feature.vocab
print([i for i in config])
CF_to_CFvocab = config['CF_to_CFvocab']
print([i for i in CF_to_CFvocab])

CF_to_CFArgs = config['CaseSettingInfo']['Case_Args_Settings']['CF_to_CFArgs']
print([i for i in CF_to_CFArgs])


TriggerCaseBaseName = config['TriggerCaseBaseName']
TriggerCaseBaseArgs = config['TriggerCaseBaseName_to_TriggerCaseBaseArgs'][TriggerCaseBaseName]
TriggerName = TriggerCaseBaseArgs['Trigger']['TriggerName']
TriggerName
# print(TriggerCaseBaseArgs)


In [None]:
# df_tag.columns

from recfldtkn.base import assign_caseSplitTag_to_dsCase
from recfldtkn.base import apply_multiple_conditions
import numpy as np 


columns = dataset.column_names
columns_tag = [i for i in columns if '--' not in i]
df_tag = dataset.select_columns(columns_tag).to_pandas()

def map_age_to_agegroup(age):
    if age < 18:
        return '0-17'
    elif 18<= age < 40:
        return '18-39'
    elif 40<= age < 65:
        return '40-64'
    else:
        return '65+'
    
###### additional tagging columns 
df_tag['Year'] = df_tag['ObsDT'].dt.year
df_tag['Cohort'] = df_tag['PID'].astype(str).str[0]
df_tag['Age'] = df_tag['Year'] - df_tag['YearOfBirth']  # .dt.year
df_tag['AgeGroup'] = df_tag['Age'].apply(map_age_to_agegroup)
##########################


dataset = dataset.add_column('Age', df_tag['Age'].values)
dataset = dataset.add_column('Cohort', df_tag['Cohort'].values)
dataset = dataset.add_column('Year', df_tag['Year'].values)
dataset = dataset.add_column('AgeGroup', df_tag['AgeGroup'].values)


In [None]:
Split_to_Selection = {
    'Train': {
        'Rules': [
            ['Age', '>=', 40],
            ['Cohort', 'in', ['1', '2', '3']], # <--- add Cohort column
            ['Year', 'in', [2020, 2021, 2022, 2023]], # <--- add Year column
            ['GenderGroup', 'in', ['Gender.1', 'Gender.2']], 
            ['ObsDT', '<', '2022-07-01'], 
            ['ObsDT', '>=', '2021-01-01'],
        ], 
        'Op': 'and',
    },
    'Val': {
        'Rules': [
            ['Age', '>=', 40],
            ['Cohort', 'in', ['1', '2', '3']], # <--- add Cohort column
            ['Year', 'in', [2020, 2021, 2022, 2023]], # <--- add Year column
            ['ObsDT', '<', '2023-01-01'], 
            ['ObsDT', '>=', '2022-07-01'],
            ['GenderGroup', 'in', ['Gender.1', 'Gender.2']], 
        ], 
        'Op': 'and',
    },
    'Test': {
        'Rules': [
            ['Age', '>=', 40],
            ['Cohort', 'in', ['1', '2', '3']], # <--- add Cohort column
            ['Year', 'in', [2020, 2021, 2022, 2023]], # <--- add Year column
            ['ObsDT', '>=', '2023-01-01'], 
            ['ObsDT', '<', '2024-01-01'],
            ['GenderGroup', 'in', ['Gender.1', 'Gender.2']], 
        ], 
        'Op': 'and',
    }
}

In [None]:
split_to_dataset = {}
for split_name, Selection in Split_to_Selection.items():
    # split_to_dataset[split_name] = dataset.filter(lambda x: apply_multiple_conditions(x, split_config['Rules'], split_config['Op']))
    Rules = Selection['Rules']
    Op = Selection['Op']

    index = apply_multiple_conditions(df_tag, Rules, Op)
    indices = np.where(index == 1)[0]
    # len(indices)
    dataset_selected = dataset.select(indices)
    split_to_dataset[split_name] = dataset_selected

split_to_dataset

In [None]:
Name_to_Data = {}
for split, dataset in split_to_dataset.items():
    Name_to_Data[split] = {'ds_case': dataset}
Name_to_Data

In [None]:
OneEntryArgs = {
     # ----------------- Input Part -----------------
    'Input_Part': {
        'EntryInputMethod': 'Mto1Period_MultiTknInStepNoWgt',
        'CF_list': [
            'cf.TargetCGM_Bf24H',
            'cf.TargetCGM_Af2H',
            # 'cf.TargetCGM_Af2Hto8H',
        ],
        'TargetField': 'TargetCGM',
        'BeforePeriods': ['Bf24H'],
        'AfterPeriods': ['Af2H'],
        'InferenceMode': False, # 'WithFutureEvent' #  # 'NoFutureEvent', 'WithFutureEvent', 
    }, 

    # ----------------- Output Part -----------------
    'Output_Part': {
        'EntryOutputMethod': 'CausalLM',
        'set_transform': True,
        'num_proc': 4, 
    },

    # 'Output_Part': {
    #     'EntryOutputMethod': 'MaskedLM',
    #     'MaskingRate': 0.15,
    #     'set_transform': True,
    #     'num_proc': 4, 
    # },

    # 'Output_Part': {
    #     'EntryOutputMethod': 'SupervisedFT',
    #     'AfStepNum': 24, # 12, # assert AfterPeriods Af2H,so 12 * 2 = 24
    #     'set_transform': True,
    #     'num_proc': 4, 
    # },
}

from recfldtkn.aidata_base.entry import EntryAIData_Builder

entry = EntryAIData_Builder(TriggerName = TriggerName, 
                            OneEntryArgs = OneEntryArgs, 
                            SPACE = SPACE)

In [None]:
Name_to_Data = entry.setup_EntryFn_to_NameToData(Name_to_Data, CF_to_CFvocab, OneEntryArgs)
# Name_to_Data

In [None]:
Data = Name_to_Data['Train']

# Data
ds_tfm = Data['ds_tfm']
ds_tfm


batch_size = 4
batch = ds_tfm[:batch_size]
batch


# Part 2: Model Init

In [None]:
[CF for CF in CF_to_CFvocab]

In [None]:
OneEntryArgs

In [None]:
InputPart = OneEntryArgs['Input_Part']
TargetField = InputPart['TargetField']
# TimeField = InputPart['TimeField']
# EventFields = InputPart['EventFields']


CF_list = InputPart['CF_list']  
# FieldList = [TimeField] + EventFields
# FieldList

# Field_to_CFs = {Field: [CF for CF in CF_list if Field in CF] for Field in FieldList}
# Field_to_CFs


CF_to_CFvocab = CF_to_CFvocab
# Field_to_CFvocab = {Field: CF_to_CFvocab[CFs[0]] for Field, CFs in Field_to_CFs.items()}
# Field_to_CFvocab


# field_to_vocabsize = {Field: len(Field_to_CFvocab[Field]['input_ids']['tkn2tid']) for Field in FieldList}
# field_to_vocabsize


TargetField


### CgmLhmConfig

In [None]:
# from nn.cgmlhm.configuration_cgmlhm import CgmLhmConfig

from collections import OrderedDict
from typing import Any, List, Mapping, Optional

from transformers import PreTrainedTokenizer, TensorType, is_torch_available
from transformers.configuration_utils import PretrainedConfig
from transformers.onnx import OnnxConfigWithPast, PatchingSpec
from transformers.utils import logging

logger = logging.get_logger(__name__)



class CgmLhmConfig(PretrainedConfig):
    model_type = "cgmlhm"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        'n_layer': 'tf_n_layer',
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
        "layer_norm_epsilon": "layer_norm_eps",
        "hidden_act": "activation_function",
        # "hidden_dropout_prob": 
    }

    def __init__(
        self,
        n_embd=768,
        initializer_range=0.02,
        use_cache=True,

        time_field = 'Time',

        # lsm_config
        lsm_n_positions=1024,
        lsm_n_layer=12,
        lsm_n_head=12,
        lsm_n_inner=None,
        lsm_activation_function="gelu_new",
        lsm_resid_pdrop=0.1,
        lsm_embd_pdrop=0.1,
        lsm_attn_pdrop=0.1,
        lsm_layer_norm_epsilon=1e-5,
        # lsm_initializer_range=0.02,
        lsm_summary_type="cls_index",
        lsm_summary_use_proj=True,
        lsm_summary_activation=None,
        lsm_summary_proj_to_labels=True,
        lsm_summary_first_dropout=0.1,
        lsm_scale_attn_weights=True,
        lsm_scale_attn_by_inverse_layer_idx=False,
        lsm_reorder_and_upcast_attn=False,


        # fieldeconder_settings.
        fe_num_hidden_layers=6,
        fe_timestep_lookback = 600,
        fe_timestep_lookahead = 300,
        fe_embd_pdrop = 0.1,
        fe_use_field_type_embedding = True,
        fe_num_attention_heads=12,
        fe_intermediate_size=3072,
        fe_hidden_act="gelu",
        fe_hidden_dropout_prob=0.1,
        fe_attention_probs_dropout_prob=0.1,
        fe_max_position_embeddings=512,
        # fe_initializer_range=0.02,
        fe_layer_norm_eps=1e-12,
        fe_position_embedding_type="absolute",
        # fe_use_cache=True,
        fe_classifier_dropout=None,


        # step connector 
        sc_num_hidden_layers=2,
        sc_num_attention_heads=12,
        sc_intermediate_size=3072,
        sc_hidden_act="gelu",
        sc_hidden_dropout_prob=0.1,
        sc_attention_probs_dropout_prob=0.1,
        sc_max_position_embeddings=512,
        # sc_initializer_range=0.02,
        sc_layer_norm_eps=1e-12,
        sc_position_embedding_type="absolute",
        # sc_use_cache=True,
        sc_classifier_dropout=None,
        
        # temporal fusor 
        tf_n_layer=4,
        tf_n_head=12,
        tf_n_inner=None,
        tf_activation_function="gelu_new",
        tf_resid_pdrop=0.1,
        tf_embd_pdrop=0.1,
        tf_attn_pdrop=0.1,
        tf_layer_norm_epsilon=1e-5,
        # tf_initializer_range=0.02,
        tf_summary_type="cls_index",
        tf_summary_use_proj=True,
        tf_summary_activation=None,
        tf_summary_proj_to_labels=True,
        tf_summary_first_dropout=0.1,
        tf_scale_attn_weights=True,
        # tf_use_cache=True,
        tf_scale_attn_by_inverse_layer_idx=False,
        tf_reorder_and_upcast_attn=False,
        
        # entry_args = None, # Add this line
        CF_to_CFvocab = None, # Add this line
        OneEntryArgs = None, # Add this line
        **kwargs,
    ):
        self.n_embd = n_embd
        self.initializer_range = initializer_range
        self.use_cache = use_cache

        self.time_field = time_field

        self.lsm_n_embd = n_embd
        self.lsm_n_positions = lsm_n_positions
        self.lsm_n_layer = lsm_n_layer
        self.lsm_n_head = lsm_n_head
        self.lsm_n_inner = lsm_n_inner
        self.lsm_activation_function = lsm_activation_function
        self.lsm_resid_pdrop = lsm_resid_pdrop
        self.lsm_embd_pdrop = lsm_embd_pdrop
        self.lsm_attn_pdrop = lsm_attn_pdrop
        self.lsm_layer_norm_epsilon = lsm_layer_norm_epsilon
        self.lsm_initializer_range = initializer_range
        self.lsm_summary_type = lsm_summary_type
        self.lsm_summary_use_proj = lsm_summary_use_proj
        self.lsm_summary_activation = lsm_summary_activation
        self.lsm_summary_proj_to_labels = lsm_summary_proj_to_labels
        self.lsm_summary_first_dropout = lsm_summary_first_dropout
        self.lsm_scale_attn_weights = lsm_scale_attn_weights
        self.lsm_use_cache = use_cache
        self.lsm_scale_attn_by_inverse_layer_idx = lsm_scale_attn_by_inverse_layer_idx
        self.lsm_reorder_and_upcast_attn = lsm_reorder_and_upcast_attn


        self.fe_hidden_size = n_embd
        self.fe_timestep_lookback = fe_timestep_lookback
        self.fe_timestep_lookahead = fe_timestep_lookahead
        self.fe_embd_pdrop = fe_embd_pdrop
        self.fe_use_field_type_embedding = fe_use_field_type_embedding
        self.fe_num_hidden_layers = fe_num_hidden_layers
        self.fe_num_attention_heads = fe_num_attention_heads
        self.fe_intermediate_size = fe_intermediate_size
        self.fe_hidden_act = fe_hidden_act
        self.fe_hidden_dropout_prob = fe_hidden_dropout_prob
        self.fe_attention_probs_dropout_prob = fe_attention_probs_dropout_prob
        self.fe_max_position_embeddings = fe_max_position_embeddings
        self.fe_initializer_range = initializer_range
        self.fe_layer_norm_eps = fe_layer_norm_eps
        self.fe_position_embedding_type = fe_position_embedding_type
        self.fe_use_cache = use_cache
        self.fe_classifier_dropout = fe_classifier_dropout


        self.sc_hidden_size = n_embd
        self.sc_num_hidden_layers = sc_num_hidden_layers
        self.sc_num_attention_heads = sc_num_attention_heads
        self.sc_intermediate_size = sc_intermediate_size
        self.sc_hidden_act = sc_hidden_act
        self.sc_hidden_dropout_prob = sc_hidden_dropout_prob
        self.sc_attention_probs_dropout_prob = sc_attention_probs_dropout_prob
        self.sc_max_position_embeddings = sc_max_position_embeddings
        self.sc_initializer_range = initializer_range
        self.sc_layer_norm_eps = sc_layer_norm_eps
        self.sc_position_embedding_type = sc_position_embedding_type
        self.sc_use_cache = use_cache
        self.sc_classifier_dropout = sc_classifier_dropout


        self.n_layer = tf_n_layer
        self.tf_n_embd = n_embd
        self.tf_n_layer = tf_n_layer
        self.tf_n_head = tf_n_head
        self.tf_n_inner = tf_n_inner
        self.tf_activation_function = tf_activation_function
        self.tf_resid_pdrop = tf_resid_pdrop
        self.tf_embd_pdrop = tf_embd_pdrop
        self.tf_attn_pdrop = tf_attn_pdrop
        self.tf_layer_norm_epsilon = tf_layer_norm_epsilon
        self.tf_initializer_range = initializer_range
        self.tf_summary_type = tf_summary_type
        self.tf_summary_use_proj = tf_summary_use_proj
        self.tf_summary_activation = tf_summary_activation
        self.tf_summary_proj_to_labels = tf_summary_proj_to_labels
        self.tf_summary_first_dropout = tf_summary_first_dropout
        self.tf_scale_attn_weights = tf_scale_attn_weights
        self.tf_use_cache = use_cache
        self.tf_scale_attn_by_inverse_layer_idx = tf_scale_attn_by_inverse_layer_idx
        self.tf_reorder_and_upcast_attn = tf_reorder_and_upcast_attn

        self.OneEntryArgs = OneEntryArgs
        self.CF_to_CFvocab = CF_to_CFvocab
        if OneEntryArgs is not None and CF_to_CFvocab is not None:
            self.initalize_field_info()

        super().__init__(**kwargs)


    def initalize_field_info(self):
        if not hasattr(self, 'OneEntryArgs') or not hasattr(self, 'CF_to_CFvocab'):
            return None 
        
        # self.CF_to_CFvocab = CF_to_CFvocab
        # self.Field_to_CFvocab = None

        # self.set_field_info_with_OneEntryArgs(OneEntryArgs, CF_to_CFvocab)  
        # self.entry_args = entry_args # Add this line

        # self.OneEntryArgs = OneEntryArgs
        # self.CF_to_CFvocab = CF_to_CFvocab
        # print('in set_field_info_with_OneEntryArgs')
        # print(OneEntryArgs)
        # print(CF_to_CFvocab)

        OneEntryArgs = self.OneEntryArgs
        CF_to_CFvocab = self.CF_to_CFvocab
        InputPart = OneEntryArgs['Input_Part']
        TargetField = InputPart['TargetField']
        TimeField = InputPart.get('TimeField', None)
        EventFields = InputPart.get('EventFields', [])


        CF_list = InputPart['CF_list']  

        if TimeField is not None:
            FieldList = [TimeField] + EventFields
        else:
            FieldList = EventFields
        # FieldList

        Field_to_CFs = {Field: [CF for CF in CF_list if Field in CF] for Field in FieldList}
        # Field_to_CFs

        Field_to_CFvocab = {Field: CF_to_CFvocab[CFs[0]] for Field, CFs in Field_to_CFs.items()}
        # Field_to_CFvocab

        # self.Field_to_CFvocab = Field_to_CFvocab  
        field_to_fieldinfo = {}

        for field in FieldList:
            tkn2tid = Field_to_CFvocab[field]['input_ids']['tkn2tid']
            # field_to_vocabsize = {field: len(tkn2tid)}
            vocab_size = len(tkn2tid) 
            bos_token_id = tkn2tid['[BOS]']
            eos_token_id = tkn2tid['[EOS]']
            pad_token_id = 0
            field_to_fieldinfo[field] = {
                'vocab_size': vocab_size,
                'bos_token_id': bos_token_id,
                'eos_token_id': eos_token_id,
                'pad_token_id': pad_token_id,
            }

        self.field_to_fieldinfo = field_to_fieldinfo


        TargetField_CFs = [CF for CF in CF_list if TargetField in CF] 
        target_field_vocab = CF_to_CFvocab[TargetField_CFs[0]]
        self.target_field_vocab = target_field_vocab
        tkn2tid = target_field_vocab['input_ids']['tkn2tid']   
        self.lsm_vocab_size = len(tkn2tid)
        self.lsm_bos_token_id = tkn2tid['[BOS]']
        self.lsm_eos_token_id = tkn2tid['[EOS]']
        self.lsm_pad_token_id = 0

        
    def to_dict(self):
        output = super().to_dict()
        
        # List of fields to exclude
        fields_to_exclude = ['CF_to_CFvocab', 'target_field_vocab', 'OneEntryArgs']
        
        # Remove excluded fields if they exist
        for field in fields_to_exclude:
            if field in output:
                del output[field]
                
        return output


In [None]:
ModelArgs = {
    'model_type': 'cgmlhm',
    'OneEntryArgs': OneEntryArgs,
    'CF_to_CFvocab': CF_to_CFvocab,
}

config = CgmLhmConfig(**ModelArgs)
# print(config)
config.field_to_fieldinfo

In [None]:
config

In [None]:
from nn.cgmlhm.configuration_cgmlsm import CgmLsmConfig

lsm_kwargs = {k.split('lsm_')[1]: v for k, v in config.to_dict().items() if 'lsm_' in k}
lsm_kwargs

lsm_config = CgmLsmConfig(**lsm_kwargs)
lsm_config

In [None]:
fe_kwargs = {k.split('fe_')[1]: v for k, v in config.to_dict().items() if 'fe_' in k}
# fe_kwargs['n_embd'] = config.n_embd

from nn.cgmlhm.configuration_fieldencoder import FieldEncoderConfig

field_to_feconfig = {}
for field, fieldinfo in config.field_to_fieldinfo.items():
    fe_config = FieldEncoderConfig(**{'field': field, 'fieldinfo': fieldinfo}, **fieldinfo, **fe_kwargs)
    field_to_feconfig[field] = fe_config

pprint(field_to_feconfig)

In [None]:
sc_kwargs = {k.split('sc_')[1]: v for k, v in config.to_dict().items() if 'sc_' in k}
pprint(sc_kwargs, compact=True, sort_dicts=True)


from nn.cgmlhm.configuration_fieldencoder import FieldEncoderConfig

sc_config = FieldEncoderConfig(**sc_kwargs)
sc_config


In [None]:
tf_kwargs = {k.split('tf_')[1]: v for k, v in config.to_dict().items() if 'tf_' in k}

from nn.cgmlhm.configuration_cgmlsm import CgmLsmConfig


tf_config = CgmLsmConfig(**tf_kwargs)
tf_config

# CGMLSM

## Step 1: model config

In [None]:
lsm_config

## Step 2: model structure

In [None]:
batch

In [None]:
from nn.cgmlhm.modeling_cgmlsm import CgmLsmModel

lsm_model = CgmLsmModel(lsm_config)
lsm_model

# vocab size: 409  
# input_length is independent of embedding parameters and MLP size or attention parameters size. 
# hidden_size (embedding size) 

In [None]:
input_ids = batch['input_ids'][:1, :]# .shape
input_ids

In [None]:
# proc
embeddings = lsm_model.wte

output = embeddings(input_ids)
output.shape

In [None]:
repr_layer = lsm_model.h
repr_layer

In [None]:
input_tensor = output

for idx, one_layer in enumerate(repr_layer):# .h:
    print('\n\nlayer', idx)
    print('input_tensor:', input_tensor.shape)
    output_tensor = one_layer.forward(input_tensor)
    # print(output_tensor)
    print(one_layer)
    output_tensor = output_tensor[0]
    print('output_tensor:', output_tensor.shape)

    # prepare for the next layer
    input_tensor = output_tensor

output_tensor.shape

In [None]:
repr_output = output_tensor
repr_output.shape

In [None]:
# pred_layer = lsm_model.pred
lsm_model

In [None]:
input_ids.shape 

In [None]:
output = lsm_model(input_ids)
last_hidden_state = output.last_hidden_state
repr_output = last_hidden_state
repr_output.shape

## Step 3: forward

In [None]:
# lsm_model_inputs = {k: v for k, v in batch.items() if '--' not in k}
# lsm_model_inputs

lsm_model_inputs = {
    'input_ids': batch['input_ids'],
}

print(batch['input_ids'].shape)
lsm_outputs = lsm_model(**lsm_model_inputs)
repr_output = lsm_outputs.last_hidden_state# .shape
print(repr_output.shape)


# FieldEncoder

## Step 1:model config

In [None]:
field_to_feconfig

## Step 2: model structure

In [None]:
from nn.cgmlhm.modeling_fieldencoder import FieldEncoderModel
import torch 


field_encoders = torch.nn.ModuleDict()


for field, fe_config in field_to_feconfig.items():
    field_encoder = FieldEncoderModel(fe_config)
    field_encoders[field] = field_encoder

field_encoders

## step 3: forward

In [None]:
field_to_encoder_outputs = {}

for field, fe_config in field_to_feconfig.items():

    print('\n===================')
    print(field)
    field_encoder = field_encoders[field]   


    batch_field_inputs = {k.split('--')[1]: v for k, v in batch.items() if field + '--' in k}


    for value_name, values in batch_field_inputs.items():
        print(value_name, values.shape)


    print('reshape')
    for value_name, values in batch_field_inputs.items():
        # print('before', value_name, values.shape)
        a, b = values.size(0), values.size(1)
        values = values.view(a * b, -1)
        batch_field_inputs[value_name] = values
        print(value_name, values.shape)

    use_event_indictors = False
    if 'event_indicators' in batch_field_inputs:
        use_event_indictors = True
        event_indicators = batch_field_inputs['event_indicators'] 
        mask = event_indicators.bool().squeeze() 
        batch_field_inputs_filtered = {}
        for k, v in batch_field_inputs.items():
            if k == 'event_indicators': continue
            batch_field_inputs_filtered[k] = v[mask]
        batch_field_inputs = batch_field_inputs_filtered

    input_ids_field = batch_field_inputs['input_ids']
    attention_mask_field = input_ids_field.ne(fe_config.pad_token_id)
    batch_field_inputs['attention_mask'] = attention_mask_field


    print('\nfinal batch_field_inputs')
    for k, v in batch_field_inputs.items(): print(k, v.shape)


    ###### test the field_embeddings
    # # field_encoder(**batch_field_inputs)
    # field_embeddings = field_encoder.embeddings
    # # print(field_embeddings)
    # embed_results = field_embeddings(**batch_field_inputs)
    # # print(results)
    # print('field_embeddings', embed_results.shape)

    print('\nfield_encoder')
    field_outputs = field_encoder(**batch_field_inputs)
    hidden_state = field_outputs.last_hidden_state#.shape
    print('field_outputs.hidden_state', hidden_state.shape)

    pooler_output =field_outputs.pooler_output# .shape
    print('field_outputs.pooler_output', pooler_output.shape)


    # hidden_state
    if use_event_indictors:
        hidden_state_origin = torch.zeros([len(mask),] + list(hidden_state.shape[1:]))
        hidden_state_origin[mask] = hidden_state
    
        pooler_output_origin = torch.zeros([len(mask),] + list(pooler_output.shape[1:]))
        pooler_output_origin[mask] = pooler_output


    else:
        hidden_state_origin  = hidden_state
        pooler_output_origin = pooler_output


    new_shape = [a, b] + list(hidden_state.shape[1:])
    # new_shape
    hidden_state_origin = hidden_state_origin.reshape(new_shape)
    print('field_outputs.hidden_state_origin', hidden_state_origin.shape)


    # pooler_output_origin.shape
    new_shape = [a, b] + list(pooler_output.shape[1:])
    pooler_output_origin = pooler_output_origin.reshape(new_shape)
    print('field_outputs.pooler_output_origin', pooler_output_origin.shape)


    field_to_encoder_outputs[field] = {
        'hidden_state': hidden_state_origin,
        'pooler_output': pooler_output_origin,
    }

# StepConnector

## Step 1: model config

In [None]:
sc_config

## Step 2: model structure

In [None]:
from nn.cgmlhm.modeling_cgmlhm import StepConnector


step_connector_model = StepConnector(sc_config)
step_connector_model

## Step 3: forward

In [None]:
target_state = lsm_outputs.last_hidden_state
target_state.shape

In [None]:
field_states = [encoder_outputs['pooler_output'] for field, encoder_outputs in field_to_encoder_outputs.items()]# ['Diet--event_indicators'].shape

field_states
# field_list= [i for i in field_to_encoder_outputs]

step_state_list = [target_state] + field_states

In [None]:
if len(step_state_list) > 1:
    # get step_state, this should be the encoder. 
    step_field_states = torch.stack(step_state_list, dim=1)
    connector_output = step_connector_model(step_field_states)
    step_states = connector_output.pooler_output
else:
    step_states = target_state

In [None]:
step_states.shape

# CGMLHM

## Step 1: model_config

In [None]:
config

## Step 2: model_structure


In [None]:
from nn.cgmlhm.modeling_cgmlhm import CgmLhmModel

model = CgmLhmModel(config)
model

## Step 3: forward

In [None]:
batch['input_ids'].shape

In [None]:
outputs = model(**batch)

# model: repr 

In [None]:
outputs.last_hidden_state.shape

# CGMLHM-LMHead

In [None]:
from nn.cgmlhm.modeling_cgmlhm import GgmLhmLMHeadModel

model = GgmLhmLMHeadModel(config)
model

In [None]:
# (lm_head): Linear(in_features=768, out_features=409, bias=False)
# 768: embed_size, hidden_size
# 409: vocab_size

In [None]:
batch['input_ids'].shape

In [None]:
lhm_outputs = model.lhm(**batch)
repr_output = lhm_outputs.last_hidden_state
repr_output.shape

In [None]:
# [i for i in lhm_outputs]

In [None]:
repr_output = lhm_outputs[0]
repr_output.shape

In [None]:
model.lm_head

In [None]:
lm_logits = model.lm_head(repr_output)
lm_logits.shape

In [None]:
# repr_tensor 

# prediction layer 

# logits 

# loss 

In [None]:
# Calculate loss manually for next token prediction
import torch.nn.functional as F

# Get the labels from the batch
labels = batch['labels']

# For next token prediction, we need to shift the labels
# Input: [t0, t1, t2, t3, ...]
# Logits predict: [t1, t2, t3, t4, ...]
# So labels should be: [t1, t2, t3, t4, ...]
input_ids = batch['input_ids']
shifted_labels = labels.clone()
shifted_labels[:, :-1] = labels[:, 1:].clone()
# The last position predicts a padding token or similar
shifted_labels[:, -1] = -100  # Ignore last position in loss calculation

# Reshape logits and shifted labels for loss calculation
lm_logits_flat = lm_logits.view(-1, lm_logits.size(-1))
shifted_labels_flat = shifted_labels.view(-1)

# Calculate cross entropy loss with shifted labels
loss = F.cross_entropy(lm_logits_flat, shifted_labels_flat, ignore_index=-100)
print(f"Manually calculated loss for next token prediction: {loss.item()}")

# This should now match the output.loss from model(**batch) in the next cell

In [None]:
output = model(**batch)
output.loss