In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("./../..")

In [3]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as T
from torch import optim

import numpy as np

# local imports
from effcn.models_mnist import BaselineCNN
from misc.utils import count_parameters


#### Test capsulwise activation

In [221]:
a = torch.rand([1,1,28,28])
a.shape

torch.Size([1, 1, 28, 28])

In [222]:
conv =  nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid")

In [223]:
x = conv(a)
x.shape

torch.Size([1, 32, 24, 24])

In [224]:
conv.weight.shape

torch.Size([32, 1, 5, 5])

In [225]:
def squash_func(x, eps=10e-21):
    """
        IN:
            x (b, n, d)
        OUT:
            squash(x) (b, n, d)
    """
    x_norm = torch.norm(x, dim=2, keepdim=True)
    return (1 - 1 / (torch.exp(x_norm) + eps)) * (x / (x_norm + eps))

In [226]:
s = squash_func(x)
s.shape
s.min()

tensor(-0.4229, grad_fn=<MinBackward1>)

In [227]:
v = x.view(x.shape[0],x.shape[1],-1)
v = v.permute(0,2,1)
s = squash_func(v)
s.shape

torch.Size([1, 576, 32])

In [228]:
x_norm = torch.norm(v, dim=2, keepdim=True)
x_norm.shape

torch.Size([1, 576, 1])

In [229]:
eps=10e-21
k = (1 - 1 / (torch.exp(x_norm) + eps)) * (v / (x_norm + eps))
k.shape

torch.Size([1, 576, 32])

In [230]:
k.min()

tensor(-0.3995, grad_fn=<MinBackward1>)

In [231]:
t = k.permute(0,2,1)
t = t.view(x.shape)
t.shape

torch.Size([1, 32, 24, 24])

#### Ref Model

In [232]:
model = BaselineCNN()

In [233]:
model

BaselineCNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)

In [234]:
count_parameters(model)

28938

In [4]:
def squash_conv_func(v, eps=10e-21):
    """
        IN:
            x (b, c, h, w)
        OUT:
            squash(x) (b, c, h, w)
    """
    #shap to capsule squashing for tests
    x = v.view(v.shape[0],v.shape[1],-1)
    x = x.permute(0,2,1)
    
    x_norm = torch.norm(x, dim=2, keepdim=True)
    k = (1 - 1 / (torch.exp(x_norm) + eps)) * (x / (x_norm + eps))
    
    #reshape to comv
    t = k.permute(0,2,1)
    t = t.view(v.shape)
    
    return t

In [16]:
class Squash(nn.Module):
    def __init__(self, eps=1e-21):
        super().__init__()
        self.eps = eps

    def forward(self, v):
        """
            IN:
                x (b, c, h, w)
            OUT:
                squash(x) (b, c, h, w)
        """
        #shap to capsule squashing for tests
        x = v.view(v.shape[0],v.shape[1],-1)
        x = x.permute(0,2,1)

        x_norm = torch.norm(x, dim=2, keepdim=True)
        k = (1 - 1 / (torch.exp(x_norm) + self.eps)) * (x / (x_norm + self.eps))

        #reshape to comv
        t = k.permute(0,2,1)
        t = t.view(v.shape)

        return t

In [237]:
#test on normal conv
a = torch.rand([1,1,28,28])
conv =  nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid")
#conv =  nn.Conv2d(1, 32, kernel_size=(28, 28), groups=1, padding="valid")
x = conv(a)
squa = Squash()
s = squa(x)

s.shape

torch.Size([1, 32, 24, 24])

In [238]:
count_parameters(conv)

832

In [239]:
#test on deepwise conv conv
a = torch.rand([1,1,28,28])
conv =  nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid"),
            nn.Conv2d(32, 32, kernel_size=(24, 24), groups=32, padding="valid")
        )
x = conv(a)
squa = Squash()
s = squa(x)

s.shape

torch.Size([1, 32, 1, 1])

In [240]:
count_parameters(conv)

19296

In [294]:
class BaselineCNN(nn.Module):
    """
        Baseline CNN Model for MNIST
    """

    def __init__(self):
        super(BaselineCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            #nn.AvgPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #nn.AvgPool2d(kernel_size=2),
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
        #self.out = nn.Linear(32 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output

#### Test on Model

In [320]:
class SquashCNN(nn.Module):
    """
        Baseline CNN Model for MNIST
    """

    def __init__(self):
        super(SquashCNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            Squash(),
            nn.MaxPool2d(kernel_size=2),
            #nn.AvgPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            Squash(),
            nn.MaxPool2d(2),
            #nn.AvgPool2d(kernel_size=2),
        )
        # fully connected layer, output 10 classes
        self.out = nn.Linear(32 * 7 * 7, 10)
        #self.out = nn.Linear(32 * 28 * 28, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output

In [321]:
model = SquashCNN()
#model = BaselineCNN()
#model = EffBB()
#model = SquashEffBB()
count_parameters(model)

28938

In [322]:
if torch.cuda.is_available():  
    dev = "cuda" 
else:  
    dev = "cpu"  
device = torch.device(dev)
device

device(type='cuda')

In [323]:
    transform_train = T.Compose([
        T.RandomAffine(
            degrees=(-30, 30),
            shear=(-30, 30),
            # translate=(0.9, 0.9),
        ),
        T.RandomResizedCrop(
            28,
            scale=(0.8, 1.2),
            ratio=(1, 1),
        ),
        T.ToTensor()
    ])
    transform_valid = T.Compose([
        T.ToTensor()
    ])

In [324]:
ds_train = datasets.MNIST(root='../../data', train=False, download=True, transform= transform_train)
ds_valid = datasets.MNIST(root="../../data", train=False, download=True, transform=T.ToTensor())

In [325]:
dl_train = torch.utils.data.DataLoader(ds_train, 
                                          batch_size=256, 
                                          shuffle=True, 
                                          num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, 
                                          batch_size=256, 
                                          shuffle=True, 
                                          num_workers=4)

In [326]:
#model = BaselineCNN()
#model = EffBB()
#model = SquashEffBB()
model = SquashCNN()
model = model.to(device)

In [327]:
loss_func = nn.CrossEntropyLoss() 
optimizer = optim.Adam(model.parameters(), lr = 0.01) 

In [328]:
num_epochs = 100
model.train()
for epoch in range(num_epochs):
    for idx, (x, y_true) in enumerate(dl_train):
        x = x.to(device)
        y_true = y_true.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = loss_func(y_pred, y_true)         
        loss.backward()
        optimizer.step()
        
        if idx % 1000 == 0:
            print("Epoch[{}/{}] - step {} loss: {:.4f}".format(epoch, num_epochs, idx, loss.item()))

Epoch[0/100] - step 0 loss: 2.3056
Epoch[1/100] - step 0 loss: 0.5747
Epoch[2/100] - step 0 loss: 0.3746
Epoch[3/100] - step 0 loss: 0.2459
Epoch[4/100] - step 0 loss: 0.2097
Epoch[5/100] - step 0 loss: 0.1655
Epoch[6/100] - step 0 loss: 0.1932
Epoch[7/100] - step 0 loss: 0.1764
Epoch[8/100] - step 0 loss: 0.1461
Epoch[9/100] - step 0 loss: 0.1336
Epoch[10/100] - step 0 loss: 0.2451
Epoch[11/100] - step 0 loss: 0.0933
Epoch[12/100] - step 0 loss: 0.1032
Epoch[13/100] - step 0 loss: 0.0817
Epoch[14/100] - step 0 loss: 0.1564
Epoch[15/100] - step 0 loss: 0.1193
Epoch[16/100] - step 0 loss: 0.1110
Epoch[17/100] - step 0 loss: 0.1258
Epoch[18/100] - step 0 loss: 0.0921
Epoch[19/100] - step 0 loss: 0.1140
Epoch[20/100] - step 0 loss: 0.0752
Epoch[21/100] - step 0 loss: 0.1031
Epoch[22/100] - step 0 loss: 0.0934
Epoch[23/100] - step 0 loss: 0.1492
Epoch[24/100] - step 0 loss: 0.2002
Epoch[25/100] - step 0 loss: 0.1290
Epoch[26/100] - step 0 loss: 0.1231
Epoch[27/100] - step 0 loss: 0.0909
Ep

In [329]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for x, y_true in dl_valid:
        x = x.to(device)
        y_true = y_true.to(device)
        y_pred = model(x)
        y_pred = torch.max(y_pred, 1)[1]
        correct += (y_pred == y_true).sum().item()
        total += y_true.shape[0]
    acc = correct / total

In [330]:
print(acc)
print(total - correct)

0.9958
42


Augmentation

Model: SquashCNN
epochs: 20

maxpool
Epoch[19/20] - step 0 loss: 0.0036
0.9887
113
params: 28938

-----------
epochs: 100

maxpool

acc: 0.9958
(total - correct): 42
params: 28938


----------------------------------------------------
Model: BaselineCNN
epochs: 20

maxpool
Epoch[19/20] - step 0 loss: 0.0002
acc:0.9837
(total - correct): 163
params: 28938

-----------
epochs: 100

maxpool
Epoch[99/100] - step 0 loss: 0.0604
acc:0.9963
(total - correct): 37
params: 28938

Model: SquashCNN
epochs: 20

maxpool
Epoch[19/20] - step 0 loss: 0.0036
acc:1.0
(total - correct): 0
params: 28938

nopool
Epoch[19/20] - step 0 loss: 0.0103
acc: 0.9964
(total - correct): 36
params: 264138


avgpool
Epoch[19/20] - step 0 loss: 0.0517
acc: 0.9919
(total - correct): 81
params: 28938

----------------------------------------------------
Model: BaselineCNN
epochs: 20

maxpool
Epoch[19/20] - step 0 loss: 0.0002
acc:0.999
(total - correct): 10
params: 28938

nopool
Epoch[19/20] - step 0 loss: 0.0006
acc: 0.9955
(total - correct): 45
params: 264138


avgpool
Epoch[19/20] - step 0 loss: 0.0038
acc: 0.9989
(total - correct): 11
params: 28938

----------------------------------------------------
Model: EFFBB
epochs: 20

nopool
Epoch[19/20] - step 0 loss: 0.0038
acc: 
(total - correct): 
params: 234378

maxpool
Epoch[19/20] - step 0 loss: 0.0001
acc:1.0
(total - correct): 0
params: 131402

avgpool
Epoch[19/20] - step 0 loss: 0.0356
acc:0.9974
(total - correct): 
params: 131402

----------------------------------------------------
Model: SquashEFFBB
epochs: 20

nopool
Epoch[19/20] - step 0 loss: 0.0583
acc: 0.9741
(total - correct): 259
params: 234378

maxpool
Epoch[19/20] - step 0 loss: 0.0022
acc:1.0
(total - correct): 0
params: 131402

avgpool
Epoch[19/20] - step 0 loss: 0.0356
acc:0.9855
(total - correct): 145
params: 131402


In [331]:
class EffBB(nn.Module):
    """
        Backbone model from Efficient-CapsNet for MNIST
    """

    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            #nn.MaxPool2d(2),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.AvgPool2d(kernel_size=2),
            #nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding="valid"),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, padding="valid"),
            nn.ReLU(inplace=True),
            #nn.BatchNorm2d(128),
        )
        #self.out = nn.Linear(128 * 9 * 9, 10)
        self.out = nn.Linear(128, 10)
    def forward(self, x):
        """
            IN:
                x (b, 1, 28, 28)
            OUT:
                x (b, 128, 9, 9)
        """
        x = self.layers(x)
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

In [332]:
class SquashEffBB(nn.Module):
    """
        Backbone model from Efficient-CapsNet for MNIST
    """

    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(5, 5), padding="valid"),
            Squash(),
            #nn.MaxPool2d(2),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding="valid"),
            Squash(),
            #nn.MaxPool2d(2),
            nn.AvgPool2d(kernel_size=2),
            nn.Conv2d(64, 64, kernel_size=(3, 3), padding="valid"),
            Squash(),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=2, padding="valid"),
            Squash(),
        )
        #self.out = nn.Linear(128 * 9 * 9, 10)
        self.out = nn.Linear(128, 10)
    def forward(self, x):
        """
            IN:
                x (b, 1, 28, 28)
            OUT:
                x (b, 128, 9, 9)
        """
        x = self.layers(x)
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

In [333]:
#test on normal conv
a = torch.rand([1,1,28,28])
mod  =  SquashEffBB()
#mod  =  EffBB()

x = mod(a)



x.shape

torch.Size([1, 10])