In [2]:
import argparse
import numpy as np
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary

from utils import set_seeds, create_src_causal_mask, CEDataset
from train import test_narce, test
from loss import focal_loss

from mamba_ssm.models.config_mamba import MambaConfig
from models import RNN, TCN, TSTransformer, BaselineMamba
from nar_model import NARMamba, AdapterMamba, NarcePipeline

In [7]:
class Args:
    nar_model = 'mamba2_v1'
    adapter_model = 'mamba2_12L'
    nar_dataset = 10000
    sensor_dataset = 4000
    seed = 53

args = Args

In [9]:
embed_nar_model_path = 'narce/saved_model/nar/{}-{}-{}.pt'.format(args.nar_model, args.nar_dataset, args.seed)


embedding_input_dim = 128
nar_vocab_size = 9 # Depends on the # unique tokens NAR takes in, which is # classes of atomic event
output_dim = 4 # The number of complex event classes

mamba_config = MambaConfig(d_model=embedding_input_dim, n_layer=12, ssm_cfg={"layer": "Mamba2", "headdim": 32,})
embed_nar_model = NARMamba(mamba_config, nar_vocab_size=nar_vocab_size, out_cls_dim=output_dim)

embed_nar_model.load_state_dict(torch.load(embed_nar_model_path))
summary(embed_nar_model)

Layer (type:depth-idx)                             Param #
NARMamba                                           --
├─Embedding: 1-1                                   1,152
├─Sequential: 1-2                                  --
│    └─MambaModel: 2-1                             --
│    │    └─MixerModel: 3-1                        1,620,896
│    └─Linear: 2-2                                 516
Total params: 1,622,564
Trainable params: 1,622,564
Non-trainable params: 0

In [11]:
narce_model_path = 'narce/saved_model/pipeline/{}-{}-{}-{}-{}.pt'.format(args.nar_model, args.nar_dataset, args.adapter_model, args.sensor_dataset, args.seed)

if args.adapter_model == 'mamba2_6L':
        n_layer = 6
elif args.adapter_model == 'mamba2_12L':
    n_layer = 12
else:
    raise Exception("Model is not defined.") 

adapter_model = AdapterMamba(d_model=embedding_input_dim, n_layer=n_layer)
# mamba2_v1
mamba_config = MambaConfig(d_model=embedding_input_dim, n_layer=12, ssm_cfg={"layer": "Mamba2", "headdim": 32,})
nar_model = NARMamba(mamba_config, nar_vocab_size=nar_vocab_size, out_cls_dim=output_dim).nar

narce_model = NarcePipeline(
    frozen_nar=nar_model,
    adapter_model=adapter_model,
)
    
narce_model_path = 'narce/saved_model/pipeline/mamba2_v1-{}-{}-{}-{}.pt'.format(args.nar_dataset, args.adapter_model, args.sensor_dataset, args.seed)


narce_model.load_state_dict(torch.load(narce_model_path))
summary(narce_model)

Layer (type:depth-idx)                             Param #
NarcePipeline                                      --
├─AdapterMamba: 1-1                                --
│    └─MambaModel: 2-1                             --
│    │    └─MixerModel: 3-1                        1,620,896
├─Sequential: 1-2                                  --
│    └─MambaModel: 2-2                             --
│    │    └─MixerModel: 3-2                        1,620,896
│    └─Linear: 2-3                                 516
Total params: 3,242,308
Trainable params: 3,242,308
Non-trainable params: 0

In [17]:
str(narce_model.nar.state_dict())==str(embed_nar_model.nar.state_dict())

True

In [23]:
embed_nar_model.embedding(torch.tensor([0,1]))

tensor([[ 1.6387, -1.3919,  0.8471,  1.3545,  0.5181,  1.2689, -0.0990, -0.5393,
         -0.2284,  0.9630,  1.1960, -0.4800, -0.9591,  0.8946, -0.5756,  0.4781,
          0.3824, -0.7835,  0.2404, -0.0248,  0.7036, -1.5191, -1.2571,  0.9365,
          1.0513,  0.4242, -0.0968, -0.8237,  0.7854,  0.3558, -0.1880, -0.3501,
         -0.9367,  1.4402,  1.8848,  0.2540,  0.0710, -0.2804, -0.0677,  2.0380,
          1.4018, -0.9650,  0.3454,  1.1501,  0.7461,  0.4283, -0.0179, -1.0685,
         -1.2970,  2.1924, -0.8828,  0.6288, -1.6202,  0.4118, -1.4721, -1.1071,
         -1.3488,  0.5550, -0.6932,  1.1773,  1.7721,  0.2907,  0.3414,  0.1358,
          2.3439, -1.2300,  0.4248, -0.1094,  1.0541,  1.2931,  1.0693, -0.2783,
         -1.8275,  0.8321,  0.7709,  1.5701, -0.1058, -0.1178, -0.2884,  1.2954,
         -1.7843,  1.3189, -0.2649,  1.6235, -0.0680,  0.2250,  1.9222,  0.0809,
         -1.2229,  1.1344, -0.0624,  1.7212,  1.2224, -1.1569, -0.6702,  1.5044,
         -0.9704,  0.2339,  

In [33]:
{} is not None

True