In [1]:
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
from torchsummary import summary
import pdb

In [2]:
class LinearFunction(torch.autograd.Function):
    @staticmethod
    # ctx is the first argument to forward
    def forward(ctx, input, weight, bias=None):
        # The forward pass can use ctx.
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        print(f'baclward: ', input)
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
#         pdb.set_trace()
        grad_weight[:input.shape[0]] = input
        print(f'grad_weight: ', grad_weight)
        return grad_input, grad_weight, grad_bias



class malLinear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super().__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.weight = nn.Parameter(torch.empty(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(output_features))
        else:
            self.register_parameter('bias', None)
        nn.init.uniform_(self.weight, -0.1, 0.1)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -0.1, 0.1)

    def forward(self, input):
        return LinearFunction.apply(input, self.weight, self.bias)

    def extra_repr(self):
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None
        )

In [3]:
class BasicBlock(nn.Module):
    """
    ResNet for CIFAR code from 
    ` https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`
    """
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
            )

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, cifar:bool=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        if cifar:
          self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        else:
          self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if cifar:
          self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        else:
          self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=2)
        
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = malLinear(512*block.expansion, 10)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)    
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        print(f'forward:  {out}')
        out = self.linear(out)
        return out


In [4]:
def ResNet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


In [5]:
m = ResNet18(num_classes=4, cifar=False)

In [6]:
def load_image_imgnet(path:str, size:int=224):
    img = Image.open(path)
    p = transforms.Compose([transforms.Resize((size, size), Image.BICUBIC)])
    img = np.array(p(img))
    img = img.transpose(2,0,1).reshape(1,3,224,224)
    print(f'loaded image w/ shape {img.shape}')
    return img

In [None]:
x = load_image_imgnet('data/building.jpg')
x.shape, type(x)
x_t = torch.tensor(x).float()

In [8]:
x_t = torch.cat((x_t.clone().detach(), x_t.clone().detach()), 0)
x_t.shape

torch.Size([2, 3, 224, 224])

In [9]:
c = m(x_t)
c

forward:  tensor([[0.7111, 1.7504, 0.4832,  ..., 0.0518, 0.0000, 0.8694],
        [0.7111, 1.7504, 0.4832,  ..., 0.0518, 0.0000, 0.8694]],
       grad_fn=<ViewBackward0>)


tensor([[-2.5440,  3.1191, -2.9588,  0.8289,  0.6569, -1.8108,  2.8210, -0.1291,
          1.9475,  2.5500],
        [-2.5440,  3.1191, -2.9588,  0.8289,  0.6569, -1.8108,  2.8210, -0.1291,
          1.9475,  2.5500]], grad_fn=<LinearFunctionBackward>)

In [10]:
l = nn.MSELoss()
b = l(c, torch.zeros(c.shape))
b.backward()

baclward:  tensor([[0.7111, 1.7504, 0.4832,  ..., 0.0518, 0.0000, 0.8694],
        [0.7111, 1.7504, 0.4832,  ..., 0.0518, 0.0000, 0.8694]],
       grad_fn=<ViewBackward0>)
grad_weight:  tensor([[ 7.1113e-01,  1.7504e+00,  4.8317e-01,  ...,  5.1783e-02,
          0.0000e+00,  8.6944e-01],
        [ 7.1113e-01,  1.7504e+00,  4.8317e-01,  ...,  5.1783e-02,
          0.0000e+00,  8.6944e-01],
        [-4.2082e-01, -1.0358e+00, -2.8592e-01,  ..., -3.0643e-02,
          0.0000e+00, -5.1450e-01],
        ...,
        [-1.8361e-02, -4.5194e-02, -1.2475e-02,  ..., -1.3370e-03,
          0.0000e+00, -2.2448e-02],
        [ 2.7699e-01,  6.8178e-01,  1.8820e-01,  ...,  2.0170e-02,
          0.0000e+00,  3.3865e-01],
        [ 3.6267e-01,  8.9268e-01,  2.4641e-01,  ...,  2.6409e-02,
          0.0000e+00,  4.4340e-01]])


In [11]:
(m.linear.weight.grad[1] == m.linear.weight.grad[0]).all()

tensor(True)