In [2]:
from torchinfo import summary
import torch
from torch import nn

In [12]:

class CNN1DModel(nn.Module):
    def __init__(self, in_shape: int, out_shape: int, hidden_units: int, signal_length: int):
        super().__init__()
        self.name = "SimpleCNN1D"
        
        # Hyperparameters
        padding = 1
        kernel_size = 2
        stride = 1
        
        # Model
        self.simple_conv = nn.Sequential(
            nn.Conv1d(in_channels=in_shape, out_channels=hidden_units, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units, out_channels=hidden_units*2, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units*2, out_channels=hidden_units*4, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.Conv1d(in_channels=hidden_units*4, out_channels=hidden_units*8, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),  # Ensure the output size is fixed to 1
            nn.Flatten(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=hidden_units*8, out_features=out_shape)
        )
    
    def forward(self, x):
        return self.simple_conv(x)

In [13]:
simple_1dcnn = CNN1DModel(4,6*20,hidden_units=72, signal_length=36).float()
simple_1dcnn

CNN1DModel(
  (simple_conv): Sequential(
    (0): Conv1d(4, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Conv1d(72, 144, kernel_size=(2,), stride=(1,), padding=(1,))
    (3): ReLU()
    (4): Conv1d(144, 288, kernel_size=(2,), stride=(1,), padding=(1,))
    (5): ReLU()
    (6): Conv1d(288, 576, kernel_size=(2,), stride=(1,), padding=(1,))
    (7): ReLU()
    (8): AdaptiveAvgPool1d(output_size=1)
    (9): Flatten(start_dim=1, end_dim=-1)
    (10): Dropout(p=0.5, inplace=False)
    (11): Linear(in_features=576, out_features=120, bias=True)
  )
)

In [14]:
print(summary(model = simple_1dcnn,
        input_size=(1,4,36),
        dtypes=[torch.float],
        verbose=2,
        col_width=16,
        col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        row_settings=["var_names"],
        ))

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
CNN1DModel (CNN1DModel)                  --               [1, 120]         --               --
├─Sequential (simple_conv)               --               [1, 120]         --               --
│    └─0.weight                          [72, 4, 2]                        ├─576
│    └─0.bias                            [72]                              ├─72
│    └─2.weight                          [144, 72, 2]                      ├─20,736
│    └─2.bias                            [144]                             ├─144
│    └─4.weight                          [288, 144, 2]                     ├─82,944
│    └─4.bias                            [288]                             ├─288
│    └─6.weight                          [576, 288, 2]                     ├─331,776
│    └─6.bias                            [576]                             ├─576
│    └─11.weight                         [120, 576]