In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
%matplotlib inline

from pclib.nn.models import ConvClassifierInv, ConvClassifier
from pclib.nn.layers import Conv2d
from pclib.optim.train_conv import train_conv as train
from pclib.utils.plot import plot_stats
from pclib.optim.eval import track_vfe, accuracy
from pclib.utils.functional import format_y
from pclib.utils.customdataset import PreloadedDataset

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [57]:
class Stats(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('output', torch.zeros(0))

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.stats = Stats()

        self.layers = nn.Sequential(
            nn.Linear(10, 1),
            nn.ReLU(),
        )
    
    def forward(self, x):
        out = self.layers(x) + torch.randn_like(x)
        self.stats.output = torch.cat([self.stats.output, out.mean().view(1)])
        return self.layers(x)

model = Model()
x = torch.randn(1, 10)
model(x)
model(x)
model(x)
state_dict = model.state_dict()

model = Model()
model.load_state_dict(state_dict)
model.stats.output

RuntimeError: Error(s) in loading state_dict for Model:
	size mismatch for stats.output: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([0]).

In [3]:
def connectivity_matrix(in_shape, kernel_size, stride, padding):
    """
    Returns a connectivity matrix for a convolutional layer.
    """
    in_h, in_w = in_shape
    # out_h, out_w = out_shape
    out_h, out_w = (in_h + 2*padding[0] - kernel_size[0])//stride[0] + 1, (in_w + 2*padding[1] - kernel_size[1])//stride[1] + 1
    k_h, k_w = kernel_size
    s_h, s_w = stride
    p_h, p_w = padding
    
    # Initialize connectivity matrix
    C = torch.zeros(out_h*out_w, in_h*in_w)
    
    # Iterate over output pixels
    for i in range(out_h):
        for j in range(out_w):
            # Iterate over kernel
            for k in range(k_h):
                for l in range(k_w):
                    # Input pixel index
                    ii = i*s_h + k - p_h
                    jj = j*s_w + l - p_w
                    
                    # If input pixel is within bounds
                    if ii >= 0 and ii < in_h and jj >= 0 and jj < in_w:
                        # Input pixel index
                        in_idx = ii*in_w + jj
                        # Output pixel index
                        out_idx = i*out_w + j
                        # Set connectivity
                        C[out_idx, in_idx] = 1
    return C.t()

In [4]:
# torch.manual_seed(42)
in_shape = (1, 300, 300)
kernel_size = 2
stride = 2
padding = 0
out_channels = 1
out_h, out_w = (in_shape[1] + 2*padding - kernel_size)//stride + 1, (in_shape[2] + 2*padding - kernel_size)//stride + 1
print(out_h, out_w)


layer = nn.Conv2d(
    in_shape[0], 
    out_channels, 
    kernel_size,
    stride,
    padding, 
    bias=False
)

mat = connectivity_matrix((in_shape[1], in_shape[2]), (kernel_size, kernel_size), (stride,stride), (padding,padding))

def run(mode='manual'):
    assert mode in ['manual', 'auto'], 'Mode must be manual or auto'
    if mode == 'auto':
        x = torch.randn(in_shape).unsqueeze(0)
        true = torch.randn((out_h, out_w)).unsqueeze(0)

        pred = layer(x)

        error = (true - pred)
        error.pow(2).mean().backward()
    elif mode == 'manual':
        with torch.no_grad():
            x = torch.randn(in_shape).unsqueeze(0)
            true = torch.randn((out_h, out_w)).unsqueeze(0)

            pred = layer(x)

            error = (true - pred)
            outer = -2 * torch.outer(x.flatten(), error.flatten()) * mat
            # nonzeros = outer.nonzero(as_tuple=False)
            # sorted_nonzeros = sorted(nonzeros, key=lambda x: x[1])
            # vals = [outer[i,j] for i,j in sorted_nonzeros]
            # vals_tight = torch.stack([torch.tensor(vals[i::kernel_size**2]) for i in range(kernel_size**2)])
            # grad = vals_tight.mean(dim=1).reshape(layer.weight.shape)
            grad = outer

    # torch.isclose(manual_grad, layer.weight.grad).all()


150 150


In [5]:
torch.manual_seed(42)
in_shape = (3, 1, 3, 3)
kernel_size = 2
stride = 1
padding = 0
out_channels = 1
out_dim = (in_shape[2] + 2*padding - kernel_size)//stride + 1
print(out_dim)


layer = nn.Conv2d(
    in_shape[1], 
    out_channels, 
    kernel_size,
    stride,
    padding, 
    bias=True
)

x1 = torch.randn((3, 1, out_dim, out_dim))
x2 = torch.randn(in_shape, requires_grad=True)

x1_hat = layer(x2)
x2_hat = torch.randn_like(x2)
# x2_hat = x2

e1 = x1 - x1_hat
e2 = x2 - x2_hat

# e1 = torch.zeros_like(e1)

vfe = e1.square().sum() + e2.square().sum()
vfe.backward(retain_graph=True)

# e2.square().sum().backward(retain_graph=True)

x_manualgrad = torch.nn.grad.conv2d_input(x2.shape, layer.weight, e1, stride=stride, padding=padding)
x_manualgrad = -2*(x_manualgrad - e2)
w_grad = -2 * torch.nn.grad.conv2d_weight(x2, layer.weight.shape, e1, stride=stride, padding=padding)
b_grad = -2 * e1.sum(dim=(0,2,3))

# torch.isclose(x_autograd, x_manualgrad)#, torch.isclose(w_grad, layer.weight.grad).all()
x2.grad, x_manualgrad, layer.weight.grad, w_grad, layer.bias.grad, b_grad

2


(tensor([[[[-1.3307, -6.5906, -0.6320],
           [-4.0812, -6.6632,  0.5901],
           [ 3.3621, -4.8324, -6.9174]]],
 
 
         [[[-1.0896,  2.0920, -1.8908],
           [-3.1428,  5.5075, -0.8493],
           [ 1.5190,  1.8407,  2.8830]]],
 
 
         [[[-0.4862, -1.5449,  1.2864],
           [-5.9475,  0.1644,  2.7083],
           [ 3.4047, -2.1701, -3.0102]]]]),
 tensor([[[[-1.3307, -6.5906, -0.6320],
           [-4.0812, -6.6632,  0.5901],
           [ 3.3621, -4.8324, -6.9174]]],
 
 
         [[[-1.0896,  2.0920, -1.8908],
           [-3.1428,  5.5075, -0.8493],
           [ 1.5190,  1.8407,  2.8830]]],
 
 
         [[[-0.4862, -1.5449,  1.2864],
           [-5.9475,  0.1644,  2.7083],
           [ 3.4047, -2.1701, -3.0102]]]], grad_fn=<MulBackward0>),
 tensor([[[[ 11.0869,  11.0813],
           [-10.2167,  16.1286]]]]),
 tensor([[[[ 11.0869,  11.0813],
           [-10.2167,  16.1286]]]], grad_fn=<MulBackward0>),
 tensor([-5.9108]),
 tensor([-5.9108], grad_fn=<MulBackward0

In [141]:
torch.manual_seed(42)
batch_size = 3
in_channels = 1
in_dim = 3
kernel_size = 2
stride = 1
padding = 0
out_channels = 1
out_dim = (in_shape[2] + 2*padding - kernel_size)//stride + 1
print(out_dim)


layer = nn.ConvTranspose2d(
    in_channels,
    out_channels, 
    kernel_size,
    stride,
    padding, 
    bias=True
)

x1 = torch.randn((batch_size, in_channels, in_dim, in_dim))
x2 = torch.randn((batch_size, out_channels, out_dim, out_dim), requires_grad=True)

x1_hat = layer(x2)
x2_hat = torch.randn_like(x2)
# x2_hat = x2

e1 = x1 - x1_hat
e2 = x2 - x2_hat

# e1 = torch.zeros_like(e1)

vfe = e1.square().sum() + e2.square().sum()
vfe.backward(retain_graph=True)

# e2.square().sum().backward(retain_graph=True)

x_manualgrad = torch.nn.grad.conv2d_input(x2.shape, layer.weight, e1, stride=stride, padding=padding)
x_manualgrad = -2*(x_manualgrad - e2)
w_grad = -2 * torch.nn.grad.conv2d_weight(x2, layer.weight.shape, e1, stride=stride, padding=padding)
b_grad = -2 * e1.sum(dim=(0,2,3))

# torch.isclose(x_autograd, x_manualgrad)#, torch.isclose(w_grad, layer.weight.grad).all()
x2.grad, x_manualgrad, layer.weight.grad, w_grad, layer.bias.grad, b_grad

2


RuntimeError: could not create a descriptor for a dilated convolution forward propagation primitive

In [27]:
torch.manual_seed(42)
batch_size = 1
C0, H0, W0 = 1, 4, 4
C1, H1, W1 = 2, 2, 2
kernel_size = 3
stride = 1
padding = 1
upsample = 2

downsample = nn.MaxPool2d(kernel_size=2, stride=2)
layer = Conv2d(
    (C1, H1, W1),
    C0, 
    kernel_size,
    stride,
    padding, 
    upsample=upsample,
    has_bias=True
)

s0 = {
    'x': torch.randn((batch_size, C0, H0, W0)),
    'e': torch.randn((batch_size, C0, H0, W0))
}
s1 = {
    'x': torch.randn((batch_size, C1, H1, W1), requires_grad=True),
    'e': torch.randn((batch_size, C1, H1, W1))
}

s0['pred'] = layer.predict(s1)
s1['pred'] = torch.randn_like(s1['x'])
s1['pred'] = s1['x'].clone()


s0['e'] = s0['x'] - s0['pred']
s1['e'] = s1['x'] - s1['pred']

vfe = s0['e'].square().sum() + s1['e'].square().sum()
# vfe = s0['e'].square().sum()
# vfe = s1['e'].square().sum()

vfe.backward(retain_graph=True)

# print(downsample(s0['e']).shape)

s1x_grad = torch.nn.grad.conv2d_input(s1['x'].shape, layer.conv_td[0].weight, downsample(s0['e']), stride=stride, padding=padding)
print(s1['x'].grad, "\n", 2*s1x_grad)

# x_manualgrad = torch.nn.grad.conv2d_input(x2.shape, layer.weight, e1, stride=stride, padding=padding)
# x_manualgrad = -2*(x_manualgrad - e2)
# w_grad = -2 * torch.nn.grad.conv2d_weight(x2, layer.weight.shape, e1, stride=stride, padding=padding)
# b_grad = -2 * e1.sum(dim=(0,2,3))

# # torch.isclose(x_autograd, x_manualgrad)#, torch.isclose(w_grad, layer.weight.grad).all()
# x2.grad, x_manualgrad, layer.weight.grad, w_grad, layer.bias.grad, b_grad

tensor([[[[-1.1280, -0.4667],
          [-0.2554, -0.1569]],

         [[-1.0597, -1.1517],
          [-1.4969, -1.2904]]]]) 
 tensor([[[[1.3864, 0.3418],
          [0.6081, 1.2601]],

         [[0.8293, 1.1670],
          [1.1651, 1.0561]]]], grad_fn=<MulBackward0>)


In [37]:
upsample = nn.Upsample(scale_factor=2, mode='nearest')
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
x = torch.randn((1, 1, 2, 2))
y = upsample(x)
x_hat = downsample(y)
x, y, x_hat

(tensor([[[[ 0.0366,  0.3688],
           [-0.4128, -0.2120]]]]),
 tensor([[[[ 0.0366,  0.0366,  0.3688,  0.3688],
           [ 0.0366,  0.0366,  0.3688,  0.3688],
           [-0.4128, -0.4128, -0.2120, -0.2120],
           [-0.4128, -0.4128, -0.2120, -0.2120]]]]),
 tensor([[[[ 0.0366,  0.3688],
           [-0.4128, -0.2120]]]]))