In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class mfm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1):
        super(mfm, self).__init__()
        self.out_channels = out_channels
        if type == 1:
            self.filter = nn.Conv2d(in_channels, 2*out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        else:
            self.filter = nn.Linear(in_channels, 2*out_channels)
    def forward(self,x):
        x = self.filter(x)
        out = torch.split(x,self.out_channels, 1)
        return torch.max(out[0],out[1])

In [3]:
class group(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(group, self).__init__()
        self.conv_a = mfm(in_channels, in_channels, 1,1, 0)
        self.conv = mfm(in_channels, out_channels, kernel_size, stride, padding)
    def forward(self,x):
        x = self.conv_a(x)
        x = self.conv(x)
        return x

In [4]:
class resblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(resblock, self).__init__()
        self.conv1 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
    def forward(self,x):
        res = x
        out = self.conv1(x)
        out = self.conv2(out)
        out = out + res
        return out

In [5]:
class network_29layers_v2_FV(nn.Module):
    def __init__(self,block, layers, num_classes=79994):
        super(network_29layers_v2_FV, self).__init__()
        self.conv1 = mfm(1,48,5,1,2)

        self.block1 = self._make_layer(block, layers[0], 48, 48)
        self.group1 = group(48,96,3,1,1)

        self.block2 = self._make_layer(block, layers[1], 96, 96)
        self.group2 = group(96,192,3,1,1)

        self.block3 = self._make_layer(block, layers[2], 192, 192)
        self.group3 = group(192,128,3,1,1)

        self.block4 = self._make_layer(block, layers[3], 128, 128)
        self.group4 = group(128,128,3,1,1)

        self.fc = nn.Linear(8*8*128,256)
        self.fc2_ = nn.Linear(256, num_classes, bias=False)

    def _make_layer(self, block, num_blocks, in_channels, out_channels):
        layers = []
        for i in range(0, num_blocks):
            layers.append(block(in_channels, out_channels))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x,2) + F.avg_pool2d(x,2)

        x = self.block1(x)
        x = self.group1(x)
        x = F.max_pool2d(x,2) + F.avg_pool2d(x,2)

        x = self.block2(x)
        x = self.group2(x)
        x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)

        x = self.block3(x)
        x = self.group3(x)
        x = self.block4(x)
        x = self.group4(x)
        x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)

        x = x.view(x.size(0), -1)
        fc = self.fc(x)
        x = F.dropout(fc,training=self.training)
        out = self.fc2_(x)
        return out, fc

In [6]:
def LightCNN_FV_Net(**kwargs):
    model = network_29layers_v2_FV(resblock, [1,2,3,4], **kwargs)
    return model

In [7]:
LightCNN_FV_Net(num_classes=79994)

network_29layers_v2_FV(
  (conv1): mfm(
    (filter): Conv2d(1, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  )
  (block1): Sequential(
    (0): resblock(
      (conv1): mfm(
        (filter): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (conv2): mfm(
        (filter): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (group1): group(
    (conv_a): mfm(
      (filter): Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))
    )
    (conv): mfm(
      (filter): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (block2): Sequential(
    (0): resblock(
      (conv1): mfm(
        (filter): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (conv2): mfm(
        (filter): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
    (1): resblock(
      (conv1): mfm(
        (filter): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1