In [1]:
import os
import torch
import numpy as np
import os.path as op
from torch import nn
from time import time
from capsules import *
import skimage.io as io
from torch.optim import Adam
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
from skimage.transform import resize
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

In [2]:
class CatsCaps(nn.Module):
    
    def __init__(self):
        super(CatsCaps, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=256,
                kernel_size=9,
                stride=1
            ),
            nn.ReLU(inplace=True)
        )
        self.primcaps = PrimaryCapsuleLayer()
        self.digicaps = SecondaryCapsuleLayer(n_capsules=2, n_routes=18432)
        self.decoder = RegularizingDecoder(dims=[32,1024,4096,12288])
        
    def forward(self, x):
        """Compute forward of capsules, get the longest vectors, reconstruct the pictures"""
        u = self.conv(x)
        u = self.primcaps(u)
        internal = self.digicaps(u)
        internal = internal.squeeze().transpose(1,0)
        #internal = self.digicaps(self.primcaps(self.conv(x)))
        lengths = F.softmax(
            (internal**2).sum(dim=-1)**0.5, dim=-1
        )
        _, max_caps_index = lengths.max(dim=-1)
        masked = Variable(torch.eye(2))
        masked = masked.cuda() if torch.cuda.is_available() else masked
        masked = masked.index_select(dim=0, index=max_caps_index)
        reconstruction = self.decoder(
            (internal*masked[:,:,None]).view(x.size(0), -1)
        )
        return(internal, reconstruction, lengths, max_caps_index)

In [3]:
cn = CatsCaps().cuda()
optimizer = Adam(cn.parameters())
capsule_loss = CapsuleLoss().cuda()



### Data preprocessing

Get the data from https://www.kaggle.com/c/dogs-vs-cats/data and store it somewhere

In [4]:
dpath = "/home/bakirillov/Documents/datasets/cats_vs_dogs/train/"

In [5]:
def load_file(fn):
    i = np.swapaxes(io.imread(fn)/255.0, 0, 2)
    label = 0 if op.split(fn)[-1].split(".")[0] == "cat" else 1
    return(i, label)

In [6]:
def iterate_minibatches(filenames, batchsize, im_shape):
    indices = np.random.permutation(np.arange(len(filenames)))
    for start in range(0, len(indices), batchsize):
        ix = indices[start: start + batchsize]
        il = [load_file(a) for a in np.array(filenames)[ix]]
        X = np.stack([resize(a[0], im_shape) for a in il])
        y = np.array([a[1] for a in il])
        yield X, y

In [7]:
filenames = [op.join(dpath, b) for b in [a for a in os.walk(dpath)][0][2]]

In [8]:
train_fn, test_fn, _, _ = train_test_split(filenames, filenames)

In [9]:
len(train_fn)

18750

In [10]:
len(test_fn)

6250

### Model training

In [13]:
training_loss = []
training_accuracy = []
testing_loss = []
testing_accuracy = []
start_time = time()
n_epochs = 10
batch_size = 10
for epoch in range(n_epochs):
    print("Epoch #"+str(epoch))
    cn.train()
    running_loss = []
    running_accuracy = []
    i = 0
    epoch_time = time()
    for batch_X, batch_y in iterate_minibatches(train_fn, batch_size, (3, 64, 64)):
        if i % batch_size == 0:
            print("Minibatch "+str(i)+" of "+str(round(len(train_fn)/batch_size)))
        i += 1
        optimizer.zero_grad()
        inp = Variable(
            torch.from_numpy(batch_X.reshape(batch_X.shape[0],3,64,64)/255.0).type(torch.FloatTensor)
        ).cuda()
        real_class = Variable(
            make_y(torch.from_numpy(batch_y).type(torch.LongTensor).cuda(), 2)
        )
        internal, reconstruction, classes, max_index = cn(inp)
        loss = capsule_loss(
            real_class, inp.view(inp.size(0), 3*64*64), classes, reconstruction
        )
        loss.backward()
        optimizer.step()
        running_loss.append(loss.cpu().data.numpy())
        running_accuracy.append(
            accuracy_score(max_index.cpu().data.numpy(), batch_y)
        )
    training_loss.append(np.mean(running_loss))
    training_accuracy.append(np.mean(running_accuracy))
    print("Training loss: "+str(training_loss[-1]))
    print("Training accuracy: "+str(training_accuracy[-1]))
    print("Training in "+str(epoch)+"th epoch took "+str((time() - epoch_time)/60) + " minutes")
    cn.eval()
    running_test_loss = []
    running_test_accuracy = []
    i = 0
    for batch_X, batch_y in iterate_minibatches(test_fn, batch_size, (3, 64, 64)):
        if i % batch_size == 0:
            print("Testing minibatch "+str(i)+" of "+str(round(len(test_fn)/batch_size)))
        i += 1
        test_X = Variable(
            torch.from_numpy(batch_X.reshape(batch_X.shape[0],3,64,64)/255.0).type(torch.FloatTensor)
        ).cuda()
        test_Y = Variable(
            make_y(torch.from_numpy(batch_y).type(torch.LongTensor).cuda(), 2)
        )
        test_internal, test_reconstruction, test_classes, test_ind = cn(test_X)
        #test_classes = test_classes.type(torch.cuda.FloatTensor)
        running_test_loss.append(
            capsule_loss(
                test_Y, test_X.reshape(test_X.size(0), 3*64*64), 
                test_classes, test_reconstruction
            ).cpu().data.numpy()
        )
        running_test_accuracy.append(
            accuracy_score(test_ind.cpu().data.numpy(), batch_y)
        )
    testing_loss.append(np.mean(running_test_loss))
    testing_accuracy.append(np.mean(running_test_accuracy))
    print("Testing loss: "+str(testing_loss[-1]))
    print("Testing accuracy: "+str(testing_accuracy[-1]))
    print(str(epoch)+"th epoch took "+str((time() - epoch_time)/60) + " minutes")
print("All epochs took "+str((time() - start_time)/60) + " minutes")

Epoch #0


  warn("The default mode, 'constant', will be changed to 'reflect' in "


Minibatch 0 of 1875
Minibatch 10 of 1875
Minibatch 20 of 1875
Minibatch 30 of 1875
Minibatch 40 of 1875
Minibatch 50 of 1875
Minibatch 60 of 1875
Minibatch 70 of 1875
Minibatch 80 of 1875
Minibatch 90 of 1875
Minibatch 100 of 1875
Minibatch 110 of 1875
Minibatch 120 of 1875
Minibatch 130 of 1875
Minibatch 140 of 1875
Minibatch 150 of 1875
Minibatch 160 of 1875
Minibatch 170 of 1875
Minibatch 180 of 1875
Minibatch 190 of 1875
Minibatch 200 of 1875
Minibatch 210 of 1875
Minibatch 220 of 1875
Minibatch 230 of 1875
Minibatch 240 of 1875
Minibatch 250 of 1875
Minibatch 260 of 1875
Minibatch 270 of 1875
Minibatch 280 of 1875
Minibatch 290 of 1875
Minibatch 300 of 1875
Minibatch 310 of 1875
Minibatch 320 of 1875
Minibatch 330 of 1875
Minibatch 340 of 1875
Minibatch 350 of 1875
Minibatch 360 of 1875
Minibatch 370 of 1875
Minibatch 380 of 1875
Minibatch 390 of 1875
Minibatch 400 of 1875
Minibatch 410 of 1875
Minibatch 420 of 1875
Minibatch 430 of 1875
Minibatch 440 of 1875
Minibatch 450 of 1875

Minibatch 890 of 1875
Minibatch 900 of 1875
Minibatch 910 of 1875
Minibatch 920 of 1875
Minibatch 930 of 1875
Minibatch 940 of 1875
Minibatch 950 of 1875
Minibatch 960 of 1875
Minibatch 970 of 1875
Minibatch 980 of 1875
Minibatch 990 of 1875
Minibatch 1000 of 1875
Minibatch 1010 of 1875
Minibatch 1020 of 1875
Minibatch 1030 of 1875
Minibatch 1040 of 1875
Minibatch 1050 of 1875
Minibatch 1060 of 1875
Minibatch 1070 of 1875
Minibatch 1080 of 1875
Minibatch 1090 of 1875
Minibatch 1100 of 1875
Minibatch 1110 of 1875
Minibatch 1120 of 1875
Minibatch 1130 of 1875
Minibatch 1140 of 1875
Minibatch 1150 of 1875
Minibatch 1160 of 1875
Minibatch 1170 of 1875
Minibatch 1180 of 1875
Minibatch 1190 of 1875
Minibatch 1200 of 1875
Minibatch 1210 of 1875
Minibatch 1220 of 1875
Minibatch 1230 of 1875
Minibatch 1240 of 1875
Minibatch 1250 of 1875
Minibatch 1260 of 1875
Minibatch 1270 of 1875
Minibatch 1280 of 1875
Minibatch 1290 of 1875
Minibatch 1300 of 1875
Minibatch 1310 of 1875
Minibatch 1320 of 1875

Minibatch 1750 of 1875
Minibatch 1760 of 1875
Minibatch 1770 of 1875
Minibatch 1780 of 1875
Minibatch 1790 of 1875
Minibatch 1800 of 1875
Minibatch 1810 of 1875
Minibatch 1820 of 1875
Minibatch 1830 of 1875
Minibatch 1840 of 1875
Minibatch 1850 of 1875
Minibatch 1860 of 1875
Minibatch 1870 of 1875
Training loss: 0.1872793
Training accuracy: 0.7086933333333334
Training in 2th epoch took 21.357480351130167 minutes
Testing minibatch 0 of 625
Testing minibatch 10 of 625
Testing minibatch 20 of 625
Testing minibatch 30 of 625
Testing minibatch 40 of 625
Testing minibatch 50 of 625
Testing minibatch 60 of 625
Testing minibatch 70 of 625
Testing minibatch 80 of 625
Testing minibatch 90 of 625
Testing minibatch 100 of 625
Testing minibatch 110 of 625
Testing minibatch 120 of 625
Testing minibatch 130 of 625
Testing minibatch 140 of 625
Testing minibatch 150 of 625
Testing minibatch 160 of 625
Testing minibatch 170 of 625
Testing minibatch 180 of 625
Testing minibatch 190 of 625
Testing minibat

Testing minibatch 540 of 625
Testing minibatch 550 of 625
Testing minibatch 560 of 625
Testing minibatch 570 of 625
Testing minibatch 580 of 625
Testing minibatch 590 of 625
Testing minibatch 600 of 625
Testing minibatch 610 of 625
Testing minibatch 620 of 625
Testing loss: 0.191315
Testing accuracy: 0.69376
3th epoch took 24.460268970330556 minutes
Epoch #4
Minibatch 0 of 1875
Minibatch 10 of 1875
Minibatch 20 of 1875
Minibatch 30 of 1875
Minibatch 40 of 1875
Minibatch 50 of 1875
Minibatch 60 of 1875
Minibatch 70 of 1875
Minibatch 80 of 1875
Minibatch 90 of 1875
Minibatch 100 of 1875
Minibatch 110 of 1875
Minibatch 120 of 1875
Minibatch 130 of 1875
Minibatch 140 of 1875
Minibatch 150 of 1875
Minibatch 160 of 1875
Minibatch 170 of 1875
Minibatch 180 of 1875
Minibatch 190 of 1875
Minibatch 200 of 1875
Minibatch 210 of 1875
Minibatch 220 of 1875
Minibatch 230 of 1875
Minibatch 240 of 1875
Minibatch 250 of 1875
Minibatch 260 of 1875
Minibatch 270 of 1875
Minibatch 280 of 1875
Minibatch 29

Minibatch 730 of 1875
Minibatch 740 of 1875
Minibatch 750 of 1875
Minibatch 760 of 1875
Minibatch 770 of 1875
Minibatch 780 of 1875
Minibatch 790 of 1875
Minibatch 800 of 1875
Minibatch 810 of 1875
Minibatch 820 of 1875
Minibatch 830 of 1875
Minibatch 840 of 1875
Minibatch 850 of 1875
Minibatch 860 of 1875
Minibatch 870 of 1875
Minibatch 880 of 1875
Minibatch 890 of 1875
Minibatch 900 of 1875
Minibatch 910 of 1875
Minibatch 920 of 1875
Minibatch 930 of 1875
Minibatch 940 of 1875
Minibatch 950 of 1875
Minibatch 960 of 1875
Minibatch 970 of 1875
Minibatch 980 of 1875
Minibatch 990 of 1875
Minibatch 1000 of 1875
Minibatch 1010 of 1875
Minibatch 1020 of 1875
Minibatch 1030 of 1875
Minibatch 1040 of 1875
Minibatch 1050 of 1875
Minibatch 1060 of 1875
Minibatch 1070 of 1875
Minibatch 1080 of 1875
Minibatch 1090 of 1875
Minibatch 1100 of 1875
Minibatch 1110 of 1875
Minibatch 1120 of 1875
Minibatch 1130 of 1875
Minibatch 1140 of 1875
Minibatch 1150 of 1875
Minibatch 1160 of 1875
Minibatch 1170 

Minibatch 1600 of 1875
Minibatch 1610 of 1875
Minibatch 1620 of 1875
Minibatch 1630 of 1875
Minibatch 1640 of 1875
Minibatch 1650 of 1875
Minibatch 1660 of 1875
Minibatch 1670 of 1875
Minibatch 1680 of 1875
Minibatch 1690 of 1875
Minibatch 1700 of 1875
Minibatch 1710 of 1875
Minibatch 1720 of 1875
Minibatch 1730 of 1875
Minibatch 1740 of 1875
Minibatch 1750 of 1875
Minibatch 1760 of 1875
Minibatch 1770 of 1875
Minibatch 1780 of 1875
Minibatch 1790 of 1875
Minibatch 1800 of 1875
Minibatch 1810 of 1875
Minibatch 1820 of 1875
Minibatch 1830 of 1875
Minibatch 1840 of 1875
Minibatch 1850 of 1875
Minibatch 1860 of 1875
Minibatch 1870 of 1875
Training loss: 0.17743951
Training accuracy: 0.7334933333333333
Training in 6th epoch took 21.256649498144785 minutes
Testing minibatch 0 of 625
Testing minibatch 10 of 625
Testing minibatch 20 of 625
Testing minibatch 30 of 625
Testing minibatch 40 of 625
Testing minibatch 50 of 625
Testing minibatch 60 of 625
Testing minibatch 70 of 625
Testing minibat

Testing minibatch 420 of 625
Testing minibatch 430 of 625
Testing minibatch 440 of 625
Testing minibatch 450 of 625
Testing minibatch 460 of 625
Testing minibatch 470 of 625
Testing minibatch 480 of 625
Testing minibatch 490 of 625
Testing minibatch 500 of 625
Testing minibatch 510 of 625
Testing minibatch 520 of 625
Testing minibatch 530 of 625
Testing minibatch 540 of 625
Testing minibatch 550 of 625
Testing minibatch 560 of 625
Testing minibatch 570 of 625
Testing minibatch 580 of 625
Testing minibatch 590 of 625
Testing minibatch 600 of 625
Testing minibatch 610 of 625
Testing minibatch 620 of 625
Testing loss: 0.18426318
Testing accuracy: 0.7124799999999999
7th epoch took 24.460166490077974 minutes
Epoch #8
Minibatch 0 of 1875
Minibatch 10 of 1875
Minibatch 20 of 1875
Minibatch 30 of 1875
Minibatch 40 of 1875
Minibatch 50 of 1875
Minibatch 60 of 1875
Minibatch 70 of 1875
Minibatch 80 of 1875
Minibatch 90 of 1875
Minibatch 100 of 1875
Minibatch 110 of 1875
Minibatch 120 of 1875
Min

Minibatch 560 of 1875
Minibatch 570 of 1875
Minibatch 580 of 1875
Minibatch 590 of 1875
Minibatch 600 of 1875
Minibatch 610 of 1875
Minibatch 620 of 1875
Minibatch 630 of 1875
Minibatch 640 of 1875
Minibatch 650 of 1875
Minibatch 660 of 1875
Minibatch 670 of 1875
Minibatch 680 of 1875
Minibatch 690 of 1875
Minibatch 700 of 1875
Minibatch 710 of 1875
Minibatch 720 of 1875
Minibatch 730 of 1875
Minibatch 740 of 1875
Minibatch 750 of 1875
Minibatch 760 of 1875
Minibatch 770 of 1875
Minibatch 780 of 1875
Minibatch 790 of 1875
Minibatch 800 of 1875
Minibatch 810 of 1875
Minibatch 820 of 1875
Minibatch 830 of 1875
Minibatch 840 of 1875
Minibatch 850 of 1875
Minibatch 860 of 1875
Minibatch 870 of 1875
Minibatch 880 of 1875
Minibatch 890 of 1875
Minibatch 900 of 1875
Minibatch 910 of 1875
Minibatch 920 of 1875
Minibatch 930 of 1875
Minibatch 940 of 1875
Minibatch 950 of 1875
Minibatch 960 of 1875
Minibatch 970 of 1875
Minibatch 980 of 1875
Minibatch 990 of 1875
Minibatch 1000 of 1875
Minibatch

In [14]:
torch.save(cn.state_dict(), "CatsVsDogs_capsule_20.ptch")