In [1]:
import torch
import torch.nn as nn
import numpy as np
import math
from scipy import linalg
import os
import sys


root = os.path.abspath('/home/mbeliaev/home/code/robust-l0/')
sys.path.append(root)
device = 'cuda:0'

from utils.helpers import *
from utils.trunc import *
from utils.models import Net

In [7]:
# testing independent clip
bs = 128
in_ch = 1
out_ch = 1
l = 28
ker_dim = 3
out_dim = l - ker_dim + 1
k = 5

x = torch.randn(bs,in_ch,l,l).to(device)
weight = torch.randn(out_ch, in_ch, ker_dim, ker_dim).to(device) 

print('input: ',x.shape)
print('weight: ',weight.shape)

input:  torch.Size([128, 1, 28, 28])
weight:  torch.Size([1, 1, 3, 3])


In [8]:
trunc = trunc_conv(x.shape, out_ch, ker_dim, k)
trunc.weight.data = weight
trunc.to(device)

trunc_conv()

In [9]:
trunc(x).shape

torch.Size([128, 1, 28, 28])

In [4]:
in_features = 10
x = torch.randn(bs,in_features).to(device)
x.requires_grad = True
clip = trunc_simple_clip(in_features)
clip.to(device)

out = clip(x)
loss = out.sum()
loss.backward()

print(clip.weight.grad)

None


In [5]:
in_features = 10
x = torch.randn(bs,in_features).to(device)
x.requires_grad = True
clip = trunc_simple(in_features)
clip.to(device)

out = clip(x)
loss = out.sum()
loss.backward()

clip.weight.grad

tensor([ -7.1002,  -7.0575,   6.0792,  -0.6289, -10.7291,   0.0230,  10.3470,
          5.1034,   2.1088,   5.3068], device='cuda:0')

In [None]:
out_old = trunc_old(x)
out = trunc(x)
print(out.shape)
print(out_old.shape)

out==out_old

In [None]:
# check that our matrix mul is correct

flat_kern = toeplitz_mult_ch(weight.detach().cpu(), [in_ch,l,l])
flat_kern = torch.tensor(flat_kern).to(device)

out = torch.nn.functional.conv2d(x,weight)
out_t = x.flatten(1)@flat_kern.T.float()

out-out_t.reshape(out.shape) < 0.00001

In [None]:
flat_kern = toeplitz_mult_ch(weight.detach().cpu(), [in_ch,l,l])
flat_kern = torch.tensor(flat_kern).to(device)
print(flat_kern.shape)
kern_mask = flat_kern != 0
assert kern_mask.sum().sum() == (out_ch*out_dim**2)*(in_ch*ker_dim**2)
r_scale = 1/kern_mask.sum(axis=0)
print(r_scale.shape)



In [None]:
kern_avg = nn.functional.conv2d(x,weight.detach())
kern_avg = kern_avg.reshape(bs,out_ch*out_dim**2)/(in_ch*ker_dim**2)
print(kern_avg.shape)

In [None]:
r_vals = x.reshape(-1,1,in_ch*l**2)*flat_kern #feature pixel product
print(r_vals.shape)
r_vals -= kern_avg.reshape(-1,out_ch*out_dim**2,1) # subtract all averages
r_vals *= kern_mask # remove values where no pixel feature product
r_vals = r_vals.sum(axis=1)/r_scale # sum average contribution over all windows
print(r_vals.shape)

In [None]:
_, idx_top = torch.topk(r_vals,k) #(bs, out_dim, self.k)
print(idx_top.shape)
_, idx_bot = torch.topk(-1*r_vals,k)

In [None]:
z = torch.ones_like(r_vals).float()
print(z.shape)
z[np.arange(r_vals.shape[0]),idx_top.T] = 0
z[np.arange(r_vals.shape[0]),idx_bot.T] = 0

z = z.view(x.shape)
print(z.shape)

In [None]:
z

In [None]:
old_trunc = trunc_conv(l,ker_dim,5)
new_trunc = trunc_conv_new(l,ker_dim,5)

new_trunc.weight = old_trunc.weight
new_trunc.bias = old_trunc.bias

new_trunc.to(device)
old_trunc.to(device)

In [None]:
old_out = old_trunc(x)

In [None]:
new_out = new_trunc(x)

In [None]:
# so far 4.8s and 1.3s
old_out==new_out

In [None]:
# r_net= Net('cnn_small', 512, x.shape, 2, 'clip')
# r_net.to(device)

# net = Net('cnn_small', 512, x.shape, 2, 'conv')
# net.to(device)

In [None]:
x_vals = x.clone().detach().cpu().numpy()

r_vals = np.zeros_like(x_vals.reshape(-1,1,l**2)).squeeze()
# print('r_vals: ',r_vals.shape)
flat_kern = toeplitz_1_ch(weight[0,0].detach().cpu(), [l,l])
# print('flat_kern:  ',flat_kern.shape)
kern_mask = flat_kern != 0
assert kern_mask.sum().sum() == (out_dim**2)*(ker_dim**2)
r_scale = 1/kern_mask.sum(axis=0)
# print('r_scale: ',r_scale.shape)

for im in range(r_vals.shape[0]):
    for i_kern in range(flat_kern.shape[0]):
        kern_avg = np.dot(flat_kern[i_kern],x_vals[im,0].flatten())/(ker_dim**2)
        # print('kern_avg: ',kern_avg.shape)
        # print('kern_mask i_kern: ', kern_mask[i_kern])
        temp = (flat_kern[i_kern]*x_vals[im,0].flatten()-kern_avg)*kern_mask[i_kern]
        # print('temp: ',temp.shape)
        # print('r_vals[im]: ', r_vals[im].shape)
        r_vals[im] += temp 
        # break
    # break
    r_vals[im] /= r_scale
# print(r_vals)

r_vals = torch.tensor(r_vals).to(device)
_, idx_top = torch.topk(r_vals,k) #(bs, out_dim, self.k)
_, idx_bot = torch.topk(-1*r_vals,k)

# print(idx_top)
# # better to create mask instead of inplace 
z = torch.ones_like(r_vals)
z[np.arange(r_vals.shape[0]),idx_top.T] = 0
z[np.arange(r_vals.shape[0]),idx_bot.T] = 0

# print(z)

out = nn.functional.conv2d(z.view(x.shape)*x,weight.to(device))

In [None]:
x_vals = x.clone().detach().cpu().numpy()

# r_vals is scales we want (bs, in_dim**2)
r_vals = np.zeros_like(x_vals.reshape(-1,1,l**2)).squeeze()
# print('r_vals: ',r_vals.shape)
# flat kern is the conv in matrix form (out_dim**2, in_dim**2)
flat_kern = toeplitz_1_ch(weight[0,0].detach().cpu(), [l,l])
# print('flat_kern:  ',flat_kern.shape)
kern_mask = flat_kern != 0
# print('kern_mask: ', kern_mask.shape)
assert kern_mask.sum().sum() == (out_dim**2)*(ker_dim**2)
# every pixel gets same scale so r_scale is just (in_dim**2)
r_scale = 1/kern_mask.sum(axis=0)
# print('r_scale: ',r_scale.shape)

# kern avg is just output of convolution 
kern_avg = nn.functional.conv2d(torch.tensor(x_vals),weight.detach().cpu()).numpy()
kern_avg = kern_avg.reshape(x_vals.shape[0],out_dim**2)/(ker_dim**2)
# print('kern_avg: ', kern_avg.shape)

# temp shape is (bs, out_dim**2, in_dim**2)
temp = x_vals.reshape(-1,1,l**2)*flat_kern
temp -= kern_avg.reshape(-1,out_dim**2,1)
temp *= kern_mask
# print('temp: ', temp.shape)
# sum over all kernels
r_vals += temp.sum(axis=1)
r_vals /= r_scale

r_vals = torch.tensor(r_vals).to(device)
_, idx_top = torch.topk(r_vals,k) #(bs, out_dim, self.k)
_, idx_bot = torch.topk(-1*r_vals,k)

# print(idx_top)
# # better to create mask instead of inplace 
z = torch.ones_like(r_vals)
z[np.arange(r_vals.shape[0]),idx_top.T] = 0
z[np.arange(r_vals.shape[0]),idx_bot.T] = 0

# print(z)

out_new = nn.functional.conv2d(z.view(x.shape)*x,weight.to(device))

In [None]:
# x_vals = x.clone().detach().cpu().numpy()

# r_vals = np.zeros_like(x_vals.reshape(-1,1,l**2)).squeeze()
# print('r_vals: ',r_vals.shape)
# flat_kern = toeplitz_1_ch(weight[0,0].detach().cpu(), [l,l])
# print('flat_kern:  ',flat_kern.shape)
# kern_mask = flat_kern != 0
# assert kern_mask.sum().sum() == (out_dim**2)*(ker_dim**2)
# r_scale = 1/kern_mask.sum(axis=0)
# print('r_scale: ',r_scale.shape)

# for im in range(r_vals.shape[0]):
#     for i_kern in range(flat_kern.shape[0]):
#         kern_avg = np.dot(flat_kern[i_kern],x_vals[im,0].flatten())/(ker_dim**2)
#         print('kern_avg: ',kern_avg.shape, kern_avg)
#         print('kern_mask: ', kern_mask.shape)
#         print('kern_mask i_kern: ', kern_mask[i_kern].shape)
#         temp = (flat_kern[i_kern]*x_vals[im,0].flatten())
#         print('temp: ',temp.shape)
#         temp -= kern_avg
#         # print(temp)
#         temp *= kern_mask[i_kern] 
#         print('temp: ',temp.shape)
#         print('r_vals[im]: ', r_vals[im].shape)
#         r_vals[im] += temp 
#         break
#     break

In [None]:
# regular ind trunc

print(x)
x_vals = x.clone().detach() #(bs, in_features)
print('x shape: ',x_vals.shape)
temp = (x_vals.view(-1,1,in_ch)*weight).sum(axis=1) # (bs, out_features, in_features)
print('sum of x@W along input pixel dimension: ',temp.shape)
# x = torch.matmul(x,weight.T) #(bs, out_features)
# print('x@W: ',x.shape)
# print(temp.sum(axis=1))
# sum over last axis gives you x@W
# sum over first axis gives you largest contribution per in_feature
# for x in bs, out_features element wise prod. of w and x
_, idx_top = torch.topk(temp,k) #(bs, out_dim, self.k)
_, idx_bot = torch.topk(-1*temp,k)
print(idx_top)
print(idx_bot)
z = torch.ones(bs,in_ch).to(device)
z[np.arange(bs),idx_top.T] = 0
z[np.arange(bs),idx_bot.T] = 0
print(z)
# print(x)
x = torch.matmul(z*x,weight.T)
# print(z*x)
print(x)

# print(val_top.shape)

# x -= val_top.sum(axis=-1)
# x += val_bot.sum(axis=-1)


In [None]:
z = torch.ones(bs,in_ch).to(device)
print(z)
print(idx_top)
# print(z[np.arange(bs),idx_top.T])
z[:,idx_top.T] = 0
print(z)


In [None]:
idx_top[np.arange]

In [None]:
w,h = 28,28
kernel_dim = 3
bs = 128
in_ch = 3
out_ch = 3

image = torch.randn(bs,in_ch,w,h).to(device)
kernel = torch.randn(out_ch,in_ch,kernel_dim,kernel_dim).to(device)

In [None]:
nn_conv = torch.nn.Conv2d(in_ch,out_ch,kernel_dim,bias=True).to(device)

conv = my_conv(in_ch,out_ch,kernel_dim,bias=True).to(device)
conv.weight = nn_conv.weight
conv.bias = nn_conv.bias

In [None]:
nn_out = nn_conv(image)
out = conv(image)

In [None]:
print(nn_out.shape)
print(out.shape)
print(((nn_out - out)<0.00001).sum() == np.prod(nn_out.shape))
# print(((out - my_out)))