In [1]:
from torchinfo import summary
import torch
from torch import nn

In [26]:
class SimpleCNN1DModel(nn.Module):
    def __init__(self, in_shape, out_shape, hidden_units):
        super().__init__()
        self.name = "SimpleCNN1D"
        padding = 1
        self.simple_conv = nn.Sequential(
        nn.Conv1d(in_channels=in_shape, out_channels=hidden_units, kernel_size = 2, stride=1, padding=padding),
        nn.ReLU(),
        nn.Flatten(),
        nn.Dropout(p=0.5, inplace=False),
        nn.Linear(in_features = hidden_units*5, out_features = out_shape)
        )
    def forward(self, x):
        return self.simple_conv(x)

In [27]:
simple_1dcnn = SimpleCNN1DModel(36,6*20,hidden_units=1024).float()
simple_1dcnn

SimpleCNN1DModel(
  (simple_conv): Sequential(
    (0): Conv1d(36, 1024, kernel_size=(2,), stride=(1,), padding=(1,))
    (1): ReLU()
    (2): Flatten(start_dim=1, end_dim=-1)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=5120, out_features=120, bias=True)
  )
)

In [28]:
print(summary(model = simple_1dcnn,
        input_size=(1,36,4),
        dtypes=[torch.float],
        verbose=2,
        col_width=16,
        col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        row_settings=["var_names"],
        ))

Layer (type (var_name))                  Kernel Shape     Output Shape     Param #          Mult-Adds
SimpleCNN1DModel (SimpleCNN1DModel)      --               [1, 120]         --               --
├─Sequential (simple_conv)               --               [1, 120]         --               --
│    └─0.weight                          [1024, 36, 2]                     ├─73,728
│    └─0.bias                            [1024]                            ├─1,024
│    └─4.weight                          [120, 5120]                       ├─614,400
│    └─4.bias                            [120]                             └─120
│    └─Conv1d (0)                        [2]              [1, 1024, 5]     74,752           373,760
│    │    └─weight                       [36, 1024, 2]                     ├─73,728
│    │    └─bias                         [1024]                            └─1,024
│    └─ReLU (1)                          --               [1, 1024, 5]     --               --
│    └─Flatte