In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [3]:
# 1) Define a simple 5x5 input image
image = torch.tensor([
    [1., 2., 0., 1., 3.],
    [0., 1., 2., 2., 1.],
    [3., 1., 0., 1., 0.],
    [2., 2., 1., 0., 1.],
    [0., 1., 3., 2., 2.]
])

# Reshape to (batch, channels, height, width)
image = image.unsqueeze(0).unsqueeze(0)

print("Input Image Shape:", image.shape)
print(image[0, 0])


Input Image Shape: torch.Size([1, 1, 5, 5])
tensor([[1., 2., 0., 1., 3.],
        [0., 1., 2., 2., 1.],
        [3., 1., 0., 1., 0.],
        [2., 2., 1., 0., 1.],
        [0., 1., 3., 2., 2.]])


In [4]:
# 2) Define two 3x3 kernels:
#    - Vertical edge detector
#    - Horizontal edge detector

vertical_kernel = torch.tensor([
    [-1., 0., 1.],
    [-2., 0., 2.],
    [-1., 0., 1.]
])

horizontal_kernel = torch.tensor([
    [-1., -2., -1.],
    [ 0.,  0.,  0.],
    [ 1.,  2.,  1.]
])

# Stack kernels -> shape becomes (2, 3, 3)
kernels = torch.stack([vertical_kernel, horizontal_kernel])

print("Vertical Kernel:")
print(vertical_kernel)

print("\nHorizontal Kernel:")
print(horizontal_kernel)


Vertical Kernel:
tensor([[-1.,  0.,  1.],
        [-2.,  0.,  2.],
        [-1.,  0.,  1.]])

Horizontal Kernel:
tensor([[-1., -2., -1.],
        [ 0.,  0.,  0.],
        [ 1.,  2.,  1.]])


In [5]:
# 2) Define two 3x3 kernels:
#    - Vertical edge detector
#    - Horizontal edge detector

vertical_kernel = torch.tensor([
    [-1., 0., 1.],
    [-2., 0., 2.],
    [-1., 0., 1.]
])

horizontal_kernel = torch.tensor([
    [-1., -2., -1.],
    [ 0.,  0.,  0.],
    [ 1.,  2.,  1.]
])

# Stack kernels -> shape becomes (2, 3, 3)
kernels = torch.stack([vertical_kernel, horizontal_kernel])

print("Vertical Kernel:")
print(vertical_kernel)

print("\nHorizontal Kernel:")
print(horizontal_kernel)


Vertical Kernel:
tensor([[-1.,  0.,  1.],
        [-2.,  0.,  2.],
        [-1.,  0.,  1.]])

Horizontal Kernel:
tensor([[-1., -2., -1.],
        [ 0.,  0.,  0.],
        [ 1.,  2.,  1.]])


In [6]:
# 3) Create a Conv2D layer with:
#    - 1 input channel
#    - 2 output channels (for 2 filters)
#    - kernel size = 3
#    - padding = 1 (to keep output size = 5x5)
#    - bias = False (we want only the kernel effect)

conv = nn.Conv2d(
    in_channels=1,
    out_channels=2,
    kernel_size=3,
    stride=1,
    padding=1,
    bias=False
)

# Load our custom kernels into the conv layer
conv.weight.data = kernels.unsqueeze(1)  # shape becomes (2,1,3,3)

print("Conv layer weight shape:", conv.weight.data.shape)


Conv layer weight shape: torch.Size([2, 1, 3, 3])


In [7]:
# 4) Apply the convolution to the input image
conv_output = conv(image)

print("Convolution Output Shape:", conv_output.shape)

print("\nRaw Output - Feature Map 0 (Vertical edges):")
print(conv_output[0, 0])

print("\nRaw Output - Feature Map 1 (Horizontal edges):")
print(conv_output[0, 1])


Convolution Output Shape: torch.Size([1, 2, 5, 5])

Raw Output - Feature Map 0 (Vertical edges):
tensor([[ 5.,  0., -1.,  5., -4.],
        [ 5.,  0.,  1.,  1., -6.],
        [ 5., -5., -1., -1., -4.],
        [ 6., -2., -3., -1., -3.],
        [ 4.,  5.,  0., -2., -4.]], grad_fn=<SelectBackward0>)

Raw Output - Feature Map 1 (Horizontal edges):
tensor([[ 1.,  4.,  7.,  7.,  4.],
        [ 3.,  0., -1., -3., -6.],
        [ 5.,  3., -3., -5., -2.],
        [-6.,  0.,  7.,  7.,  5.],
        [-6., -7., -4., -2., -2.]], grad_fn=<SelectBackward0>)


In [8]:
# 5) Apply ReLU activation (just like in a real CNN)
activated_output = F.relu(conv_output)

print("After ReLU - Feature Map 0:")
print(activated_output[0, 0])

print("\nAfter ReLU - Feature Map 1:")
print(activated_output[0, 1])


After ReLU - Feature Map 0:
tensor([[5., 0., 0., 5., 0.],
        [5., 0., 1., 1., 0.],
        [5., 0., 0., 0., 0.],
        [6., 0., 0., 0., 0.],
        [4., 5., 0., 0., 0.]], grad_fn=<SelectBackward0>)

After ReLU - Feature Map 1:
tensor([[1., 4., 7., 7., 4.],
        [3., 0., 0., 0., 0.],
        [5., 3., 0., 0., 0.],
        [0., 0., 7., 7., 5.],
        [0., 0., 0., 0., 0.]], grad_fn=<SelectBackward0>)
