📝 **Author:** Amirhossein Heydari - 📧 **Email:** <amirhosseinheydari78@gmail.com> - 📍 **Origin:** [mr-pylin/pytorch-workshop](https://github.com/mr-pylin/pytorch-workshop)

---


**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Xception](#toc2_)    
  - [Custom Xception](#toc2_1_)    
    - [Initialize the Model](#toc2_1_1_)    
    - [Model Summary](#toc2_1_2_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)


In [None]:
import torch
import torch.nn.functional as F
from torch import nn
from torchinfo import summary

# <a id='toc2_'></a>[Xception](#toc0_)

- Xception, short for **Extreme Inception**, was introduced by [*François Chollet*](https://scholar.google.com/citations?user=VfYhf2wAAAAJ&hl=en) from [*Google Research*](https://research.google/) in 2017.
- It is based on the paper [Xception: Deep Learning with Depthwise Separable Convolutions](https://openaccess.thecvf.com/content_cvpr_2017/html/Chollet_Xception_Deep_Learning_CVPR_2017_paper.html)
- It was trained on the [ImageNet](https://www.image-net.org/) dataset (typically preprocessed with resizing and cropping to 299x299 for Xception) [[ImageNet viewer](https://navigu.net/#imagenet)]
- Known for its use of **depthwise separable convolutions** to enhance efficiency and performance
- The architecture replaces the **traditional Inception modules** with **depthwise separable convolution** layers used in **Inception-v3**
- Designed to optimize the feature extraction process while reducing computational complexity
- Used as a baseline for many applications, including image classification and feature extraction tasks

<figure style="text-align: center;">
  <img src="../../../assets/images/original/cnn/architectures/xception.svg" alt="xception-architecture.svg" style="width: 100%;">
  <figcaption>Xception Module (Depthwise Separable Convolution)</figcaption>
</figure>


## <a id='toc2_1_'></a>[Custom Xception](#toc0_)


In [4]:
class SeparableConv2d(nn.Module):
    """Depthwise Separable Convolution"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 1,
        bias: bool = False,
    ):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

In [5]:
class XceptionBlock(nn.Module):
    """Xception Block with Depthwise Separable Convolutions"""

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        reps: int,
        strides: int = 1,
        start_with_relu: bool = True,
        grow_first: bool = True,
    ):
        super().__init__()
        if out_channels != in_channels or strides != 1:
            self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=strides, bias=False)
            self.skip_bn = nn.BatchNorm2d(out_channels)
        else:
            self.skip = None

        self.relu = nn.ReLU(inplace=True)
        rep = []
        filters = in_channels
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_channels, out_channels, 3, stride=1, padding=1))
            rep.append(nn.BatchNorm2d(out_channels))
            filters = out_channels

        for _ in range(reps - 1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_channels, out_channels, 3, stride=1, padding=1))
            rep.append(nn.BatchNorm2d(out_channels))

        if strides != 1:
            rep.append(nn.MaxPool2d(3, strides, 1))

        self.rep = nn.Sequential(*rep)

    def forward(self, x):
        skip = x
        if self.skip is not None:
            skip = self.skip(skip)
            skip = self.skip_bn(skip)

        x = self.rep(x)
        x += skip
        return x

In [None]:
class Xception(nn.Module):
    """Xception Architecture"""

    def __init__(self, num_classes: int = 1000):
        super().__init__()
        self.entry_flow = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            XceptionBlock(64, 128, reps=2, strides=2, start_with_relu=False, grow_first=True),
            XceptionBlock(128, 256, reps=2, strides=2, start_with_relu=True, grow_first=True),
            XceptionBlock(256, 728, reps=2, strides=2, start_with_relu=True, grow_first=True),
        )

        self.middle_flow = nn.Sequential(*[XceptionBlock(728, 728, reps=3, strides=1, start_with_relu=True, grow_first=True) for _ in range(8)])

        self.exit_flow = nn.Sequential(
            XceptionBlock(728, 1024, reps=2, strides=2, start_with_relu=True, grow_first=False),
            SeparableConv2d(1024, 1536, 3, stride=1, padding=1),
            nn.BatchNorm2d(1536),
            nn.ReLU(inplace=True),
            SeparableConv2d(1536, 2048, 3, stride=1, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.entry_flow(x)
        x = self.middle_flow(x)
        x = self.exit_flow(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

### <a id='toc2_1_1_'></a>[Initialize the Model](#toc0_)


In [11]:
model = Xception()

In [None]:
model

### <a id='toc2_1_2_'></a>[Model Summary](#toc0_)


In [None]:
summary(model, (1, 3, 299, 299), device="cpu")