In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [4]:
input_image=torch.rand(1,3,299,299)

https://github.com/tstandley/Xception-PyTorch/blob/master/xception.py

In [21]:
class conv_bn(nn.Module):
    def __init__(self,in_ch,out_ch,k=3,s=1):
        super(conv_bn,self).__init__()
        
        self.conv1=nn.Conv2d(in_ch,out_ch,kernel_size=k,stride=s,bias=False)
        self.bn=nn.BatchNorm2d(out_ch)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn(x)
        return x
        

In [27]:
class sep_bn(nn.Module):
    def __init__(self,in_ch,out_ch,s=1,p=0):
        super(sep_bn,self).__init__()
        
        self.conv1=nn.Conv2d(in_ch,in_ch,kernel_size=3,stride=s,padding=p,dilation=1,groups=in_ch,bias=False)
        self.bn=nn.BatchNorm2d(in_ch)
        self.conv2= nn.Conv2d(in_ch,out_ch,kernel_size=1,stride=1,padding=0,dilation=1,groups=1,bias=False)
        self.bn2=nn.BatchNorm2d(out_ch)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn(x)
        x=self.conv2(x)
        x=self.bn2(x)
        return x

![](entry_flow.png)

In [32]:
class entry_flow(nn.Module):
    def __init__(self):
        super(entry_flow,self).__init__()
        
        self.conv_32=conv_bn(3,32,s=2)
        self.conv_64=conv_bn(32,64)
        
        self.sc1=sep_bn(64,128,p=1)
        self.sc2=sep_bn(128,128,p=1)
        self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.skip_conv1=conv_bn(64,128,k=1,s=2)
        
        self.sc3=sep_bn(128,256,p=1)
        self.sc4=sep_bn(256,256,p=1)
        self.maxpool2=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.skip_conv2=conv_bn(128,256,k=1,s=2)
        
        self.sc5=sep_bn(256,728,p=1)
        self.sc6=sep_bn(728,728,p=1)
        self.maxpool3=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.skip_conv3=conv_bn(256,728,k=1,s=2)
        
    def forward(self,x):
        x=self.conv_32(x)
        x=F.relu(x)
        x=self.conv_64(x)
        tensor=F.relu(x)
        
        x=self.sc1(x)
        x=F.relu(x)
        x=self.sc2(x)
        x=self.maxpool1(x)
        print(x.shape)
        tensor=self.skip_conv1(tensor)
        print(tensor.shape)
        x=tensor+x
        x=F.relu(x)
        
        x=self.sc3(x)
        x=F.relu(x)
        x=self.sc4(x)
        x=self.maxpool2(x)
        
        tensor=self.skip_conv2(tensor)
        
        x=tensor+x
        x=F.relu(x)
        
        x=self.sc5(x)
        x=F.relu(x)
        x=self.sc6(x)
        x=self.maxpool3(x)
        
        tensor=self.skip_conv3(tensor) 
        x=tensor+x
        return x
        

In [33]:
mo=entry_flow()
x=torch.rand(1,3,299,299)
out=mo(x)
print(out.shape)

torch.Size([1, 128, 74, 74])
torch.Size([1, 128, 74, 74])
torch.Size([1, 728, 19, 19])


![](middle_flow.png)

In [68]:
class middle_flow(nn.Module):
    def __init__(self):
        super(middle_flow,self).__init__()
        
        layers=[]
        
        for _ in range(8):
            layers.append(nn.ReLU(True))
            layers.append(sep_bn(728,728,p=1))
            layers.append(nn.ReLU(True))
            layers.append(sep_bn(728,728,p=1))
            layers.append(nn.ReLU(True))
            layers.append(sep_bn(728,728,p=1))
            
        self.feature=nn.Sequential(*layers)
    
    def forward(self,x):
#         print(self.feature[:6])
        tensor=self.feature[:6](x)
        x=tensor+x
        tensor=self.feature[6:12](x)
        x=tensor+x
        tensor=self.feature[12:18](x)
        x=tensor+x
        tensor=self.feature[18:24](x)
        x=tensor+x
        tensor=self.feature[24:30](x)
        x=tensor+x
        tensor=self.feature[30:36](x)
        x=tensor+x
        tensor=self.feature[36:42](x)
        x=tensor+x
        tensor=self.feature[42:48](x)
        x=tensor+x
        return x

In [69]:
mo=entry_flow()
x=torch.rand(1,3,299,299)
out=mo(x)
mid=middle_flow()
out2=mid(out)
print(out2.shape)

torch.Size([1, 128, 74, 74])
torch.Size([1, 128, 74, 74])
torch.Size([1, 728, 19, 19])


![](exit_flow.png)

In [74]:
class exit_flow(nn.Module):
    def __init__(self):
        super(exit_flow,self).__init__()
        
        self.sc1=sep_bn(728,728,p=1)
        self.sc2=sep_bn(728,1024,p=1)
        self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        
        self.skip_conv1=conv_bn(728,1024,k=1,s=2)
        
        self.sc3=sep_bn(1024,1536,p=1)
        self.sc4=sep_bn(1536,2048,p=1)
        self.fc=nn.Linear(2048,1000)
        
        
    def forward(self,x):
        x0=F.relu(x)
        x=self.sc1(x0)
        x=F.relu(x)
        x=self.sc2(x)
        x=self.maxpool1(x)
        
        tensor=self.skip_conv1(x0)
        
        x=tensor+x
        
        x=self.sc3(x)
        x=F.relu(x)
        x=self.sc4(x)
        x=F.relu(x)
        
        x=F.adaptive_avg_pool2d(x,(1,1))
        x=x.view(x.size(0),-1)
        
        x=self.fc(x)
        
        return x

In [75]:
mo=entry_flow()
x=torch.rand(1,3,299,299)
out=mo(x)
mid=middle_flow()
out2=mid(out)

end=exit_flow()

out3=end(out2)
print(out3.shape)

torch.Size([1, 128, 74, 74])
torch.Size([1, 128, 74, 74])
torch.Size([1, 1000])
