## custom FC test


In [None]:
# custom FC test

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

class MyLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = F.linear(input, weight, bias)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # print('grad_output 크기', grad_output.shape, '\n')

        if ctx.needs_input_grad[0]:
            grad_input = grad_output @ weight
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t() @ input
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        # print('input 크기', input.shape)
        # print('weight 크기', weight.shape)
        # print('bias 크기', bias.shape)

        # print ('grad_input 크기', grad_input.shape)
        # print ('grad_weight 크기', grad_weight.shape)
        # print ('grad_bias 크기', grad_bias.shape)

        # print(grad_weight.sum(axis=0))

        return grad_input, grad_weight, grad_bias


# Usage
BATCH = 2
IN_FEATURE = 3
OUT_FEATURE = 5   

# data
input = torch.randn(BATCH, IN_FEATURE, requires_grad=True)
input2 = input.clone().detach().requires_grad_(True)
weight = torch.randn(OUT_FEATURE, IN_FEATURE, requires_grad=True)
weight2 = weight.clone().detach().requires_grad_(True)
bias = torch.randn(OUT_FEATURE, requires_grad=True)
bias2 = bias.clone().detach().requires_grad_(True)

print("\n\nnn.linear")
print("nn.linear")
print("nn.linear")
linear = nn.Linear(IN_FEATURE, OUT_FEATURE)
linear.weight = nn.Parameter(weight2) 
linear.bias = nn.Parameter(bias2)

# Forward pass
output_linear = linear(input2)

# Backward pass
output_linear.sum().backward()
print('input2.grad.shape',input2.grad.shape)
# print(input2.grad)
print('linear.weight.grad.shape',linear.weight.grad.shape)
# print(linear.weight.grad)
print('linear.bias.shape',linear.bias.shape)
# print(linear.bias.grad)

my_linear = MyLinear.apply


print("\n\nMylinear")
print("Mylinear")
print("Mylinear")
output = my_linear(input, weight, bias)
output.sum().backward()
print('input.grad.shape',input.grad.shape)
# print(input.grad)
print('weight.grad.shape',weight.grad.shape)
# print(weight.grad)
print('bias.shape',bias.shape)
# print(bias.grad)


print("\nsame?")

print('input_grad',torch.allclose(input.grad, input2.grad))
print('weight_grad', torch.allclose(weight.grad, linear.weight.grad)) 
print('linear_grad', torch.allclose(bias.grad, linear.bias.grad)) 

## custom Conv test

In [None]:
#custom Conv test


import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

class MyConv2d(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):
        # Save input and weight for backward pass
        ctx.save_for_backward(input, weight, bias)
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups

        # Perform forward pass
        output = F.conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve tensors from the forward pass
        input, weight, bias = ctx.saved_tensors
        stride = ctx.stride
        padding = ctx.padding
        dilation = ctx.dilation
        groups = ctx.groups

        # Compute gradients w.r.t input and weight
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = F.conv_transpose2d(grad_output, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output,
                                                      stride=stride, padding=padding,
                                                      dilation=dilation, groups=groups)
        if bias is not None and ctx.needs_input_grad[2]:
            # grad_bias = grad_output.sum(0).squeeze(0)
            # grad_bias = grad_output.sum(0).sum(-1).sum(-1)
            grad_bias = grad_output.sum((0, -1, -2))

        return grad_input, grad_weight, grad_bias, None, None, None, None
    
    
# Usage
BATCH = 2
CHANNEL = 2
KERNEL = 3
INPUT_H_W = 5
WEIGHT_R_C = 3
PADDING = 1

# data
input = torch.randn(BATCH, CHANNEL, INPUT_H_W, INPUT_H_W, requires_grad=True)
input2 = input.clone().detach().requires_grad_(True)
weight = torch.randn(KERNEL, CHANNEL, WEIGHT_R_C, WEIGHT_R_C, requires_grad=True)
weight2 = weight.clone().detach().requires_grad_(True)
bias = torch.randn(KERNEL, requires_grad=True)
bias2 = bias.clone().detach().requires_grad_(True)


print("\n\nnn.Conv2d")
print("nn.Conv2d")
print("nn.Conv2d")
conv = nn.Conv2d(CHANNEL, KERNEL, WEIGHT_R_C, 1, PADDING, 1, 1)
conv.weight = nn.Parameter(weight2)
conv.bias = nn.Parameter(bias2)

# Forward pass
output_conv = conv(input2)

# Backward pass
output_conv.sum().backward()

print('input2.grad.shape', input2.grad.shape)
# print('input2.grad', input2.grad)
print('conv.weight.grad.shape', conv.weight.grad.shape)
# print('conv.weight.grad',conv.weight.grad)
print('conv.bias.grad.shape', conv.bias.grad.shape)
# print('conv.bias.grad',conv.bias.grad)





print("\n\nMyConv2d")
print("MyConv2d")
print("MyConv2d")
my_conv = MyConv2d.apply

output = my_conv(input, weight, bias,1,PADDING,1,1)
output.sum().backward()
print('input.grad.shape', input.grad.shape)
# print('input.grad', input.grad)
print('weight.grad.shape', weight.grad.shape)
# print('weight.grad', weight.grad)
print('bias.grad.shape',bias.grad.shape)
# print('bias.grad',bias.grad)









print("\n\nsame?\n")
print('input_grad',torch.allclose(input.grad, input2.grad))
print('weight_grad',torch.allclose(weight.grad, conv.weight.grad))
print('linear_grad', torch.allclose(bias.grad, conv.bias.grad))


