In [424]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [425]:
class Xception(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(Xception, self).__init__()
        self.entryflow = EntryFlow(in_channels)
        self.middleflow = MiddleFlow(728)
        self.exitflow = ExitFlow(728, num_classes)

    def forward(self, x):
        x = self.entryflow(x)
        x = self.middleflow(x)
        x = self.exitflow(x)
        print("exitflow")
        return x

In [426]:
class SeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SeparableConv, self).__init__()
        self.depthwiseConv = nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

        self.pointwiseConv = nn.Conv2d(in_channels, out_channels, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)


    def forward(self, x):
        x = self.relu(self.bn1(self.depthwiseConv(x)))
        x = self.relu(self.bn2(self.pointwiseConv(x)))
        return x

In [427]:
class StandardConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
        super(StandardConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):       
        x = self.relu(self.bn(self.conv(x)))
        return x

In [428]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, reps, first_relu=True, max_pool=False, skip=True):
        super(Block, self).__init__()
        self.skip = skip
        self.layers = self.make_layers(in_channels, out_channels, reps, first_relu, max_pool)
        self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)
        self.skip_bn = nn.BatchNorm2d(out_channels)

    def make_layers(self, in_channels, out_channels, reps, first_relu, max_pool):
        layers = []

        if first_relu:
            layers.append(nn.ReLU(inplace=True))
        
        for rep in range(reps):
            if rep > 0:
                layers.append(nn.ReLU(inplace=True))
            layers.append(SeparableConv(in_channels, out_channels))
            in_channels = out_channels

        if max_pool:
            layers.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.layers(x)
        if self.skip:
            out += self.skip_bn(self.skip_conv(x))
        return out

In [429]:
class EntryFlow(nn.Module):
    def __init__(self, in_channels):
        super(EntryFlow, self).__init__()
        self.conv1 = StandardConv(in_channels, 32)
        self.conv2 = StandardConv(32, 64, stride=1)

        self.block1 = Block(64, 128, 2, first_relu=False, max_pool=True)
        self.block2 = Block(128, 256, 2, max_pool=True)
        self.block3 = Block(256, 728, 2, max_pool=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

In [430]:
class MiddleFlow(nn.Module):
    def __init__(self, in_channels):
        super(MiddleFlow, self).__init__()
        self.block1 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block2 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block3 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block4 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block5 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block6 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block7 = Block(in_channels, in_channels, reps=3, skip=False)
        self.block8 = Block(in_channels, in_channels, reps=3, skip=False)

    def forward(self, x):
        x = self.block1(x) + x
        x = self.block2(x) + x
        x = self.block3(x) + x
        x = self.block4(x) + x
        x = self.block5(x) + x
        x = self.block6(x) + x
        x = self.block7(x) + x
        x = self.block8(x) + x
        return x    

In [431]:
class ExitFlow(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(ExitFlow, self).__init__()
        self.layers = nn.Sequential(
            nn.ReLU(inplace=True),
            SeparableConv(in_channels, in_channels),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            SeparableConv(in_channels, 1024),
            nn.BatchNorm2d(1024),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.skip_conv = nn.Conv2d(in_channels, 1024, 1, 2, bias=False)
        self.skip_bn = nn.BatchNorm2d(1024)

        self.sepconv1 = SeparableConv(1024, 1536)
        self.sepconv2 = SeparableConv(1536, 2048)
        self.fc_layer = nn.Linear(2048, num_classes)

    def forward(self, x):
        out = self.layers(x)
        out += self.skip_bn(self.skip_conv(x))

        out = self.sepconv1(out)
        out = self.sepconv2(out)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = out.view(out.size(0), -1)
        out = self.fc_layer(out)
        return out

In [432]:
from torchsummary import summary

model = Xception()
summary(model, (3, 299, 299))

exitflow
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
      StandardConv-4         [-1, 32, 149, 149]               0
            Conv2d-5         [-1, 64, 147, 147]          18,432
       BatchNorm2d-6         [-1, 64, 147, 147]             128
              ReLU-7         [-1, 64, 147, 147]               0
      StandardConv-8         [-1, 64, 147, 147]               0
            Conv2d-9         [-1, 64, 147, 147]             576
      BatchNorm2d-10         [-1, 64, 147, 147]             128
             ReLU-11         [-1, 64, 147, 147]               0
           Conv2d-12        [-1, 128, 147, 147]           8,192
      BatchNorm2d-13        [-1, 128, 147, 147]             256
             ReLU-14        [-