In [2]:
from torchinfo import summary
import torch
from torch import nn

In [None]:
class InversionModel(nn.Module):
    def __init__(self, scales: list[int], nwl_points: int, in_channels: int = 4, c1_filters: int = 1024, c2_filters: int = 1024, kernel_size: int = 5, stride: int = 1, padding: int = 0, pool_size: int = 2, hidden_units: int = 1024, n_outputs: int = 5):
        super(InversionModel, self).__init__()
        # 1. Multi-scale feature mapping
        self.multi_scale_feature_mapping = MultiScaleFeatureMapping(scales=scales, in_channels=in_channels, c1_filters=c1_filters, c2_filters=c2_filters, kernel_size=kernel_size, stride=stride, padding=padding, pool_size=pool_size)
        # 2. Flatten layer
        self.flatten_layer = nn.Flatten()
        # 3. Flatten size calculation layers
        def get_output_shape(input_size: int, kernel_size: int, stride: int = 1, padding: int = 0) -> int:
            return (input_size - kernel_size + 2 * padding) // stride + 1
        pool_stride = 2
        total_features = 0
        for s in scales:
            conv_output_shapes = nwl_points // s
            for _ in range(2):
                conv_output_shapes = get_output_shape(conv_output_shapes, kernel_size, stride=1, padding=0)//pool_stride 
            total_features += conv_output_shapes
        flatten_size = total_features * c2_filters
        # 4. Linear layers
        self.initial_linear_layer = nn.Linear(in_features = flatten_size, out_features = hidden_units)
        self.linear_layer = nn.Sequential(
            nn.Linear(in_features = hidden_units, out_features = hidden_units),
            nn.ReLU())
        self.linear_layers = nn.ModuleList([self.linear_layer for _ in range(4)])
        # 5. Output layer
        self.output_layer = nn.Linear(in_features = hidden_units, out_features = n_outputs)

    def forward(self, x):
        x = self.multi_scale_feature_mapping(x)
        x = self.flatten_layer(x)
        x = self.initial_linear_layer(x)
        for layer in self.linear_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

############################################################################################################
# Helper classes
class MultiScaleFeatureMapping(nn.Module):
    def __init__(self, scales: list[int], in_channels: int = 4, c1_filters: int = 16, c2_filters: int = 32, kernel_size: int = 5, stride: int = 1, padding: int = 1, pool_size: int = 2):
        super(MultiScaleFeatureMapping, self).__init__()
        self.scales = scales
        self.c2_filters = c2_filters
        self.coarse_grains = nn.ModuleDict({f"scale_{s}": CoarseGrain(scale = s) for s in scales})
        self.conv_block = ConvBlock(in_channels=in_channels, c1_filters=c1_filters, c2_filters=c2_filters, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, pool_size=pool_size)
        
    def forward(self, x):
        feature_maps = torch.empty((x.size(0), self.c2_filters, 0), dtype=torch.float32, device=x.device)
        for s in self.scales:
            stokes_cg = self.coarse_grains[f"scale_{s}"](x)
            conv_block_output = self.conv_block(stokes_cg)
            feature_maps = torch.cat((feature_maps, conv_block_output), dim=-1)
        return feature_maps
    
class CoarseGrain(nn.Module):
    def __init__(self, scale: int):
        super(CoarseGrain, self).__init__()
        self.scale = scale

    def forward(self, x):
        unfolded = x.unfold(dimension=2, size=self.scale, step=self.scale)
        stokes_scale = unfolded.mean(dim=-1)
        
        return stokes_scale
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int = 4, c1_filters: int = 1024, c2_filters: int = 1024, kernel_size: int = 5, stride: int = 1, padding: int = 0, pool_size: int = 2):
        super(ConvBlock, self).__init__()
        self.c1 = nn.Sequential(nn.Conv1d(in_channels = in_channels, out_channels=c1_filters, kernel_size=kernel_size, stride=stride, padding=padding),
                                    nn.ReLU(),
                                    nn.MaxPool1d(kernel_size=pool_size))
        self.c2 = nn.Sequential(nn.Conv1d(in_channels = c1_filters, out_channels=c2_filters, kernel_size=kernel_size, stride=stride, padding=padding),
                                    nn.ReLU(),
                                    nn.MaxPool1d(kernel_size=pool_size))

    def forward(self, x):
        x = self.c1(x)
        x = self.c2(x)
        x = self.c2(x)
        return x

In [9]:
scales = [1,2,4]
thermody_model = InversionModel(scales=scales, 
                        nwl_points=112,
                        n_outputs=3*20)

In [10]:
print(summary(model = thermody_model,
        input_size=(1,4,112),
        dtypes=[torch.float],
        verbose=2,
        col_width=16,
        col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
        row_settings=["var_names"],
        ))

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [CoarseGrain: 3, ConvBlock: 2, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, CoarseGrain: 3, ConvBlock: 2, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, CoarseGrain: 3, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Sequential: 3, Conv1d: 4, ReLU: 4, MaxPool1d: 4, Conv1d: 4, ReLU: 4]