In [1]:
import torch
import torch.nn as nn
from torch.nn import  functional as F
import numpy as np
import os

In [54]:
class ResidualBlock(nn.Module):
    '''
    实现子module: Residual Block
    '''
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
                nn.Conv2d(inchannel,outchannel,3,stride, 1,bias=False),
                nn.BatchNorm2d(outchannel),
                nn.ReLU(inplace=True),
                nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),
                nn.BatchNorm2d(outchannel) )
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)

    def save(self, dir, conv_num, bn_num):
        conv1 = self.left[0].weight.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "conv{}_weight.npy".format(str(conv_num))), conv1)
        conv_num += 1

        bn1w = self.left[1].weight.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_weight.npy".format(str(bn_num))), bn1w)
        bn1b = self.left[1].bias.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_bias.npy".format(str(bn_num))), bn1b)
        bn1m = self.left[1].running_mean.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_mean.npy".format(str(bn_num))), bn1m)
        bn1v = self.left[1].running_var.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_var.npy".format(str(bn_num))), bn1v)
        bn_num += 1

        conv2 = self.left[3].weight.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "conv{}_weight.npy".format(str(conv_num))), conv2)
        conv_num += 1

        bn2w = self.left[4].weight.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_weight.npy".format(str(bn_num))), bn2w)
        bn2b = self.left[4].bias.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_bias.npy".format(str(bn_num))), bn2b)
        bn2m = self.left[4].running_mean.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_mean.npy".format(str(bn_num))), bn2m)
        bn2v = self.left[4].running_var.data.detach().cpu().numpy()
        np.save(os.path.join(dir, "bn{}_var.npy".format(str(bn_num))), bn2v)
        bn_num += 1

        if self.right is not None:
            conv3 = self.right[0].weight.data.detach().cpu().numpy()
            np.save(os.path.join(dir, "conv{}_weight.npy".format(str(conv_num))), conv3)
            conv_num += 1

            bn3w = self.right[1].weight.data.detach().cpu().numpy()
            np.save(os.path.join(dir, "bn{}_weight.npy".format(str(bn_num))), bn3w)
            bn3b = self.right[1].bias.data.detach().cpu().numpy()
            np.save(os.path.join(dir, "bn{}_bias.npy".format(str(bn_num))), bn3b)
            bn3m = self.right[1].running_mean.data.detach().cpu().numpy()
            np.save(os.path.join(dir, "bn{}_mean.npy".format(str(bn_num))), bn3m)
            bn3v = self.right[1].running_var.data.detach().cpu().numpy()
            np.save(os.path.join(dir, "bn{}_var.npy".format(str(bn_num))), bn3v)
            bn_num += 1

        return conv_num, bn_num
    
    def load(self, dir, conv_num, bn_num):
        conv1 = np.load(os.path.join(dir, "conv{}_weight.npy".format(conv_num)))
        self.left[0].weight.data = torch.from_numpy(conv1)
        conv_num += 1

        bn1w = np.load(os.path.join(dir, "bn{}_weight.npy".format(bn_num)))
        self.left[1].weight.data = torch.from_numpy(bn1w)
        bn1b = np.load(os.path.join(dir, "bn{}_bias.npy".format(bn_num)))
        self.left[1].bias.data = torch.from_numpy(bn1b)
        bn1m = np.load(os.path.join(dir, "bn{}_mean.npy".format(bn_num)))
        self.left[1].running_mean.data = torch.from_numpy(bn1m)
        bn1v = np.load(os.path.join(dir, "bn{}_var.npy".format(bn_num)))
        self.left[1].running_var.data = torch.from_numpy(bn1v)
        bn_num += 1

        conv2 = np.load(os.path.join(dir, "conv{}_weight.npy".format(conv_num)))
        self.left[3].weight.data = torch.from_numpy(conv2)
        conv_num += 1

        bn2w = np.load(os.path.join(dir, "bn{}_weight.npy".format(bn_num)))
        self.left[4].weight.data = torch.from_numpy(bn2w)
        bn2b = np.load(os.path.join(dir, "bn{}_bias.npy".format(bn_num)))
        self.left[4].bias.data = torch.from_numpy(bn2b)
        bn2m = np.load(os.path.join(dir, "bn{}_mean.npy".format(bn_num)))
        self.left[4].running_mean.data = torch.from_numpy(bn2m)
        bn2v = np.load(os.path.join(dir, "bn{}_var.npy".format(bn_num)))
        self.left[4].running_var.data = torch.from_numpy(bn2v)
        bn_num += 1

        if self.right is not None:
            conv3 = np.load(os.path.join(dir, "conv{}_weight.npy".format(conv_num)))
            self.right[0].weight.data = torch.from_numpy(conv3)
            conv_num += 1

            bn3w = np.load(os.path.join(dir, "bn{}_weight.npy".format(bn_num)))
            self.right[1].weight.data = torch.from_numpy(bn3w)
            bn3b = np.load(os.path.join(dir, "bn{}_bias.npy".format(bn_num)))
            self.right[1].bias.data = torch.from_numpy(bn3b)
            bn3m = np.load(os.path.join(dir, "bn{}_mean.npy".format(bn_num)))
            self.right[1].running_mean.data = torch.from_numpy(bn3m)
            bn3v = np.load(os.path.join(dir, "bn{}_var.npy".format(bn_num)))
            self.right[1].running_var.data = torch.from_numpy(bn3v)
            bn_num += 1

        return conv_num, bn_num

In [60]:
class conv_layer:

    def __init__(self, in_channels, out_channels, kernel_h, kernel_w, same = True, stride = 1, shift = True):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_h = kernel_h
        self.kernel_w = kernel_w
        self.same = same
        self.stride = stride
        self.shift = shift

        self.init_param()

    def init_param(self):
        self.kernel = np.random.uniform(
            low = -np.sqrt(6.0/(self.out_channels + self.in_channels * self.kernel_h * self.kernel_w)),
            high = np.sqrt(6.0/(self.in_channels + self.out_channels * self.kernel_h * self.kernel_w)),
            size = (self.out_channels, self.in_channels, self.kernel_h, self.kernel_w)
        )
        self.bias = np.zeros([self.out_channels]) if self.shift else None

    @staticmethod
    def pad(in_tensor, pad_h, pad_w):
        batch_num = in_tensor.shape[0]
        in_channels = in_tensor.shape[1]
        in_h = in_tensor.shape[2]
        in_w = in_tensor.shape[3]
        padded = np.zeros([batch_num, in_channels, in_h + 2*pad_h, in_w + 2*pad_w])
        padded[:, :, pad_h:pad_h+in_h, pad_w:pad_w+in_w] = in_tensor
        return padded
    
    @staticmethod
    def convolution(in_tensor, kernel, stride = 1, dilate = 1):
        batch_num = in_tensor.shape[0]
        in_channels = in_tensor.shape[1]
        in_h = in_tensor.shape[2]
        in_w = in_tensor.shape[3]
        out_channels = kernel.shape[0]
        assert kernel.shape[1] == in_channels
        kernel_h = kernel.shape[2]
        kernel_w = kernel.shape[3]
        
        out_h = int((in_h - kernel_h + 1)/stride)
        out_w = int((in_w - kernel_w + 1)/stride)
        
        kernel = kernel.reshape(out_channels, -1)
        
        extend_in = np.zeros([in_channels*kernel_h*kernel_w, batch_num*out_h*out_w])
        for i in range(out_h):
            for j in range(out_w):
                part_in = in_tensor[:, :, i*stride:i*stride+kernel_h, j*stride:j*stride+kernel_w].reshape(batch_num, -1)
                extend_in[:, (i*out_w+j)*batch_num:(i*out_w+j+1)*batch_num] = part_in.T
        
        out_tensor = np.dot(kernel, extend_in)
        out_tensor = out_tensor.reshape(out_channels, out_h*out_w, batch_num)
        out_tensor = out_tensor.transpose(2,0,1).reshape(batch_num, out_channels, out_h, out_w) 
        
        return out_tensor
    
    def forward(self, in_tensor):
        if self.same:
            in_tensor = conv_layer.pad(in_tensor, int((self.kernel_h-1)/2), int((self.kernel_w-1)/2))
        
        self.in_tensor = in_tensor.copy()
        
        self.out_tensor = conv_layer.convolution(in_tensor, self.kernel, self.stride)

        if self.shift:
            self.out_tensor += self.bias.reshape(1,self.out_channels,1,1)

        return self.out_tensor
    
    def backward(self, out_diff_tensor, lr):
        assert out_diff_tensor.shape == self.out_tensor.shape

        if self.shift:
            bias_diff = np.sum(out_diff_tensor, axis = (0,2,3)).reshape(self.bias.shape)
            self.bias -= lr * bias_diff
        
        batch_num = out_diff_tensor.shape[0]
        out_channels = out_diff_tensor.shape[1]
        out_h = out_diff_tensor.shape[2]
        out_w = out_diff_tensor.shape[3]
        extend_out = np.zeros([batch_num, out_channels, out_h, out_w, self.stride * self.stride])
        extend_out[:, :, :, :, 0] = out_diff_tensor
        extend_out = extend_out.reshape(batch_num, out_channels, out_h, out_w, self.stride, self.stride)
        extend_out = extend_out.transpose(0,1,2,4,3,5).reshape(batch_num, out_channels, out_h*self.stride, out_w*self.stride)

        kernel_diff = conv_layer.convolution(self.in_tensor.transpose(1,0,2,3), extend_out.transpose(1,0,2,3))
        kernel_diff = kernel_diff.transpose(1,0,2,3)
        
        padded = conv_layer.pad(extend_out, self.kernel_h-1, self.kernel_w-1)
        kernel_trans = self.kernel.reshape(self.out_channels, self.in_channels, self.kernel_h*self.kernel_w)
        kernel_trans = kernel_trans[:,:,::-1].reshape(self.kernel.shape)
        self.in_diff_tensor = conv_layer.convolution(padded, kernel_trans.transpose(1,0,2,3))
        if self.same:
            pad_h = int((self.kernel_h-1)/2)
            pad_w = int((self.kernel_w-1)/2)
            self.in_diff_tensor = self.in_diff_tensor[:, :, pad_h:-pad_h, pad_w:-pad_w]
            
        self.kernel -= lr * kernel_diff

    def save(self, path, conv_num):
        if os.path.exists(path) == False:
            os.mkdir(path)

        np.save(os.path.join(path, "conv{}_weight.npy".format(conv_num)), self.kernel)
        if self.shift:
            np.save(os.path.join(path, "conv{}_bias.npy".format(conv_num)), self.bias)
        
        return conv_num + 1

    def load(self, path, conv_num):
        assert os.path.exists(path)

        self.kernel = np.load(os.path.join(path, "conv{}_weight.npy".format(conv_num)))
        if self.shift:
            self.bias = np.load(os.path.join(path, "conv{}_bias.npy").format(conv_num))
        
        return conv_num + 1
    
class relu:

    def forward(self, in_tensor):
        self.in_tensor = in_tensor.copy()
        self.out_tensor = in_tensor.copy()
        self.out_tensor[self.in_tensor < 0.0] = 0.0
        return self.out_tensor

    def backward(self, out_diff_tensor, lr = 0):
        assert self.out_tensor.shape == out_diff_tensor.shape
        self.in_diff_tensor = out_diff_tensor.copy()
        self.in_diff_tensor[self.in_tensor < 0.0] = 0.0



class bn_layer:

    def __init__(self, neural_num, moving_rate = 0.1):
        self.gamma = np.random.uniform(low=0, high=1, size=neural_num)
        self.bias = np.zeros([neural_num])
        self.moving_avg = np.zeros([neural_num])
        self.moving_var = np.ones([neural_num])
        self.neural_num = neural_num
        self.moving_rate = moving_rate
        self.is_train = True
        self.epsilon = 1e-5

    def train(self):
        self.is_train = True

    def eval(self):
        self.is_train = False

    def forward(self, in_tensor):
        assert in_tensor.shape[1] == self.neural_num

        self.in_tensor = in_tensor.copy()

        if self.is_train:
            mean = in_tensor.mean(axis=(0,2,3))
            var = in_tensor.var(axis=(0,2,3))
            self.moving_avg = mean * self.moving_rate + (1 - self.moving_rate) * self.moving_avg
            self.moving_var = var * self.moving_rate + (1 - self.moving_rate) * self.moving_var
            self.var = var
            self.mean = mean
        else:
            mean = self.moving_avg
            var = self.moving_var

        self.normalized = (in_tensor - mean.reshape(1,-1,1,1)) / np.sqrt(var.reshape(1,-1,1,1)+self.epsilon)
        out_tensor = self.gamma.reshape(1,-1,1,1) * self.normalized + self.bias.reshape(1,-1,1,1)

        return out_tensor

    def backward(self, out_diff_tensor, lr):
        assert out_diff_tensor.shape == self.in_tensor.shape
        assert self.is_train

        m = self.in_tensor.shape[0] * self.in_tensor.shape[2] * self.in_tensor.shape[3]

        normalized_diff = self.gamma.reshape(1,-1,1,1) * out_diff_tensor
        var_diff = -0.5 * np.sum(normalized_diff*self.normalized, axis=(0,2,3)) / (self.var + self.epsilon)
        mean_diff = -1.0 * np.sum(normalized_diff, axis=(0,2,3)) / np.sqrt(self.var + self.epsilon)
        in_diff_tensor1 = normalized_diff / np.sqrt(self.var.reshape(1,-1,1,1)+self.epsilon)
        in_diff_tensor2 = var_diff.reshape(1,-1,1,1) * (self.in_tensor - self.mean.reshape(1,-1,1,1)) * 2 / m
        in_diff_tensor3 = mean_diff.reshape(1,-1,1,1) / m
        self.in_diff_tensor = in_diff_tensor1 + in_diff_tensor2 + in_diff_tensor3

        gamma_diff = np.sum(self.normalized * out_diff_tensor, axis=(0,2,3))
        self.gamma -= lr * gamma_diff

        bias_diff = np.sum(out_diff_tensor, axis=(0,2,3))
        self.bias -= lr * bias_diff 

    def save(self, path, bn_num):
        if os.path.exists(path) == False:
            os.mkdir(path)

        np.save(os.path.join(path, "bn{}_weight.npy".format(bn_num)), self.gamma)
        np.save(os.path.join(path, "bn{}_bias.npy".format(bn_num)), self.bias)
        np.save(os.path.join(path, "bn{}_mean.npy".format(bn_num)), self.moving_avg)
        np.save(os.path.join(path, "bn{}_var.npy".format(bn_num)), self.moving_var)

        return bn_num + 1

    def load(self, path, bn_num):
        assert os.path.exists(path)

        self.gamma = np.load(os.path.join(path, "bn{}_weight.npy".format(bn_num)))
        self.bias = np.load(os.path.join(path, "bn{}_bias.npy".format(bn_num)))
        self.moving_avg = np.load(os.path.join(path, "bn{}_mean.npy".format(bn_num)))
        self.moving_var = np.load(os.path.join(path, "bn{}_var.npy".format(bn_num)))

        return bn_num + 1

class ResBlock:

    def __init__(self, in_channels, out_channels, stride=1, shortcut=None):
        self.path1 = [
            conv_layer(in_channels, out_channels, 3, 3, stride = stride, shift=False),
            bn_layer(out_channels),
            relu(),
            conv_layer(out_channels, out_channels, 3, 3, shift=False),
            bn_layer(out_channels)
        ]
        self.path2 = shortcut
        self.relu = relu()
    
    def train(self):
        self.path1[1].train()
        self.path1[4].train()
        if self.path2 is not None:
            self.path2[1].train()

    def eval(self):
        self.path1[1].eval()
        self.path1[4].eval()
        if self.path2 is not None:
            self.path2[1].eval()

    def forward(self, in_tensor):
        x1 = in_tensor.copy()
        x2 = in_tensor.copy()

        for l in self.path1:
            x1 = l.forward(x1)
        if self.path2 is not None:
            for l in self.path2:
                x2 = l.forward(x2)
        self.out_tensor = self.relu.forward(x1+x2)

        return self.out_tensor


    def save(self, path, conv_num, bn_num):
        conv_num = self.path1[0].save(path, conv_num)
        bn_num = self.path1[1].save(path, bn_num)
        conv_num = self.path1[3].save(path, conv_num)
        bn_num = self.path1[4].save(path, bn_num)

        if self.path2 is not None:
            conv_num = self.path2[0].save(path, conv_num)
            bn_num = self.path2[1].save(path, bn_num)

        return conv_num, bn_num

    def load(self, path, conv_num, bn_num):
        conv_num = self.path1[0].load(path, conv_num)
        bn_num = self.path1[1].load(path, bn_num)
        conv_num = self.path1[3].load(path, conv_num)
        bn_num = self.path1[4].load(path, bn_num)

        if self.path2 is not None:
            conv_num = self.path2[0].load(path, conv_num)
            bn_num = self.path2[1].load(path, bn_num)

        return conv_num, bn_num

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from components import *
from network import *
from model import *

In [7]:
in_n = np.random.randn(2,3,64,64)
in_t = torch.from_numpy(in_n).float()
in_t.requires_grad_(True);

In [8]:
bn1 = resnet34(20)
bn2 = ResNet(20)
bn2.save("model7")
bn1.load("model7")

In [9]:
out1 = bn1.forward(in_n)
out2 = bn2(in_t)
print(out1)
print(out2)

torch.Size([2, 64, 16, 16]) torch.Size([2, 64, 16, 16])
torch.Size([2, 64, 16, 16]) torch.Size([2, 64, 16, 16])
torch.Size([2, 64, 16, 16]) torch.Size([2, 64, 16, 16])
torch.Size([2, 128, 8, 8]) torch.Size([2, 128, 8, 8])
torch.Size([2, 128, 8, 8]) torch.Size([2, 128, 8, 8])
torch.Size([2, 128, 8, 8]) torch.Size([2, 128, 8, 8])
torch.Size([2, 128, 8, 8]) torch.Size([2, 128, 8, 8])
torch.Size([2, 256, 4, 4]) torch.Size([2, 256, 4, 4])
torch.Size([2, 256, 4, 4]) torch.Size([2, 256, 4, 4])
torch.Size([2, 256, 4, 4]) torch.Size([2, 256, 4, 4])
torch.Size([2, 256, 4, 4]) torch.Size([2, 256, 4, 4])
torch.Size([2, 256, 4, 4]) torch.Size([2, 256, 4, 4])
torch.Size([2, 256, 4, 4]) torch.Size([2, 256, 4, 4])
torch.Size([2, 512, 2, 2]) torch.Size([2, 512, 2, 2])
torch.Size([2, 512, 2, 2]) torch.Size([2, 512, 2, 2])
torch.Size([2, 512, 2, 2]) torch.Size([2, 512, 2, 2])
[[0.47470274 0.53944822 0.51017137 0.30650067 0.49524024 0.53134661
  0.50256288 0.66426079 0.46259805 0.37517949 0.54271628 0.583



In [10]:
k = np.random.uniform(0,1,out1.shape)
bn1.backward(k,1)
print(bn1.in_diff_tensor)

[[[[ 5.90819158e-03  3.43736262e-03  1.50333080e-03 ...  2.64158754e-03
     3.52910164e-05 -9.12694691e-04]
   [ 1.23524310e-02  4.55476877e-03 -4.08180879e-03 ... -5.30715725e-04
    -8.11735846e-04 -1.82933754e-03]
   [ 5.72507029e-03 -2.77475036e-02  9.37338680e-04 ... -1.62348890e-03
     2.00137579e-04 -2.56272212e-03]
   ...
   [-3.22040804e-03  5.49114252e-03  7.41648161e-03 ...  5.05466426e-03
    -2.20717292e-03 -1.28212594e-03]
   [-5.76905078e-04  3.01659316e-03  6.66994189e-05 ...  1.19905746e-03
     1.20188261e-03  4.41400115e-03]
   [-2.16947151e-03  5.78977073e-03 -1.69361598e-03 ...  3.73768499e-03
     2.63824601e-03 -7.48866818e-04]]

  [[-2.60327317e-03  2.23995162e-03 -3.47933889e-04 ... -3.34248006e-03
    -5.34795291e-03  1.63161807e-03]
   [-1.68622785e-02  8.84059937e-03  6.17426215e-03 ... -5.87906103e-03
     5.10255968e-04 -3.10303440e-03]
   [ 1.81166008e-04 -5.10856921e-03 -6.16627417e-03 ...  2.35884178e-03
    -2.29381486e-03 -2.69982435e-03]
   ...
   

In [11]:
l=torch.sum(torch.from_numpy(k).float()*out2)
l.backward()
in_t.grad

tensor([[[[ 5.9043e-03,  3.4333e-03,  1.5049e-03,  ...,  2.6432e-03,
            3.6375e-05, -9.1327e-04],
          [ 1.2349e-02,  4.5567e-03, -4.0727e-03,  ..., -5.3365e-04,
           -8.1134e-04, -1.8312e-03],
          [ 5.7230e-03, -2.7743e-02,  9.3942e-04,  ..., -1.6268e-03,
            2.0067e-04, -2.5628e-03],
          ...,
          [-3.2228e-03,  5.4898e-03,  7.4173e-03,  ...,  5.0559e-03,
           -2.2053e-03, -1.2808e-03],
          [-5.7643e-04,  3.0170e-03,  6.7040e-05,  ...,  1.1989e-03,
            1.2030e-03,  4.4135e-03],
          [-2.1686e-03,  5.7898e-03, -1.6913e-03,  ...,  3.7366e-03,
            2.6377e-03, -7.4924e-04]],

         [[-2.5999e-03,  2.2403e-03, -3.5068e-04,  ..., -3.3393e-03,
           -5.3478e-03,  1.6316e-03],
          [-1.6861e-02,  8.8377e-03,  6.1746e-03,  ..., -5.8805e-03,
            5.0853e-04, -3.1026e-03],
          [ 1.7967e-04, -5.1077e-03, -6.1697e-03,  ...,  2.3568e-03,
           -2.2931e-03, -2.6975e-03],
          ...,
     

In [52]:
print(bn2.left[4].weight.grad)
print(record[3]-bn1.path1[4].gamma)

tensor([ 1.1314,  0.3261, -0.0415])
[ 2.219511   -0.45285076 -0.7259828 ]


In [25]:
relu1 = relu()
relu2 = nn.ReLU()

In [26]:
in_n = np.random.randn(1,3,4,4)
in_t = torch.from_numpy(in_n).float()
in_t.requires_grad_(True)

tensor([[[[-0.8657,  0.7099,  1.3849, -2.6403],
          [ 0.7308,  0.8718,  0.2971,  1.6389],
          [-0.7588,  0.1648, -0.9381,  0.2895],
          [ 2.1914,  1.7482,  0.7808, -2.2425]],

         [[-0.2162,  0.9270,  1.2617, -0.3084],
          [-1.1653, -0.4890,  1.0565, -1.0408],
          [ 0.9162, -0.4509, -0.5819,  0.4063],
          [-3.6674, -0.2084, -0.1758,  0.1968]],

         [[ 0.5012,  0.0386,  1.1023,  1.4681],
          [ 2.4505,  1.7032, -0.0731, -0.8587],
          [ 0.0627,  1.3656, -1.3541, -1.7107],
          [ 1.4836, -2.0714, -2.0978,  0.8290]]]], requires_grad=True)

In [27]:
out1 = relu1.forward(in_n)
print(out1)
out2 = relu2(in_t)
print(out2)

[[[[0.         0.70992776 1.38488918 0.        ]
   [0.7308161  0.87184872 0.29706023 1.6389423 ]
   [0.         0.16479439 0.         0.28948098]
   [2.1914497  1.74815641 0.7808052  0.        ]]

  [[0.         0.92701882 1.26173421 0.        ]
   [0.         0.         1.05650258 0.        ]
   [0.91617213 0.         0.         0.40627932]
   [0.         0.         0.         0.19680612]]

  [[0.50123578 0.03857217 1.1023445  1.46813431]
   [2.45050671 1.70323962 0.         0.        ]
   [0.0626919  1.36563698 0.         0.        ]
   [1.48359461 0.         0.         0.82895613]]]]
tensor([[[[0.0000, 0.7099, 1.3849, 0.0000],
          [0.7308, 0.8718, 0.2971, 1.6389],
          [0.0000, 0.1648, 0.0000, 0.2895],
          [2.1914, 1.7482, 0.7808, 0.0000]],

         [[0.0000, 0.9270, 1.2617, 0.0000],
          [0.0000, 0.0000, 1.0565, 0.0000],
          [0.9162, 0.0000, 0.0000, 0.4063],
          [0.0000, 0.0000, 0.0000, 0.1968]],

         [[0.5012, 0.0386, 1.1023, 1.4681],
     

In [28]:
relu1.backward(np.ones(list(out1.shape)))
print(relu1.in_diff_tensor)

[[[[0. 1. 1. 0.]
   [1. 1. 1. 1.]
   [0. 1. 0. 1.]
   [1. 1. 1. 0.]]

  [[0. 1. 1. 0.]
   [0. 0. 1. 0.]
   [1. 0. 0. 1.]
   [0. 0. 0. 1.]]

  [[1. 1. 1. 1.]
   [1. 1. 0. 0.]
   [1. 1. 0. 0.]
   [1. 0. 0. 1.]]]]


In [29]:
l = torch.sum(out2)
l.backward()
print(in_t.grad)

tensor([[[[0., 1., 1., 0.],
          [1., 1., 1., 1.],
          [0., 1., 0., 1.],
          [1., 1., 1., 0.]],

         [[0., 1., 1., 0.],
          [0., 0., 1., 0.],
          [1., 0., 0., 1.],
          [0., 0., 0., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 0., 0.],
          [1., 1., 0., 0.],
          [1., 0., 0., 1.]]]])
