In [14]:
"""
    Example for a simple model
"""

from abc import ABCMeta
# from nets.fc import FCNet
import torch
from torch import nn, Tensor
import itertools as it

In [None]:
# channels=[32, 64, 64, 64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 1]

In [22]:
class MyModel(nn.Module, metaclass=ABCMeta):
    """
    Example for a simple model
    """
    def __init__(self, input_dim: int = 50, num_hid: int = 256, output_dim: int = 2, dropout: float = 0.2):
        super(MyModel, self).__init__()
        test_params = [
        self.img_encoder = ResNetClassifier(
            in_size=(3,224,224),
            out_classes=1024,
            channels=[32, 64, 64, 64, 64, 128, 128, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 1024, 1024, 1024, 1024],
            pool_every=2,
            activation_type='relu',
            activation_params=dict(),
            pooling_type='avg',
            pooling_params=dict(kernel_size=2),
            batchnorm=True,
            dropout=0.1,
        )
#         self.img_encoder = ConvClassifier(**test_params)

    def forward(self, x) -> Tensor:
        img = x[0]
        out = self.img_encoder(img)
        print(img)
        raise
        return self.classifier(x)

SyntaxError: invalid syntax (<ipython-input-22-5d78e084f71c>, line 8)

In [50]:
ACTIVATIONS = {"relu": nn.ReLU, "lrelu": nn.LeakyReLU}
POOLINGS = {"avg": nn.AvgPool2d, "max": nn.MaxPool2d}

class ResidualBlock(nn.Module):
    """
    """
    def __init__(
        self,
        in_channels: int,
        channels: list,
        kernel_sizes: list,
        batchnorm=False,
        dropout=0.0,
        activation_type: str = "relu",
        activation_params: dict = {},
        **kwargs,
    ):

        super().__init__()
        assert channels and kernel_sizes
        assert len(channels) == len(kernel_sizes)
        assert all(map(lambda x: x % 2 == 1, kernel_sizes))

        if activation_type not in ACTIVATIONS:
            raise ValueError("Unsupported activation type")

        self.main_path, self.shortcut_path = None, None
        main_layers = []
        shortcut_layers = []

        # - extract number of conv layers
        N = len(channels)

        # - first conv layer 
        main_layers.append(
            nn.Conv2d(
                in_channels,
                channels[0],
                kernel_size= kernel_sizes[0],
                padding=(int((kernel_sizes[0]-1)/2),
                int((kernel_sizes[0]-1)/2)), bias=True))
        if dropout !=0:
            main_layers.append(torch.nn.Dropout2d(p=dropout))
        if batchnorm == True:
            main_layers.append(torch.nn.BatchNorm2d(channels[0]))
        main_layers.append(ACTIVATIONS[activation_type]())

        #middle layers
        for i in range(1,N-1):
            main_layers.append(
                nn.Conv2d(
                    channels[i-1],
                    channels[i],
                    kernel_size= kernel_sizes[i],
                    padding=(int((kernel_sizes[i]-1)/2),
                    int((kernel_sizes[i]-1)/2)), bias=True))
            if dropout !=0:
                main_layers.append(torch.nn.Dropout2d(p=dropout))
            if batchnorm == True:
                main_layers.append(torch.nn.BatchNorm2d(channels[i]))
            main_layers.append(ACTIVATIONS[activation_type]())
        if N > 1:
            main_layers.append(
                nn.Conv2d(
                    channels[N-2],
                    channels[N-1],
                    kernel_size= kernel_sizes[N-1],
                    padding=(int((kernel_sizes[N-1]-1)/2),
                    int((kernel_sizes[N-1]-1)/2)), bias=True))
        if (in_channels != channels[N-1]):
            shortcut_layers.append(nn.Conv2d (in_channels, channels[N-1], kernel_size= 1, bias=False))

        self.main_path = nn.Sequential(*main_layers)
        self.shortcut_path = nn.Sequential(*shortcut_layers)

    def forward(self, x):
        out = self.main_path(x)
        out = out + self.shortcut_path(x)
        relu = torch.nn.ReLU()
        out = relu(out)
        return out


class ResNetClassifier(nn.Module):
    def __init__(
        self,
        in_size,
        out_classes,
        channels,
        pool_every,
#         hidden_dims,
        activation_type: str = "relu",
        activation_params: dict = {},
        pooling_type: str = "max",
        pooling_params: dict = {},
        batchnorm=False,
        dropout=0.0,
        **kwargs,
    ):
        """
        See arguments of ConvClassifier & ResidualBlock.
        """
        super().__init__()
#             in_size, out_classes, channels, pool_every, activation_type,
#             activation_params, pooling_type, pooling_params, batchnorm, dropout, **kwargs
#         )
        self.batchnorm = batchnorm
        self.dropout = dropout
        self.conv_params=dict(kernel_size=3, stride=1, padding=1)
        self.out_classes = out_classes
        self.in_size = in_size
        self.channels = channels
        self.pool_every = pool_every
#         self.hidden_dims = hidden_dims
        self.activation_type = activation_type
        self.activation_params = activation_params
        self.pooling_type = pooling_type
        self.pooling_params = pooling_params
        self.feature_extractor = self._make_feature_extractor()
        self.liner = torch.nn.Linear(100, self.out_classes)

#         super().__init__(
#             in_size, out_classes, channels, pool_every, hidden_dims, **kwargs
#         )

    def _make_feature_extractor(self):
        in_channels, in_h, in_w, = tuple(self.in_size)
        layers = []
        
        # - extract number of conv layers
        N = len(self.channels)
        
        #1st layer
        temp_in_channels = in_channels
        temp_channels = []
        temp_kernel_sizes = []
        
        #middle layers
        for i in range(1,N):
            temp_channels.append(self.channels[i-1])
            temp_kernel_sizes.append(3)
            if ((i % self.pool_every)==0 and i!=0):
                layers.append(
                    ResidualBlock(
                        in_channels=temp_in_channels,
                        channels=temp_channels,
                        kernel_sizes=temp_kernel_sizes,
                        batchnorm=self.batchnorm,
                        dropout=self.dropout,
                        activation_type=self.activation_type))
                temp_in_channels = self.channels[i-1]
                temp_channels = []
                temp_kernel_sizes = []
                layers.append(POOLINGS[self.pooling_type](self.pooling_params['kernel_size']))
        temp_channels.append(self.channels[N-1])
        temp_kernel_sizes.append(3)
        layers.append(ResidualBlock(
                in_channels=temp_in_channels,
                channels=temp_channels,
                kernel_sizes=temp_kernel_sizes,
                batchnorm=self.batchnorm,
                dropout=self.dropout,
                activation_type=self.activation_type))
        if ((N % self.pool_every)==0):
            layers.append(POOLINGS[self.pooling_type](self.pooling_params['kernel_size']))
        seq = nn.Sequential(*layers)
        return seq
    
    
    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        features = self.liner(features)
        return features