In [None]:
import os
import sys
if 'google.colab' in sys.modules:
    import gdown
    if 'torch' not in sys.modules:
        !pip3 install torch==1.2.0+cu92 torchvision==0.4.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
    if 'skia-python' not in sys.modules:
        !pip3 install skia-python

    if os.getcwd() != '/content/DeepGraphemics':
        !git clone https://github.com/bensapirstein/DeepGraphemics.git
        %cd DeepGraphemics/

    url = 'https://drive.google.com/drive/folders/1X3ERUGyhMZo_ZlHApI1XkjcZAVcnTRNd?usp=drive_link'

    gdown.download_folder(url)

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# unzip several databases based on category and letter.

categories = ["base", "moderate", "rotation", "rich_moderate", "rich_rotation"]

for ds_type in categories:
    root_dir = f"datasets/{ds_type}_dataset"
    if not os.path.exists(root_dir):
        zipped_data = f"{root_dir}.zip"
        !unzip -q $zipped_data -d datasets/

In [None]:
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.optim import lr_scheduler
import torch.nn.functional as F
from src.capsnet import CapsNet, ReconstructionNet, CapsNetWithReconstruction, MarginLoss

batch_size = 128
test_batch_size = 1000
epochs = 9
save_every = 3
lr = 0.001
no_cuda = False
seed = 42
log_interval = 10
routing_iterations = 3
with_reconstruction = True
n_classes = 13

cuda = not no_cuda and torch.cuda.is_available()

torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from src.datasets import GraphemesDataset

img_transform = transforms.Compose([
    transforms.ToTensor()
])

ds_type = "moderate"
letter = "aleph"
root_dir = f"datasets/{ds_type}_dataset"

train_dataset = GraphemesDataset(root_dir, train=True, by_letter=letter, transform=img_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = GraphemesDataset(root_dir, train=False, by_letter=letter, transform=img_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import src.plot as plot

fig, axes = plot.dataset_first_n(train_dataset, 64,
                                 show_classes=True,
                                 class_labels=train_dataset.classes,
                                 nrows=8, hspace=0.5, cmap='gray')

In [None]:


model = CapsNet(routing_iterations, n_classes)

if with_reconstruction:
    reconstruction_model = ReconstructionNet(16, n_classes)
    reconstruction_alpha = 0.5
    model = CapsNetWithReconstruction(model, reconstruction_model)

if cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(), lr=lr)

scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=15, min_lr=1e-6)

loss_fn = MarginLoss(0.9, 0.1, 0.5)


def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target, requires_grad=False)
        optimizer.zero_grad()
        if with_reconstruction:
            output, probs = model(data, target)
            reconstruction_loss = F.mse_loss(output, data.view(-1, 784))
            margin_loss = loss_fn(probs, target)
            loss = reconstruction_alpha * reconstruction_loss + margin_loss
        else:
            output, probs = model(data)
            loss = loss_fn(probs, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                        100. * batch_idx / len(train_loader), loss.data.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

        if with_reconstruction:
            output, probs = model(data, target)
            reconstruction_loss = F.mse_loss(output, data.view(-1, 784), size_average=False).data.item()
            test_loss += loss_fn(probs, target, size_average=False).data.item()
            test_loss += reconstruction_alpha * reconstruction_loss
        else:
            output, probs = model(data)
            test_loss += loss_fn(probs, target, size_average=False).data.item()

        pred = probs.data.max(1, keepdim=True)[1]  # get the index of the max probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss


In [None]:
dst_dir = "pretrained_models/"

for epoch in range(1, epochs + 1):
    train(epoch)
    test_loss = test()
    scheduler.step(test_loss)
    if epoch % save_every == 0:
        torch.save(model.state_dict(), dst_dir +
                    '{:03d}_model_dict_{}routing_reconstruction{}.pth'.format(epoch, routing_iterations,
                                                                            with_reconstruction))

# Helper functions

In [None]:
# (1x28x28 tensor input)
def get_digit_caps(model, image):
    input_ = Variable(image.unsqueeze(0), volatile=True)
    digit_caps, probs = model.capsnet(input_)
    return digit_caps

# takes digit_caps output and target label
def get_reconstruction(model, digit_caps, label):
    target = Variable(torch.LongTensor([label]), volatile=True)
    reconstruction = model.reconstruction_net(digit_caps, target)
    return reconstruction.data.cpu().numpy()[0].reshape(28, 28)

# create reconstructions with perturbed digit capsule
def dimension_perturbation_reconstructions(model, digit_caps, label, dimension, dim_values):
    reconstructions = []
    for dim_value in dim_values:
        digit_caps_perturbed = digit_caps.clone()
        digit_caps_perturbed[0, label, dimension] = dim_value
        reconstruction = get_reconstruction(model, digit_caps_perturbed, label)
        reconstructions.append(reconstruction)
    return reconstructions

# Visualizations

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

## Sample reconstructions

In [None]:
dataset = GraphemesDataset(root_dir, test_size=0, by_letter=letter, transform=img_transform)

# Get reconstructions
images = []
reconstructions = []
# sample 8 random images from the dataset
idx = np.random.randint(0, len(dataset), size=8)
for i in idx:
    image_tensor, label = dataset[i]
    digit_caps = get_digit_caps(model, image_tensor)
    reconstruction = get_reconstruction(model, digit_caps, label)
    images.append(image_tensor.numpy()[0])
    reconstructions.append(reconstruction)

In [None]:
# Plot reconstructions
fig, axs = plt.subplots(2, 8, figsize=(16, 4))
axs[0, 0].set_ylabel('Org image', size='large')
axs[1, 0].set_ylabel('Reconstruction', size='large')
for i in range(8):
    axs[0, i].imshow(images[i], cmap='gray')
    axs[1, i].imshow(reconstructions[i], cmap='gray')
    axs[0, i].set_yticks([])
    axs[0, i].set_xticks([])
    axs[1, i].set_yticks([])
    axs[1, i].set_xticks([])

## What the individual dimensions of a capsule represent

We can visualize what an individual dimension of a capsule represents by perturbing values of each dimension (sec. 5.1. of the paper, figure 4).
Each row shows the reconstruction when one of the 16 dimensions in the DigitCaps representation is tweaked by intervals of 0.05 in the range [âˆ’0.25, 0.25].

In [None]:
digit, label = dataset[0]
perturbed_reconstructions = []
perturbation_values = [0.05*i for i in range(-5, 6)]
digit_caps = get_digit_caps(model, digit)
for dimension in range(16):
    perturbed_reconstructions.append(
        dimension_perturbation_reconstructions(model, digit_caps, label,
                                               dimension, perturbation_values)
    )

In [None]:
fig, axs = plt.subplots(16, 11, figsize=(11*1.5, 16*1.5))
for i in range(16):
    axs[i, 0].set_ylabel('dim {}'.format(i), size='large')
    for j in range(11):
        axs[i, j].imshow(perturbed_reconstructions[i][j], cmap='gray')
        axs[i, j].set_yticks([])
        axs[i, j].set_xticks([])