1. In a normal convolutional layer (say, Conv2d in PyTorch):
	•	You have an input with:
	•	M input channels (e.g., RGB = 3)
	•	Each filter produces 1 output channel
	•	To produce N output channels, you learn N full 3D filters
→ each of size (M × K × K)

So total parameters =

Params_standard = N x M x K x K

Each output pixel is a sum over all input channels — lots of computation.

2. What a Depthwise Separable Convolution Does

It splits a regular convolution into two simpler steps:
	1.	Depthwise Convolution
→ One filter per input channel (not across all channels)
→ Each filter only looks at its own channel.
	2.	Pointwise Convolution (1×1 Conv)
→ Combines the results of all channels linearly using 1×1 filters.

So instead of one large convolution mixing space + channels at once,
we separate spatial filtering and channel mixing.

3. Step-by-Step Example

Let’s say:
Input shape = (H, W, 32)  (32 channels)

Kernel size = 3×3

Output channels = 64

Standard Conv:

Each of 64 filters has shape (3×3×32)

Parameters = 64 × 32 × 3 × 3 = 18,432

Depthwise Separable Conv:

(a) Depthwise step:

One 3×3 filter per channel → (3×3×32)

Parameters = 32 × 3 × 3 = 288

(b) Pointwise step (1×1 Conv):

64 filters of shape (1×1×32)

Parameters = 64 × 32 × 1 × 1 = 2,048

Total = 288 + 2,048 = 2,336 parameters

That’s almost 8× fewer parameters than a standard conv!

In [1]:
import torch
from torch import nn

In [14]:
class DepthWiseSeparableConv(nn.Module):
  def __init__(self, n_in, n_out, kernel_size=3, stride=1, padding=1):
    super().__init__()

    self.depthwise = nn.Conv2d(n_in, n_in, kernel_size, stride=stride, padding=padding, groups=n_in, bias=False)
    self.pointwise = nn.Conv2d(n_in, n_out, kernel_size=1, bias=False)
    self.bn = nn.BatchNorm2d(n_out)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.depthwise(x)
    x = self.pointwise(x)
    x = self.bn(x)
    x = self.relu(x)
    return x

In [15]:
x = torch.randn(1, 32, 64, 64)
model = DepthWiseSeparableConv(32, 64)
y = model(x)
print(y.shape)

torch.Size([1, 64, 64, 64])


In [16]:
# number of parameters in Depthwise separable CNN
sum(p.numel() for p in model.parameters())

2464

In [17]:
class ClassicCNN(nn.Module):
  def __init__(self, n_in, n_out, kernel_size=3, stride=1, padding=1):
    super().__init__()
    self.layer = nn.Conv2d(n_in, n_out, kernel_size=kernel_size, bias=False)
    self.bn = nn.BatchNorm2d(n_out)
    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.layer(x)
    x = self.bn(x)
    x = self.relu(x)
    return x

In [18]:
x = torch.randn(1, 32, 64, 64)
model = ClassicCNN(32, 64)
y = model(x)
print(y.shape)

torch.Size([1, 64, 62, 62])


In [19]:
# number of parameters in Classi CNN
sum(p.numel() for p in model.parameters())

18560