In [1]:
from torchinfo import summary
import torch
from torch import nn

In [7]:
class CNN1DModel(nn.Module):
    def __init__(self, in_shape: int, out_shape: int, hidden_units: int, signal_length: int):
        super().__init__()
        self.name = "SimpleCNN1D"
        
        # Hyperparameters
        padding = 1
        kernel_size = 2
        stride = 1
        
        # Model
        # Calculate the output length after each Conv1d layer
        def calculate_conv_output_length(input_length, kernel_size, stride, padding):
            return ((input_length - kernel_size + 2 * padding) // stride) + 1

        conv1_output_length = calculate_conv_output_length(signal_length, kernel_size, stride, padding)
        conv2_output_length = calculate_conv_output_length(conv1_output_length, kernel_size, stride, padding)
        conv3_output_length = calculate_conv_output_length(conv2_output_length, kernel_size, stride, padding)
        conv4_output_length = calculate_conv_output_length(conv3_output_length, kernel_size, stride, padding)
        conv5_output_length = calculate_conv_output_length(conv4_output_length, kernel_size, stride, padding)
        
        # Model
        self.simple_conv = nn.Sequential(
            nn.Conv1d(in_channels=in_shape, out_channels=hidden_units, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units, out_channels=hidden_units, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Flatten(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=hidden_units * conv5_output_length, out_features=out_shape)
        )
    
    def forward(self, x):
        return self.simple_conv(x)

In [16]:
cnn1d_36ch = CNN1DModel(36,6*20,hidden_units=72, signal_length=4).float()
cnn1d_36ch

CNN1DModel(
  (simple_conv): Sequential(
    (0): Conv1d(36, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(72, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (3): ReLU()
    (4): Conv1d(72, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (5): ReLU()
    (6): Conv1d(72, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (7): ReLU()
    (8): Conv1d(72, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (9): ReLU()
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Dropout(p=0.5, inplace=False)
    (12): Linear(in_features=648, out_features=120, bias=True)
  )
)

In [17]:
print(summary(model = cnn1d_36ch,
        input_size=(1,36,4),
        dtypes=[torch.float],
        verbose=2,
        col_width=16,
        col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        row_settings=["var_names"],
        ))

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
CNN1DModel (CNN1DModel)                  --               [1, 120]         --               --
├─Sequential (simple_conv)               --               [1, 120]         --               --
│    └─0.weight                          [72, 36, 2]                       ├─5,184
│    └─0.bias                            [72]                              ├─72
│    └─2.weight                          [72, 72, 2]                       ├─10,368
│    └─2.bias                            [72]                              ├─72
│    └─4.weight                          [72, 72, 2]                       ├─10,368
│    └─4.bias                            [72]                              ├─72
│    └─6.weight                          [72, 72, 2]                       ├─10,368
│    └─6.bias                            [72]                              ├─72
│    └─8.weight                          [72, 72, 2] 