In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import time
USE_GPU = True
from model import *
from data_loaders import load_mnist
from tools import weights_init_xavier
import matplotlib.pyplot as plt
import os.path as path
SAVE_DIR = "saved_models"
FILENAME = "modelk.pt"

In [2]:
def initialize_weights(capsnet):
  capsnet.conv_layer.conv.apply(weights_init_xavier)
  capsnet.primary_capsules.apply(weights_init_xavier)
  capsnet.decoder.apply(weights_init_xavier)
  #nn.init.xavier_normal_(capsnet.digit_caps.W)

In [3]:
capsnet = CapsNet(reconstruction_type="FC")
if USE_GPU:
  capsnet.cuda()
optimizer = torch.optim.Adam(capsnet.parameters())

In [4]:
filepath = path.join(SAVE_DIR, FILENAME)
if path.isfile(filepath):
  print("Saved model found")
  capsnet.load_state_dict(torch.load(filepath))
else:
  print("Saved model not found; Model initialized.")
  initialize_weights(capsnet)

Saved model not found; Model initialized.


In [5]:
"""Hyperparameters"""
max_epochs = 1000
batch_size = 128
train_loader, test_loader = load_mnist(batch_size)

In [7]:
iter(train_loader).next()[0][0]

tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0

In [None]:
t = time.time()
Te_LOSS = []
Tr_LOSS = []
test_acc = []
display_step = 450
for epoch in range(max_epochs):
  capsnet.train()
  train_loss = 0
  for batch, (data, target) in list(enumerate(train_loader)):
    target = torch.eye(10).index_select(dim=0, index=target)
    data, target = Variable(data), Variable(target)
    
    if USE_GPU:
      data, target = data.cuda(), target.cuda()
    
    optimizer.zero_grad()
    
    output, reconstructions, masked = capsnet(data, target)
    loss = capsnet.loss(data, target, output, reconstructions)
    loss.backward()
    optimizer.step()
    
    train_loss += loss.data.item()
    if batch % display_step == 0 and batch != 0:
      eval_time = time.time()

      capsnet.eval()
      test_loss = 0
      test_correct = 0
      test_total = 0
      for batch_id, (data, target) in enumerate(test_loader):
        target = torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)

        if USE_GPU:
          data,target = data.cuda(), target.cuda()

        output, reconstruction, masked = capsnet(data)
        loss = capsnet.loss(data, target, output, reconstruction)

        test_loss += loss.data.item()
        test_total += data.size(0)
        test_correct += sum(np.argmax(masked.data.cpu().numpy(),1 ) == np.argmax(target.data.cpu().numpy(), 1))

      acc = test_correct / test_total
      Te_LOSS.append(test_loss / len(test_loader))
      Tr_LOSS.append(train_loss / len(train_loader))
      test_acc.append(acc)
      test_loss /= len(test_loader)
      train_loss /= len(train_loader)
      time_spent = time.time() - t
      t = time.time()
      
      print("Epoch: {:3.0f} \t Time: {:3.0f} \t Test: {:.3f} \t Train: {:.3f} \t Accuracy: {:3.4f}".format(epoch, time_spent,test_loss, train_loss, acc*100))
      train_loss = 0
      filepath = path.join(SAVE_DIR, "model{}.pt".format(epoch))
      torch.save(capsnet.state_dict(), filepath)
      capsnet.train()
  

Epoch:   0 	 Time: 106 	 Test: 0.455 	 Train: 0.691 	 Accuracy: 91.6700
Epoch:   1 	 Time: 111 	 Test: 0.174 	 Train: 0.272 	 Accuracy: 97.9700
Epoch:   2 	 Time: 112 	 Test: 0.124 	 Train: 0.138 	 Accuracy: 98.3500
Epoch:   3 	 Time: 112 	 Test: 0.081 	 Train: 0.095 	 Accuracy: 98.6900
Epoch:   4 	 Time: 111 	 Test: 0.054 	 Train: 0.059 	 Accuracy: 98.6800
Epoch:   5 	 Time: 107 	 Test: 0.037 	 Train: 0.044 	 Accuracy: 98.9800
Epoch:   6 	 Time: 112 	 Test: 0.044 	 Train: 0.034 	 Accuracy: 98.7600
Epoch:   7 	 Time: 112 	 Test: 0.034 	 Train: 0.030 	 Accuracy: 99.1800
Epoch:   8 	 Time: 112 	 Test: 0.040 	 Train: 0.028 	 Accuracy: 98.9100
Epoch:   9 	 Time: 106 	 Test: 0.055 	 Train: 0.025 	 Accuracy: 98.9200
Epoch:  10 	 Time: 112 	 Test: 0.035 	 Train: 0.024 	 Accuracy: 99.0600
Epoch:  11 	 Time: 107 	 Test: 0.024 	 Train: 0.022 	 Accuracy: 99.3300
Epoch:  12 	 Time: 111 	 Test: 0.033 	 Train: 0.021 	 Accuracy: 99.2700
Epoch:  13 	 Time: 111 	 Test: 0.024 	 Train: 0.019 	 Accuracy: 

In [None]:
capsnet.eval()
data, target = iter(test_loader).next()
output, reconstruction, masked = capsnet(data.cuda())

In [None]:
predictions = torch.max((output**2).sum(dim=2).squeeze(), dim=1)[1].cpu().data.numpy()

In [None]:
i = 3
print(target[i], predictions[i])
im = reconstruction[i,0].data.cpu().numpy()
im += abs(im.min())
im /= im.max()
plt.imshow(im, cmap="gray")