In [1]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F
import vnn
import torchvision
from imp import reload

In [2]:
#do non-vectorized version first

class Local2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, h_in, w_in, stride=1, padding=0, bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.h_in = h_in
        self.w_in = w_in
        self.stride = stride
        self.padding = padding
        self.has_bias = bias
        h_out = int(np.floor(((h_in + 2*padding - kernel_size)/stride) + 1))
        w_out = int(np.floor(((w_in + 2*padding - kernel_size)/stride) + 1))
        self.h_out = h_out
        self.w_out = w_out
        k = in_channels*kernel_size**2
        self.weight = nn.Parameter(torch.randn(h_out, w_out, out_channels, in_channels, kernel_size, kernel_size)/np.sqrt(k))
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels, h_out, w_out))
        if padding > 0:
            self.padder = nn.ZeroPad2d(padding)
        
    def forward(self, input):
        #input = (batch, in_channels, h_in, w_in)
        batch_size = input.shape[0]
        padded_input = self.padder(input) if self.padding > 0 else input
        output = torch.zeros(batch_size, self.out_channels, self.h_out, self.w_out, device=input.device)
        for i in range(self.h_out):
            for j in range(self.w_out):
                i1 = i*self.stride
                i2 = i1 + self.kernel_size
                j1 = j*self.stride
                j2 = j1 + self.kernel_size
                input_chunk = padded_input[:, :, i1:i2, j1:j2]
                weight_for_chunk = self.weight[i, j] #, :, :, :]
                output[:, :, i, j] = torch.einsum("oikl,bikl->bo", weight_for_chunk, input_chunk)
        if self.has_bias:
            output = output + self.bias[None, :, :, :]
        return output
                
        

In [3]:
model = nn.Sequential(Local2d(3, 64, 5, 32, 32, 2, 2),
                      nn.ReLU(),
                      Local2d(64, 128, 5, 16, 16, 2, 2),
                      nn.ReLU(),
                      Local2d(128, 256, 3, 8, 8, 2, 1),
                      nn.ReLU(),
                      nn.Flatten(),
                      nn.Linear(4096, 1024),
                      nn.ReLU(),
                      nn.Linear(1024, 10)).to(0)



In [4]:
#CIFAR

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
loss_fn = nn.CrossEntropyLoss(reduction="mean")
opt = optim.Adam(model.parameters(), lr=1e-3)
for epoch_idx in range(100):
    print(epoch_idx)
    epoch_loss = 0.
    epoch_correct = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        if batch_idx % 50 == 0:
            print(batch_idx)
        out = model(data.to(0))
        loss = loss_fn(out, labels.to(0))
        epoch_loss += loss.item()
        epoch_correct += (out.argmax(dim=1).cpu() == labels).float().sum().item()
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(epoch_loss / (batch_idx + 1))
    print(epoch_correct / 50000.)


0
0
50
100
150
200
250
300
350
1.4825032312241966
0.46608
1
0
50
100
150
200
250
300
350
1.137232693717303
0.59414
2
0
50
100
150
200
250
300
350
0.9318077039840581
0.6676
3
0
50
100
150
200
250
300
350
0.7528206577996159
0.73078
4
0
50
100
150
200
250
300
350
0.585695365200872
0.78874
5
0
50
100
150
200
250
300
350
0.43390200666302
0.84674
6
0
50
100
150
200
250
300
350
0.3167865096836749
0.88792
7
0
50
100
150
200
250
300
350
0.2354684372615936
0.91616
8
0
50
100
150
200
250
300
350
0.18606120422291939
0.93558
9
0
50
100
150
200
250
300


KeyboardInterrupt: 

In [9]:
num_correct = 0
total = 0
for batch_idx, (data, labels) in enumerate(test_loader):
    out = model(data.to(0)).detach()
    num_correct += (out.argmax(dim=1).cpu() == labels).int().sum().item()
    total += len(data)
acc = num_correct / total
acc

0.6586

In [4]:
def init_local(weight, first_layer=False, mono=False):
    if mono:
        return init_local_mono(weight, first_layer)
    weight *= 0.
    h_out, w_out, out_channels, in_channels, kernel_size = weight.shape[:5]
    f = 1 if first_layer else 0.5
    weight.normal_(0., 1./np.sqrt(f * in_channels * kernel_size**2))

def init_local_mono(weight, first_layer):
    if first_layer:
        return init_local_mono_l0(weight)
    weight *= 0.
    h_out, w_out, out_channels, in_channels, kernel_size = weight.shape[:5]
    W = torch.randn(h_out, w_out, out_channels//2, in_channels//2, kernel_size, kernel_size) / np.sqrt(0.25 * in_channels * kernel_size**2)
    weight[:, :, ::2, ::2] = F.relu(W)
    weight[:, :, ::2, 1::2] = F.relu(-W)
    weight[:, :, 1::2, 1::2] = F.relu(W)
    weight[:, :, 1::2, ::2] = F.relu(-W)

def init_local_mono_l0(weight):
    weight *= 0.
    h_out, w_out, out_channels, in_channels, kernel_size = weight.shape[:5]
    filter_shape_3d = weight.shape[3:]
    W = torch.randn((h_out, w_out, out_channels//2,) + filter_shape_3d) / np.sqrt(in_channels * kernel_size**2)
    weight[:, :, ::2] = W
    weight[:, :, 1::2] = -W

class VecLocal2d(nn.Module):
    def __init__(self, category_dim, in_channels, out_channels, kernel_size, h_in, w_in,
                 stride=1, padding=0,
                 mono=False, first_layer=False, device="cpu"):
            super().__init__()
            self.category_dim = category_dim
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.kernel_size = kernel_size
            self.h_in = h_in
            self.w_in = w_in
            self.mono = mono
            self.first_layer = first_layer
            self.device = device
            self.stride = stride
            self.padding = padding
            self.lc = Local2d(in_channels, out_channels, kernel_size, h_in, w_in,
                              stride=stride, padding=padding, bias=False).to(device)
            w_out, h_out = self.lc.w_out, self.lc.h_out
            self.bias = nn.Parameter(torch.zeros(category_dim, out_channels, h_out, w_out, device=device))
            with torch.no_grad():
                init_local(self.lc.weight, first_layer=first_layer, mono=mono)
                if first_layer:
                    self.lc.weight *= np.sqrt(category_dim)

    @property
    def weight(self):
        return self.lc.weight

    def forward(self, input):
        #input = (batch_dim, category_dim, channels, width, height)
        self.input = input.detach()
        batch_size, category_dim = input.shape[:2]
        CWH = input.shape[2:]
        input_reshaped = input.view((batch_size*category_dim,) + CWH)
        output_reshaped = self.lc(input_reshaped)
        output = output_reshaped.view((batch_size, category_dim) + output_reshaped.shape[1:])
        output = output + self.bias[None, :, :, :, :]
        self.mask_shape = (output.shape[0],) + output.shape[2:]
        return output
    
    def post_step_callback(self):
        if self.mono and (not self.first_layer):
            with torch.no_grad():
                #self.conv.weight.clamp_(min=0)
                self.lc.weight.abs_()


In [5]:
mono = False
model = nn.Sequential(VecLocal2d(10, 30, 64, 5, 32, 32, 2, 2, first_layer=True, mono=mono),
                      vnn.ctReLU(),
                      VecLocal2d(10, 64, 128, 5, 16, 16, 2, 2, mono=mono),
                      vnn.ctReLU(),
                      VecLocal2d(10, 128, 256, 3, 8, 8, 2, 1, mono=mono),
                      vnn.ctReLU(),
                      vnn.Flatten(),
                      vnn.Linear(10, 4096, 1024, mono=mono),
                      vnn.tReLU(),
                      vnn.Linear(10, 1024, 1, mono=mono)).to(0)



In [6]:
loss_fn = nn.CrossEntropyLoss(reduction="mean")
opt = optim.Adam(model.parameters(), lr=1e-3)
for epoch_idx in range(100):
    print(epoch_idx)
    epoch_loss = 0.
    epoch_correct = 0.
    for batch_idx, (data, labels) in enumerate(train_loader):
        if batch_idx % 25 == 0:
            print(batch_idx)
        input = vnn.expand_input_conv(data, 10).to(0)
        out = model(input)[..., 0]
        loss = loss_fn(out, labels.to(0))
        epoch_loss += loss.item()
        epoch_correct += (out.argmax(dim=1).cpu() == labels).float().sum().item()
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(epoch_loss / (batch_idx + 1))
    print(epoch_correct / 50000.)


0
0
Instantiated t with shape (10, 64, 16, 16)
Instantiated t with shape (10, 128, 8, 8)
Instantiated t with shape (10, 256, 4, 4)
Instantiated t with shape (10, 1024)
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
1.53617662389565
0.46854
1
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
1.1691053960939197
0.59284
2
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.9859638319295996
0.65764
3
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.8380102280460661
0.7064
4
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.7220104371800142
0.74846
5
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.6066160157818319
0.7893
6
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.5325303315506567
0.81552
7
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.4619803993826937
0.84024
8
0
25
50
75
100
125
150
175
200
225
250
275
300
325
350
375
0.41709781851609956
0.85584
9
0
25
50
75
100
125
150


KeyboardInterrupt: 

In [19]:
model[:3](input).std()

Instantiated t with shape (10, 64, 16, 16)


tensor(0.8995, grad_fn=<StdBackward0>)

In [13]:
out.std()

tensor(0.8682, device='cuda:0', grad_fn=<StdBackward0>)