In [1]:
from torchinfo import summary
import torch
from torch import nn

In [None]:
class CNN1DModel(nn.Module):
    def __init__(self, in_shape: int, out_shape: int, hidden_units: int, signal_length: int):
        super().__init__()
        self.name = "SimpleCNN1D"
        
        #Hyperparameters
        padding = 1
        kernel_size = 2
        stride = 1
        conv_out_size = int((signal_length+2*padding-kernel_size)/stride + 1)
        
        #Model
        self.simple_conv = nn.Sequential(
        nn.Conv1d(in_channels=in_shape, out_channels=hidden_units, kernel_size = kernel_size, stride=stride, padding=padding),
        nn.ReLU(),
        nn.Conv1d(in_channels=hidden_units, out_channels=hidden_units*2, kernel_size = kernel_size, stride=stride, padding=padding),
        nn.ReLU(),
        nn.Conv1d(in_channels=hidden_units*2, out_channels=hidden_units*4, kernel_size = kernel_size, stride=stride, padding=padding),
        nn.ReLU(),
        nn.Conv1d(in_channels=hidden_units*4, out_channels=hidden_units*8, kernel_size = kernel_size, stride=stride, padding=padding),
        nn.ReLU(),
        nn.Flatten(),
        nn.Dropout(p=0.5, inplace=False),
        nn.Linear(in_features = (hidden_units*8)*conv_out_size, out_features = out_shape)
        )
    def forward(self, x):
        return self.simple_conv(x)

In [None]:
simple_1dcnn = CNN1DModel(4,6*20,hidden_units=72,signal_length = 36).float()
simple_1dcnn

SimpleCNN1DModel(
  (simple_conv): Sequential(
    (0): Conv1d(4, 72, kernel_size=(2,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Flatten(start_dim=1, end_dim=-1)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=2664, out_features=120, bias=True)
  )
)

In [18]:
print(summary(model = simple_1dcnn,
        input_size=(1,4,36),
        dtypes=[torch.float],
        verbose=2,
        col_width=16,
        col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        row_settings=["var_names"],
        ))

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
SimpleCNN1DModel (SimpleCNN1DModel)      --               [1, 120]         --               --
├─Sequential (simple_conv)               --               [1, 120]         --               --
│    └─0.weight                          [72, 4, 2]                        ├─576
│    └─0.bias                            [72]                              ├─72
│    └─4.weight                          [120, 2664]                       ├─319,680
│    └─4.bias                            [120]                             └─120
│    └─Conv1d (0)                        [2]              [1, 72, 37]      648              23,976
│    │    └─weight                       [4, 72, 2]                        ├─576
│    │    └─bias                         [72]                              └─72
│    └─ReLU (1)                          --               [1, 72, 37]      --               --
│    └─Flatten (2)        