In [1]:
import os
import sys
PROJECT_PATH = "/home/albert/Baikal-ML/" #insert your project path
sys.path.append(f"{PROJECT_PATH}")

In [2]:
from nnetworks.layers.config import LstmConfig, DenseInput, MaskedConv1DConfig, ResBlockConfig
from nnetworks.models.config import MuNuSepLstmConfig, MuNuSepResNetConfig
from nnetworks.models.munusep_lstm import MuNuSepLstm
from nnetworks.models.munusep_resnet import MuNuSepResNet

from nnetworks.models.config_manager import save_model_cfg, model_from_yaml

# LSTM based model

In [3]:
cfg = MuNuSepLstmConfig(
        [
            LstmConfig(
                5,
                128,
                1,
                True,
                dropout = 0
            ),
            LstmConfig(
                256,
                128,
                1,
                False,
                dropout = 0
            )
        ],
        [   
            DenseInput(
                256, 
                32,
                activation = {'ReLU': None},
                do_norm = True
                ),
            
            DenseInput(
                32, 
                2,
                activation = {'Softmax': {"dim": 1}},
                do_norm = False
                ),
        ]
    )

model = MuNuSepLstm(cfg)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

542882

In [5]:
name_for_model_arch = "munusep_all_rnn"
# save_model_cfg(cfg, f"/home/albert/Baikal-ML/nnetworks/models/configurations/{name_for_model_arch}.yaml")

In [6]:
model_from_yaml(MuNuSepLstm, f"/home/albert/Baikal-ML/nnetworks/models/configurations/{name_for_model_arch}.yaml")

MuNuSepLstm(
  (lstm_layers): ModuleList(
    (0): LstmLayer(
      (lstm_layer): LSTM(5, 128, batch_first=True, bidirectional=True)
      (norm_layer): MaskedLayerNorm1D()
    )
    (1): LstmLayer(
      (lstm_layer): LSTM(256, 128, batch_first=True, bidirectional=True)
      (norm_layer): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
  )
  (dense_layers): ModuleList(
    (0): DenseBlock(
      (dropout_layer): Dropout(p=0.2, inplace=False)
      (dense_layer): Linear(in_features=256, out_features=32, bias=True)
      (activation): ReLU()
      (norm1d): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): DenseBlock(
      (dropout_layer): Dropout(p=0.2, inplace=False)
      (dense_layer): Linear(in_features=32, out_features=2, bias=True)
      (activation): Softmax(dim=1)
    )
  )
)

# ResNet Based model

In [None]:
cfg = MuNuSepResNetConfig(
        res_blocks = [
            ResBlockConfig(
                id = MaskedConv1DConfig(
                        in_channels = 5,
                        out_channels = 64,
                        kernel_size = 16,
                        strides = 1,
                        activation = {"LeakyReLU": {"negative_slope": 0.1}},
                        dropout = 0.1,
                        do_batch_norm = True
                    ),
                cd = MaskedConv1DConfig(
                        in_channels = 64,
                        out_channels = 64,
                        kernel_size = 16,
                        strides = 2,
                        activation = {"LeakyReLU": {"negative_slope": 0.1}},
                        dropout = 0.1,
                        do_batch_norm = True
                    
                    ),
                skip = MaskedConv1DConfig(
                        in_channels = 5,
                        out_channels = 64,
                        kernel_size = 32,
                        strides = 2,
                        activation = {"LeakyReLU": {"negative_slope": 0.1}},
                        dropout = 0.1,
                        do_batch_norm = True
                )
            ),
            ResBlockConfig(
                id = MaskedConv1DConfig(
                        in_channels = 128,
                        out_channels = 128,
                        kernel_size = 8,
                        strides = 1,
                        activation = {"LeakyReLU": {"negative_slope": 0.1}},
                        dropout = 0.1,
                        do_batch_norm = True
                    ),
                cd = MaskedConv1DConfig(
                        in_channels = 128,
                        out_channels = 128,
                        kernel_size = 8,
                        strides = 2,
                        activation = {"LeakyReLU": {"negative_slope": 0.1}},
                        dropout = 0.1,
                        do_batch_norm = True
                    
                    ),
                skip = MaskedConv1DConfig(
                        in_channels = 128,
                        out_channels = 128,
                        kernel_size = 16,
                        strides = 2,
                        activation = {"LeakyReLU": {"negative_slope": 0.1}},
                        dropout = 0.1,
                        do_batch_norm = True
                )
            ),
        ],
        pooling_type = "Average",
        dense_layers = [
            DenseInput(
                in_features = 256, 
                units = 128,
                activation = {'ReLU': None},
                dropout = 0.2,
                do_norm = True
                ),
            DenseInput(
                in_features = 128, 
                units = 32,
                activation = {'ReLU': None},
                dropout = 0.2,
                do_norm = True
                ),
            DenseInput(
                in_features = 32, 
                units = 2,
                activation = {'Softmax': {"dim": 1}},
                do_norm = False
                ),
        ]
    )

import torch

model = MuNuSepResNet(cfg)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)

# test
model(torch.rand((3,5,10)), torch.ones((3,1,10)))

In [10]:
name_for_model_arch = "munusep_all_resnet"
save_model_cfg(cfg, f"/home/albert/Baikal-ML/nnetworks/models/configurations/{name_for_model_arch}.yaml")