In [1]:
import sys
# change this path according 
sys.path.append('/hpc/compgen/users/mpages/babe/src')

import torch
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if not p.requires_grad), sum(p.numel() for p in model.parameters())

# Convolution

Count the number of parameters for every CNN architechture

## Bonito

In [6]:
cnn_bonito = nn.Sequential(
        nn.Conv1d(
                in_channels = 1, 
                out_channels = 4, 
                kernel_size = 5, 
                stride= 1, 
                padding=5//2, 
                bias=True
        ),
        nn.SiLU(),
        nn.Conv1d(
                in_channels = 4, 
                out_channels = 16, 
                kernel_size = 5, 
                stride= 1, 
                padding=5//2, 
                bias=True
        ),
        nn.SiLU(),
        nn.Conv1d(
                in_channels = 16, 
                out_channels = 384, 
                kernel_size = 19, 
                stride= 5, 
                padding=19//2, 
                bias=True
        ),
        nn.SiLU()
).to(device)

In [7]:
count_parameters(cnn_bonito)

(117480, 0, 117480)

In [25]:
x = torch.rand((64, 1, 500), device = device)
y = cnn_bonito(x)
print(y.shape)

torch.Size([64, 384, 100])


## CATCaller

In [9]:
d_model = 512
padding = 1
kernel = 3
stride = 2
dilation = 1

cnn_catcaller = nn.Sequential(
    nn.Conv1d(
        in_channels=1,
        out_channels=d_model//2,
        kernel_size=kernel,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=False),
    nn.BatchNorm1d(num_features=d_model//2),
    nn.ReLU(),
    nn.Conv1d(
        in_channels=d_model//2,
        out_channels=d_model,
        kernel_size=kernel,
        stride=stride,
        padding=padding,
        dilation=dilation,
        bias=False),
    nn.BatchNorm1d(num_features=d_model),
    nn.ReLU()
).to(device)

In [10]:
count_parameters(cnn_catcaller)

(395520, 0, 395520)

In [26]:
x = torch.rand((64, 1, 500), device = device)
y = cnn_catcaller(x)
print(y.shape)

torch.Size([64, 512, 125])


## CausalCall

In [12]:
from layers.causalcall import CausalCallConvBlock

num_blocks = 5
num_channels = 256
kernel_size = 3
dilation_multiplier = 2
dilation = 1

layers = list()
for i in range(num_blocks):
    if i == 0:
        layers.append(CausalCallConvBlock(kernel_size, num_channels, 1, dilation))
    else:
        layers.append(CausalCallConvBlock(kernel_size, num_channels, int(num_channels/2), dilation))
    dilation *= dilation_multiplier

cnn_causalcall = nn.Sequential(*layers).to(device)

In [13]:
count_parameters(cnn_causalcall)

(956928, 0, 956928)

In [28]:
x = torch.rand((64, 1, 500), device = device)
y = cnn_causalcall(x)
print(y.shape)

torch.Size([64, 128, 500])


## Halcyon

In [15]:
from layers.halcyon import HalcyonCNNBlock, HalcyonInceptionBlock

strides = [1] * 5
num_kernels = [1, 3, 5, 7, 3]
paddings = ['same'] * 5
num_channels = [64, 96, 128, 160, 32]
use_bn = [True] * 5

cnn_halcyon = nn.Sequential(
    HalcyonCNNBlock( 1, 64, 3, 2, "valid", True),
    HalcyonCNNBlock(64, 64, 3, 1, "valid", True),
    HalcyonCNNBlock(64, 128, 3, 1, "valid", True),
    nn.MaxPool1d(3, 2),
    HalcyonCNNBlock(128, 160, 3, 1, "valid", True),
    HalcyonCNNBlock(160, 384, 3, 1, "valid", True),
    nn.MaxPool1d(3, 2),
    HalcyonInceptionBlock(384, num_channels, num_kernels, strides, paddings, use_bn, scaler = 0.8**1),
    HalcyonInceptionBlock(382, num_channels, num_kernels, strides, paddings, use_bn, scaler = 0.8**2),
    HalcyonInceptionBlock(304, num_channels, num_kernels, strides, paddings, use_bn, scaler = 0.8**3),
).to(device)

In [16]:
count_parameters(cnn_halcyon)

(1329643, 0, 1329643)

In [31]:
x = torch.rand((64, 1, 500), device = device)
y = cnn_halcyon(x)
print(y.shape)

torch.Size([64, 243, 58])


## MinCall

In [18]:
from layers.mincall import MinCallConvBlock

num_layers = 72
pool_every = 24
kernel_size = 3
padding = 'same'
num_channels = 64
max_pool_kernel = 2

layers = list()
layers.append(nn.Conv1d(1, num_channels, kernel_size, 1, padding)) 
for i in range(num_layers):
    if i % pool_every == 0 and i > 0:
        layers.append(nn.MaxPool1d(max_pool_kernel))
    layers.append(MinCallConvBlock(kernel_size, num_channels, num_channels, padding))

cnn_mincall = nn.Sequential(*layers).to(device)

In [19]:
count_parameters(cnn_mincall)

(1797376, 0, 1797376)

In [32]:
x = torch.rand((64, 1, 500), device = device)
y = cnn_mincall(x)
print(y.shape)

torch.Size([64, 64, 125])


## SACall

In [33]:
d_model = 256
kernel = 3
maxpooling_stride = 2 

cnn_sacall = nn.Sequential(
    nn.Conv1d(1, d_model//2, kernel, 1, 1, bias=False),
    nn.BatchNorm1d(d_model//2),
    nn.ReLU(),
    nn.MaxPool1d(kernel, maxpooling_stride, 1),
    nn.Conv1d(d_model//2, d_model, kernel, 1, 1, bias=False),
    nn.BatchNorm1d(d_model),
    nn.ReLU(),
    nn.MaxPool1d(kernel, maxpooling_stride, 1),
    nn.Conv1d(d_model, d_model, kernel, 1, 1, bias=False),
    nn.BatchNorm1d(d_model),
    nn.ReLU(),
    nn.MaxPool1d(kernel, maxpooling_stride, 1)
).to(device)

In [34]:
count_parameters(cnn_sacall)

(296576, 0, 296576)

In [35]:
x = torch.rand((64, 1, 500), device = device)
y = cnn_sacall(x)
print(y.shape)

torch.Size([64, 256, 63])


## URNano

In [3]:
from layers.urnano import URNetDownBlock, URNetFlatBlock, URNetUpBlock, URNet

padding = 'same'
stride = 1
n_channels = [64, 128, 256, 512]
kernel = 11
maxpooling = [2, 2, 2] # in the github json it is [3, 2, 2], changed because we use even number segments

down = nn.ModuleList([URNetDownBlock(1, n_channels[0], kernel, maxpooling[0], stride, padding),
                        URNetDownBlock(n_channels[0], n_channels[1], 3, maxpooling[1], stride, padding),
                        URNetDownBlock(n_channels[1], n_channels[2], 3, maxpooling[2], stride, padding)])
flat = nn.ModuleList([URNetFlatBlock(n_channels[2], n_channels[3], 3, stride, padding)])
up = nn.ModuleList([URNetUpBlock(n_channels[3], n_channels[2], 3, maxpooling[2], maxpooling[2], stride, padding), 
                    URNetUpBlock(n_channels[2], n_channels[1], 3, maxpooling[1], maxpooling[1], stride, padding),
                    URNetUpBlock(n_channels[1], n_channels[0], 3, maxpooling[0], maxpooling[0], stride, padding)])

cnn_urnano = nn.Sequential(URNet(down, flat, up), 
                            nn.Conv1d(n_channels[0], n_channels[0], 3, stride, padding), 
                            nn.BatchNorm1d(n_channels[0]), 
                            nn.ReLU()).to(device)

In [4]:
count_parameters(cnn_urnano)

(3510464, 0, 3510464)

In [5]:
x = torch.rand((64, 1, 4000), device = device)
y = cnn_urnano(x)
print(y.shape)

torch.Size([64, 64, 4000])


  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


# Time

## Bonito

In [3]:
from layers.bonito import BonitoLSTM

In [4]:
rnn = nn.Sequential(
    BonitoLSTM(384, 384, reverse = True),
    BonitoLSTM(384, 384, reverse = False),
    BonitoLSTM(384, 384, reverse = True),
    BonitoLSTM(384, 384, reverse = False),
    BonitoLSTM(384, 384, reverse = True)
)

In [5]:
count_parameters(rnn)

(5913600, 0, 5913600)

## CATCaller

In [6]:
from layers.catcaller import CATCallerEncoderLayer
from layers.layers import PositionalEncoding

d_model = 512
d_ff = 512
dropout = 0.1
padding = 1
kernel = 3
stride = 2
dilation = 1

num_encoder_layers = 6
embed_dims = 512
heads = 4
kernel_size = [3, 7, 15, 31, 31, 31]
weight_softmax = True
weight_dropout = 0.1
with_linear = True
glu = True

pe = PositionalEncoding(
    d_model = d_model, 
    dropout = dropout, 
    max_len = 5000
)

encoder = nn.ModuleList([CATCallerEncoderLayer(
    d_model = d_model, 
    d_ff = d_ff, 
    kernel_size = kernel_size[i], 
    num_heads = heads, 
    channels = embed_dims, 
    dropout = dropout, 
    weight_softmax = weight_softmax, 
    weight_dropout = weight_dropout, 
    with_linear = with_linear,
    glu = glu,
) for i in range(num_encoder_layers)])

In [7]:
print(count_parameters(pe))
print(count_parameters(encoder))

(0, 0, 0)
(14437376, 0, 14437376)


## SACall

In [8]:
d_model = 256
dropout = 0.1
n_layers = 6
n_head = 8
d_ff = 1024

pe = PositionalEncoding(d_model, dropout, max_len = 4000)
encoder_layer = nn.TransformerEncoderLayer(d_model, n_head, d_ff, dropout)
encoder = nn.TransformerEncoder(encoder_layer, n_layers)
rnn = nn.Sequential(
    pe,
    encoder
)

In [9]:
count_parameters(rnn)

(4738560, 0, 4738560)

## URNano

In [12]:
rnn = nn.GRU(128,  hidden_size = 64, num_layers = 3, bidirectional = True)

In [13]:
count_parameters(rnn)

(223488, 0, 223488)

## LSTM3

In [14]:
rnn = nn.LSTM(256,  hidden_size = 256, num_layers = 3, bidirectional = True)

In [15]:
count_parameters(rnn)

(4206592, 0, 4206592)

## LSTM1

In [16]:
rnn = nn.LSTM(256,  hidden_size = 256, num_layers = 1, bidirectional = True)

In [17]:
count_parameters(rnn)

(1052672, 0, 1052672)