In [9]:
import torch
import torch.nn as nn

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [45]:
class Conv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, device=None):
        super().__init__()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Kaiming Init for CNN
        std = (2.0 / (in_channels * kernel_size[0] * kernel_size[1]))**0.5
        self.weight = torch.randn(out_channels, in_channels, kernel_size[0], kernel_size[1], device=device, requires_grad=True) * std
        self.bias = torch.zeros(out_channels, device=device, requires_grad=True)

    def forward(self, x):
        return torch.stack([self.corr2d_multi_in_out(img, self.weight) for img in x]) + self.bias.view(1, -1, 1, 1)

    def call(self, x):
        return self.forward(x)

    def corr2d(self, X, K):
        h, w = K.shape
        Y = torch.zeros(X.shape[0] - h + 1, X.shape[1] - w + 1, device=device)
        for i in range(Y.shape[0]):
            for j in range(Y.shape[1]):
                Y[i, j] = (X[i:i+h, j:j+w] * K).sum()
        return Y

    def corr2d_multi_in(self, X, K):
        # X: (C_in, H, w), K: (C_in, k_h, k_w)
        # Summ all channels into one 2D slice
        return sum(self.corr2d(x, k) for x, k in zip(X, K))

    def corr2d_multi_in_out(self, X, K):
        # X: (C_in, H, W), K: (C_out, C_in, k_h, k_w)
        # Apply filters one by one to get C_out feature maps
        return torch.stack([self.corr2d_multi_in(X, k) for k in K])