In [13]:
import numpy as np
import tensorflow as tf
%matplotlib inline
train, test = tf.keras.datasets.cifar10.load_data()
x_train, y_train = train[0], train[1]
x_test, y_test = test[0], test[1]

class LinearLayer:
    
    def forward(self, x):
        self.input = x
        self.output = x.dot(self.weights)
        return self.output
        
    def update(self, dloss_dw, lr):
        self.weights = self.weights - lr*dloss_dw
    
    def __init__(self, layer):
        self.weights = layer
        
class ConvLayer:
    
    def convolve(self, x):
        f = self.filter
        fn, fd, fr, fc = f.shape
        xn, xd, xr, xc = x.shape
        x_pad = (fr - 1) // 2
        y_pad = (fc - 1) // 2
        x = np.pad(x, ((0, 0), (0, 0), (x_pad, x_pad), (y_pad, y_pad)))
        o = np.zeros((xn, fn, xr, xc))
        for r in range(xr):
            for c in range(xc):
                for n in range(fn):
                    fm = f[n]
                    xm = x[:, :, r:r+fr, c:c+fc]
                    res = fm * xm
                    o[:, n, r, c] += np.sum(np.sum(res, (2, 3)), 1)
        return o
    
    def maxpool(self, x):
        
        xn, d, xr, xc = x.shape
        o = np.zeros((xn, d, xr//2, xc//2))
        ix = np.zeros_like(x)
        for n in range(xn):
            for r in range(0, xr, 2):
                for c in range(0, xc, 2):
                    m = np.amax(x[n, :, r:r+2, c:c+2], (1, 2))
                    o[n, :, r//2, c//2] = m
                    
                    mm = x[n, :, r:r+2, c:c+2].reshape(d, 2*2)
                    im = np.zeros_like(mm)
                    
                    im[range(d), np.argmax(mm, 1)] = 1
                    im = im*(np.all(mm==0, 1)==0).reshape(-1, 1)
                    
                    ix[n, :, r:r+2, c:c+2] = im.reshape((d, 2, 2))
                    if n ==0 and r==5 and c==5:
                        print(ix[n, :, r:r+2, c:c+2])
        self.ixs.append(ix)
        
        return o
        
    
    def backpool(self, x):
        self.backpool_inputs.append(x)
        ix = self.ixs.pop()
        xn, d, xr, xc = x.shape
        result = np.zeros_like(ix)

        for n in range(xn):
            for r in range(xr):
                for c in range(xc):
                    xm = x[n, :, r, c].reshape(d, 1, 1)
                    fm = ix[n, :, r*2:r*2+2, c*2:c*2+2]

                    result[n, :, r*2:r*2+2, c*2:c*2+2] = xm*fm
        self.backpool_results.append(result)  
        return result * (self.relu_o>0)

    def do_df(self, f):
        
        x = self.input
    
        fn, fd, fr, fc  = f.shape
        xn, xd, xr, xc = x.shape
        x_pad = (self.filter.shape[2] - 1) // 2
        y_pad = (self.filter.shape[3] - 1) // 2
        x = np.pad(x, ((0, 0), (0, 0), (x_pad, x_pad), (y_pad, y_pad)))
        result = np.zeros_like(self.filter)
        
        for r in range(self.filter.shape[2]):
            for c in range(self.filter.shape[3]):

                xm = np.swapaxes(np.swapaxes(x[:, :, r:r+fr, c:c+fc], 0, 1).reshape((xd, -1)), 0, 1)
                fm = np.swapaxes(f, 0, 1).reshape((fd, -1))
                result[:, :, r, c] = fm @ xm
                    
        return result

    def do_di(self, f):
        
        x = np.rot90(self.filter, k=2, axes=(2, 3))
        
        fn, fd, fr, fc = f.shape 
        xn, xd, xr, xc = x.shape 
        x_pad = fr - xr + (xr - 1) //2
        y_pad = fc  - xc + (xc - 1) //2
        x = np.pad(x, ((0, 0), (0, 0), (x_pad, x_pad), (y_pad, y_pad)))
        result = np.zeros((fn, xd, fr, fc))
        for n in range(fn):
            for r in range(fr):
                for c in range(fc):
                    xm = np.swapaxes(x[:, :, r:r+fr, c:c+fc], 0, 1)
                    fm = f[n]
                    
                    prod = xm * fm
                    result[n, :, r, c] = np.sum(np.sum(prod, (2, 3)), 1)
                    
                    '''
                    xm = np.swapaxes(x[:, :, r:r + fr, c:c+fc], 0, 1)
                    res = xm * f[n]
                    result[n, :, r, c] = np.sum(np.sum(res, (2, 3)), (1))
                    '''
                    
        return np.rot90(result, k=2, axes=(2, 3))
    
    def forward(self, x):
        self.input= x
        conv_o = self.convolve(x)
        relu_o = np.maximum(conv_o, 0)
        self.relu_o = relu_o
        pool_o = self.maxpool(relu_o)
        self.output = pool_o
        return pool_o
    
    def backward(self, dloss_do):
        o = self.backpool(dloss_do)
        return self.do_df(o), self.do_di(o)
    
    def update(self, dloss_df, lr):
        self.filter = self.filter - lr*dloss_df
    
    def __init__(self, layer):
        self.backpool_results = []
        self.backpool_inputs = []
        self.filter = layer
        self.ixs = []

class CNN:

    def NLLLoss(self, lsm, targets):
        out = np.zeros_like(targets)
        for i in range(targets.shape[0]): out[i] = lsm[i][targets[i]]
        return -out.sum()/float(len(out))
    
    def log_softmax(self, x):
        c = x.max()
        logsumexp = np.log(np.exp(x - c).sum())
        return x - c - logsumexp
    
    def forward(self, x):
        self.input = x
        for l in self.conv_layers:
            x = l.forward(x)
        fcl_o = self.fcl.forward(x.reshape((1, -1)))
        lsm = self.log_softmax(fcl_o)
        return lsm
    
    def __init__(self, conv_layers, fcl_layer, learning_rate):
        
        self.conv_layers = conv_layers
        self.fcl = fcl_layer
        self.lr = learning_rate
        
    def eval(self, n):
        
        guesses, corr = [], []
        for i in range(n):
            
            samp = np.random.randint(0, 50000)
            shared_inp = np.asarray([x_train[samp]]).reshape((1, 3, 32, 32))
            shared_out = [y_train[samp].item()]
            
            
            p = self.forward(shared_inp)
            my_pred = np.argmax(np.exp(p))
            guesses.append(my_pred)
            corr.append(shared_out[0])
        return (np.asarray(guesses)==np.asarray(corr)).mean()
        
        
    def backward(self, p, y):
        
        loss = self.NLLLoss(p, y)
        out = np.zeros((y.shape[0],10))
        out[range(out.shape[0]),y] = 1
        d_out = -out / len(y)

        dloss_dlsmi = d_out - np.exp(p)*d_out.sum(axis=1).reshape((-1, 1))
        dloss_dfclo = dloss_dlsmi
        
        dloss_dfcli = self.fcl
        
        dfclo_dfcl = self.fcl.input
        dloss_dfcl = np.dot(dfclo_dfcl.reshape((-1, 1)), dloss_dfclo.reshape(1, -1))
        
        dfclo_dfcli = self.fcl.weights
        dloss_dfcli = np.dot(dfclo_dfcli, dloss_dfclo.reshape((-1, 1)))
        
        conv_layers = [l for l in self.conv_layers]
        
        dloss_do = dloss_dfcli.reshape((conv_layers[-1].output.shape))
        dloss_dfcli = dloss_do
        dloss_dfs = []
        dloss_dos = []
        
        for i in range(len(conv_layers)):
            layer = conv_layers.pop()
            dloss_df, dloss_do = layer.backward(dloss_do)
            dloss_dfs.append(dloss_df)
            dloss_dos.append(dloss_do)
            self.conv_layers[-(i+1)].update(dloss_df, self.lr)
        self.fcl.update(dloss_dfcl, self.lr)
        return dloss_dfclo, dloss_dfcli, dloss_dfcl, dloss_dfs, dloss_dos, loss

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    
    def fun(self, module,grad_in,grad_out):
        
        if grad_in[0] is not None: self.stored_gradients.append(grad_in[0])
        self.stored_gradients.append(grad_out[0])
    def __init__(self):
        super().__init__()
        self.stored_gradients = []
        self.ngpu=0
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels=16, kernel_size=5, stride=1, padding=2, bias=False)
        self.pool = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(in_channels = 16, out_channels=20, kernel_size=5, stride=1, padding=2, bias=False)
        self.conv3 = nn.Conv2d(in_channels = 20, out_channels=20, kernel_size=5, stride=1, padding=2, bias=False)
        self.fcl = nn.Linear(4 * 4 * 20, 10)
        self.lsm = nn.LogSoftmax(dim=1)
        self.conv1.register_full_backward_hook(self.fun)
        self.conv2.register_full_backward_hook(self.fun)
        self.conv3.register_full_backward_hook(self.fun)
        
    def forward(self, x):
        self.stored_gradients = []
        c1 = self.pool(F.relu(self.conv1(x)))
        c1.retain_grad()
        c2 = self.pool(F.relu(self.conv2(c1)))
        c2.retain_grad()
        c3 = self.pool(F.relu(self.conv3(c2)))
        c3.retain_grad()
        flat = torch.flatten(c3, 1) # flatten all dimensions except batch
        flat.retain_grad()
        fcl = self.fcl(flat)
        fcl.retain_grad()
        lsm = self.lsm(fcl)
        lsm.retain_grad()
        return lsm, fcl, flat, c3, c2, c1

In [19]:
torchmodel = Net()
criterion = nn.NLLLoss()
learning_rate = .00005
optimizer = optim.SGD(torchmodel.parameters(), lr=learning_rate)

c1 = ConvLayer(torchmodel.conv1.weight.detach().numpy().copy())
c2 = ConvLayer(torchmodel.conv2.weight.detach().numpy().copy())
c3 = ConvLayer(torchmodel.conv3.weight.detach().numpy().copy())
fcl = LinearLayer(np.swapaxes(torchmodel.fcl.weight.detach().numpy().copy(), 0, 1))
mymodel = CNN([c1, c2, c3], fcl, learning_rate)

for i in range(100):
    samp = np.random.randint(0, 50000)
    shared_inp = np.asarray([x_train[samp]]).reshape((1, 3, 32, 32))
    shared_out = [y_train[samp].item()]
    
    torch_lsm, torch_fcl, torch_flat, torch_c3, torch_c2, torch_c1 = torchmodel(torch.tensor(shared_inp, dtype=torch.float))
    torch_pred = np.argmax(np.exp(torch_lsm.detach().numpy().copy()))
    torch_loss = criterion(torch_lsm,  torch.tensor(shared_out))

    p = mymodel.forward(shared_inp)
    my_pred = np.argmax(np.exp(p))
    dloss_dfclo, dloss_dfcli, dloss_dfcl, dloss_dfs, dloss_dos, my_loss = mymodel.backward(p, np.asarray([shared_out]))
    torch_loss.backward()
    optimizer.step()
    print(torch_loss.item(), my_loss, shared_out, torch_pred, my_pred)

20.37076187133789 20.0 [9] 4 4
7.297242164611816 7.0 [3] 8 8
8.110270500183105 9.0 [7] 3 3
8.37322998046875 6.0 [8] 9 5
1.5587868690490723 2.0 [7] 9 8
15.295458793640137 8.0 [1] 7 5
6.854683876037598 3.0 [6] 7 7
10.73036003112793 7.0 [2] 7 6
3.931396245956421 0.0 [2] 7 2
8.44135570526123 3.0 [4] 7 2
5.3414998054504395 8.0 [0] 8 4
2.6941730976104736 3.0 [7] 8 0
1.7557027339935303 2.0 [2] 8 7
3.102165460586548 2.0 [5] 9 7
2.3847246170043945 4.0 [9] 2 7
3.6294443607330322 0.0 [9] 2 9
1.635027527809143 3.0 [2] 2 9
1.798170804977417 3.0 [4] 2 2
2.7127537727355957 3.0 [7] 9 4
3.3734943866729736 3.0 [5] 8 4
2.8225667476654053 0.0 [5] 9 5
2.2785072326660156 4.0 [9] 4 5
2.2637596130371094 9.0 [0] 2 9
3.2505240440368652 3.0 [7] 2 0
3.6455914974212646 0.0 [7] 2 7
2.6040854454040527 3.0 [5] 0 7
2.1916465759277344 4.0 [9] 2 5
1.4896893501281738 4.0 [2] 2 9
3.960935592651367 6.0 [1] 2 5
1.4328631162643433 4.0 [0] 0 9
2.4903674125671387 1.0 [1] 2 0
1.7320274114608765 2.0 [2] 2 1
1.8260191679000854 0.

In [20]:
mymodel.eval(50)

0.12

In [None]:
'''

shared_inp = x_train[0].reshape((1, 3, 32, 32))
shared_out = [5]
    
torchmodel = Net()
criterion = nn.NLLLoss()
optimizer = optim.SGD(torchmodel.parameters(), lr=0.001)
torch_lsm, torch_fcl, torch_flat, torch_c3, torch_c2, torch_c1 = torchmodel(torch.tensor(shared_inp, dtype=torch.float))
torch_loss = criterion(torch_lsm,  torch.tensor(shared_out))

c1 = ConvLayer(torchmodel.conv1.weight.detach().numpy().copy())
c2 = ConvLayer(torchmodel.conv2.weight.detach().numpy().copy())
c3 = ConvLayer(torchmodel.conv3.weight.detach().numpy().copy())
fcl = LinearLayer(np.swapaxes(torchmodel.fcl.weight.detach().numpy().copy(), 0, 1))
mymodel = CNN([c1, c2, c3], fcl)
targets = np.asarray([shared_out])
p = mymodel.forward(shared_inp)
dloss_dfclo, dloss_dfcli, dloss_dfcl, dloss_dfs, dloss_dos, myloss = mymodel.backward(p, np.asarray([shared_out]))

torch_loss.backward()

print(torch_loss.item(), myloss)

'''
'''

fcl_grad = torch_fcl.grad.detach().numpy().copy()
flat_grad = torch_flat.grad.detach().numpy().copy()
torch_c3_grad = torch_c3.grad.detach().numpy().copy()
torch_c2_grad = torch_c2.grad.detach().numpy().copy()
torch_c1_grad = torch_c1.grad.detach().numpy().copy()
print(fcl_grad.shape, flat_grad.shape, torch_c3_grad.shape, torch_c2_grad.shape, torch_c1_grad.shape)
for i in range(len(torchmodel.stored_gradients)):
    print(torchmodel.stored_gradients[i].shape)
torchprev3 = (torchmodel.conv3.weight.detach().numpy().copy())
torchprev2 = (torchmodel.conv2.weight.detach().numpy().copy())
torchprev1 = (torchmodel.conv1.weight.detach().numpy().copy())
optimizer.step()
torchpost3 = (torchmodel.conv3.weight.detach().numpy().copy())
torchpost2 = (torchmodel.conv2.weight.detach().numpy().copy())
torchpost1 = (torchmodel.conv1.weight.detach().numpy().copy())
grad3 = (torchprev3-torchpost3)*1000
grad2 = (torchprev2-torchpost2)*1000
grad1 = (torchprev1-torchpost1)*1000
print(grad3[1, 0])
print(dloss_dfs[0][1, 0])
print(grad1[-1, -1])
print(dloss_dfs[2][-1, -1])
print(np.round(np.asarray(torchmodel.stored_gradients[1])[0, 0, 0:8, 0:8], 3)*1000)
print(np.round(mymodel.conv_layers[-1].backpool_results[0][0, 0,  0:8, 0:8], 3)*1000)
'''