In [2]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import time
USE_GPU = True
from model import *
from data_loaders import load_mnist
from tools import weights_init_xavier
from stats import Statistics
import matplotlib.pyplot as plt
import os.path as path
SAVE_DIR = "saved_models"
FILENAME = "modelk.pt"

In [3]:
def initialize_weights(capsnet):
  capsnet.conv_layer.conv.apply(weights_init_xavier)
  capsnet.primary_capsules.apply(weights_init_xavier)
  capsnet.decoder.apply(weights_init_xavier)
  #nn.init.xavier_normal_(capsnet.digit_caps.W)

In [4]:
capsnet = CapsNet(reconstruction_type="FC") # FC or Conv
if USE_GPU:
  capsnet.cuda()
optimizer = torch.optim.Adam(capsnet.parameters())

In [5]:
filepath = path.join(SAVE_DIR, FILENAME)
if path.isfile(filepath):
  print("Saved model found")
  capsnet.load_state_dict(torch.load(filepath))
else:
  print("Saved model not found; Model initialized.")
  initialize_weights(capsnet)

Saved model not found; Model initialized.


In [6]:
"""Hyperparameters"""
max_epochs = 1000
batch_size = 128
train_loader, test_loader = load_mnist(batch_size)

In [7]:
import time
import numpy as np

class Statistics:
  
  def __init__(self):
    self.TEST_LOSSES = []
    self.TRAIN_LOSSES = []
    self.TEST_ACC = []

    self.reset_tracking_stats()
  
  def reset_tracking_stats(self):
    self.train_loss = 0
    self.train_steps = 0
    self.test_loss = 0
    self.test_steps = 0
    self.test_correct = 0
    self.test_num_samples = 0
    self.time = time.time()    
    
  def track_train(self, train_loss):
    self.train_steps += 1
    self.train_loss += train_loss

  def track_test(self, test_loss, target, prediction):
    # Calculate accuracy
    self.test_correct += (target.max(dim=1)[1] == prediction.max(dim=1)[1]).sum().item()
    self.test_num_samples += target.size(0)

    # Track test loss
    self.test_steps += 1
    self.test_loss += test_loss
  
  def save_stats(self, epoch):
    time_spent = time.time() - self.time
    train_loss = self.train_loss / self.train_steps
    test_loss = self.test_loss / self.test_steps
    test_acc = self.test_correct / self.test_num_samples
    self.TEST_ACC.append(test_acc)
    self.TEST_LOSSES.append(test_loss)
    self.TRAIN_LOSSES.append(train_loss)
    print("Epoch: {:3.0f} \t Time: {:3.0f} \t Test: {:.3f} \t Train: {:.3f} \t Accuracy: {:3.4f}".format(epoch, time_spent, test_loss, train_loss, test_acc*100))
    self.reset_tracking_stats()
    

In [None]:
stats = Statistics()
display_step = 450

for epoch in range(max_epochs):
  capsnet.train()
  for batch, (data, target) in list(enumerate(train_loader)):
    target = torch.eye(10).index_select(dim=0, index=target)
    data, target = Variable(data), Variable(target)
    if USE_GPU:
      data, target = data.cuda(), target.cuda()
    
    optimizer.zero_grad()
    
    output, reconstructions, masked = capsnet(data, target)
    loss = capsnet.loss(data, target, output, reconstructions)
    
    loss.backward()
    optimizer.step()
    
    stats.track_train(loss.data.item())
    
    if batch % display_step == 0 and batch != 0:
      capsnet.eval()

      for batch_id, (data, target) in enumerate(test_loader):
        target = torch.eye(10).index_select(dim=0, index=target)
        data, target = Variable(data), Variable(target)
        if USE_GPU:
          data,target = data.cuda(), target.cuda()

        output, reconstructions, masked = capsnet(data)
        loss = capsnet.loss(data, target, output, reconstructions)
        
        stats.track_test(loss.data.item(), target, masked)
        
      stats.save_stats(epoch)

      filepath = path.join(SAVE_DIR, "model{}.pt".format(epoch))
      torch.save(capsnet.state_dict(), filepath)
      capsnet.train()
  

In [None]:
capsnet.eval()
data, target = iter(test_loader).next()
output, reconstruction, masked = capsnet(data.cuda())

In [None]:
predictions = torch.max((output**2).sum(dim=2).squeeze(), dim=1)[1].cpu().data.numpy()

In [None]:
i = 7
print(target[i], predictions[i])
im = reconstruction[i,0].data.cpu().numpy()
im += abs(im.min())
im /= im.max()
plt.subplot(2,1,1)
plt.imshow(im, cmap="gray")
im2 = data[i, 0].data.cpu().numpy()
im2 += abs(im.min())
im2 /= im.max()
plt.subplot(2,1,2)
plt.imshow(im2, cmap="gray")