<a href="https://colab.research.google.com/github/PavloZakala/CNN/blob/main/CNN1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Задача

Написати реалізацію операцій forward та backward для Conv2d

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function


# Helper to handle int or tuple arguments
def to_tuple(param):
    if isinstance(param, int):
        return (param, param)
    return param

class MyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, x, weight, bias=None, stride=1, padding=0, dilation=1):
        """
        x:       (N, C_in, H, W)
        weight:  (C_out, C_in, kH, kW)
        bias:    (C_out,) or None
        stride, padding, dilation: int or (h, w) tuples
        """
        # Forward implementation here
        stride_h, stride_w = to_tuple(stride)
        pad_h, pad_w = to_tuple(padding)
        dil_h, dil_w = to_tuple(dilation)

        N, C_in, H_in, W_in = x.shape
        C_out, _, kH, kW = weight.shape

        H_out = (H_in + 2 * pad_h - dil_h * (kH - 1) - 1) // stride_h + 1
        W_out = (W_in + 2 * pad_w - dil_w * (kW - 1) - 1) // stride_w + 1


        x_padded = F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0)

        output = torch.zeros((N, C_out, H_out, W_out), device=x.device, dtype=x.dtype)

        # iterate over output
        for n in range(N):
            for cout in range(C_out):
                for h in range(H_out):
                    for w in range(W_out):
                        h_start = h * stride_h
                        w_start = w * stride_w
                        
                        conv_sum = 0.0
                        for cin in range(C_in):
                            for i in range(kH):
                                for j in range(kW):
                                    h_in = h_start + i * dil_h
                                    w_in = w_start + j * dil_w
                                    
                                    conv_sum += x_padded[n, cin, h_in, w_in] * weight[cout, cin, i, j]
                        
                        output[n, cout, h, w] = conv_sum

        if bias is not None:
            output += bias.view(1, C_out, 1, 1)

        ctx.save_for_backward(x, weight, bias)
        ctx.params = (stride, padding, dilation)

        return output


    @staticmethod
    def backward(ctx, grad_output):
        """
        grad_output: (N, C_out, H_out, W_out)
        returns gradients for (x, weight, bias, stride, padding, dilation)
        """
        # Backward implementation here

        #ctx.needs_input_grad: 0 for x, 1 for weight, 2 for bias
        x, weight, bias = ctx.saved_tensors
        stride, padding, dilation = ctx.params
        
        stride_h, stride_w = to_tuple(stride)
        pad_h, pad_w = to_tuple(padding)
        dil_h, dil_w = to_tuple(dilation)

        N, C_in, H_in, W_in = x.shape
        C_out, _, kH, kW = weight.shape
        _, _, H_out, W_out = grad_output.shape

        x_padded = F.pad(x, (pad_w, pad_w, pad_h, pad_h), mode='constant', value=0)


        grad_x = None
        grad_weight = torch.zeros_like(weight)
        grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_x_padded = torch.zeros_like(x_padded)
        

        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(dim=(0, 2, 3))


        for n in range(N):
            for cout in range(C_out):
                for h in range(H_out):
                    for w in range(W_out):
                        global_grad_pixel = grad_output[n, cout, h, w] # dL/dO
                        
                        h_start = h * stride_h
                        w_start = w * stride_w

                        for cin in range(C_in):
                            for i in range(kH):
                                for j in range(kW):
                                    h_in = h_start + i * dil_h
                                    w_in = w_start + j * dil_w

                                    # dL/dW = sum(dL/dO * dO/dW), dO/dW = x
                                    # dL/dW_pixel += global_grad_pixel * x_pixel
                                    if ctx.needs_input_grad[1]:
                                        grad_weight[cout, cin, i, j] += global_grad_pixel * x_padded[n, cin, h_in, w_in]
                                    
                                    # dL/dx = sum(dL/dO * dO/dx), dO/dx = W
                                    # dL/dx_pixel += global_grad_pixel * weight_pixel
                                    if ctx.needs_input_grad[0]:
                                        grad_x_padded[n, cin, h_in, w_in] += global_grad_pixel * weight[cout, cin, i, j]

        # pad removal for grad_x (N, C, H, W)
        if ctx.needs_input_grad[0]:
            if pad_h > 0 or pad_w > 0:
                origin_h_start = pad_h
                origin_w_start = pad_w
                origin_h_end = grad_x_padded.shape[2] - pad_h
                origin_w_end = grad_x_padded.shape[3] - pad_w

                grad_x = grad_x_padded[:, :, origin_h_start:origin_h_end, origin_w_start:origin_w_end]
            else:
                grad_x = grad_x_padded

        # x, weight, bias, stride, padding, dilation
        return grad_x, grad_weight, grad_bias, None, None, None


class MyConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, bias=True):
        super().__init__()
        self.kernel_size = kernel_size
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel_size))
        self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, x):
        return MyConv2dFunction.apply(x, self.weight, self.bias, self.stride, self.padding, self.dilation)

In [6]:
# Test
device = "cuda" if torch.cuda.is_available() else "cpu"
B, C, H, W = 2, 1, 3, 3
C_out = 2

x = torch.randn(B, C, H, W, device=device, requires_grad=True)

myconv = MyConv2d(C, C_out, (3, 3), stride=1, padding=1, dilation=1, bias=True).to(device)
target_conv = nn.Conv2d(C, C_out, (3, 3), stride=1, padding=1, dilation=1, bias=True).to(device)

with torch.no_grad():
    target_conv.weight.copy_(myconv.weight)
    if myconv.bias is not None:
        target_conv.bias.copy_(myconv.bias)

# Forward check
y_my = myconv(x)
y_target = target_conv(x)
print(y_my.shape, y_target.shape)
print(y_my)
print(y_target)

print("Forward check:", torch.allclose(y_my, y_target, atol=1e-6, rtol=1e-5))

# Backward check
loss_my = y_my.square().mean()
loss_ref = y_target.square().mean()
loss_my.backward()
loss_ref.backward()

print("w grad close:", torch.allclose(myconv.weight.grad, target_conv.weight.grad, atol=1e-6, rtol=1e-5))
if myconv.bias is not None:
    print("b grad close:", torch.allclose(myconv.bias.grad, target_conv.bias.grad, atol=1e-6, rtol=1e-5))


torch.Size([2, 2, 3, 3]) torch.Size([2, 2, 3, 3])
tensor([[[[ 2.3729, -2.6914,  0.6745],
          [ 0.8500, -1.1675,  1.9616],
          [-3.9082,  2.1504, -0.4034]],

         [[ 1.9253,  0.2542, -1.2055],
          [ 0.0866, -1.6695,  1.0285],
          [ 2.3610,  2.2851, -1.1022]]],


        [[[ 3.5413, -0.6972,  0.4987],
          [-2.8518,  3.5412, -3.5332],
          [-3.5641,  1.1669, -1.1376]],

         [[-1.9640, -1.0002,  1.8033],
          [-0.7915, -0.7532, -0.6293],
          [ 3.3096,  3.2629, -0.1718]]]], device='cuda:0',
       grad_fn=<MyConv2dFunctionBackward>)
tensor([[[[ 2.3729, -2.6914,  0.6745],
          [ 0.8500, -1.1675,  1.9616],
          [-3.9082,  2.1504, -0.4034]],

         [[ 1.9253,  0.2542, -1.2055],
          [ 0.0866, -1.6695,  1.0285],
          [ 2.3610,  2.2851, -1.1022]]],


        [[[ 3.5413, -0.6972,  0.4987],
          [-2.8518,  3.5412, -3.5332],
          [-3.5641,  1.1669, -1.1376]],

         [[-1.9640, -1.0002,  1.8033],
          [-0