In [1]:
import torch
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os

from model.net import *
from utils.training import *
from data.data import *

## Experiment Config

In [5]:
model_path = os.getcwd()

args = {
    'USE_CUDA': True if torch.cuda.is_available() else False,
    'BATCH_SIZE': 256,
    'N_EPOCHS': 30,
    'LEARNING_RATE': 0.1,
    'MOMENTUM': 0.9,
    'DATASET_NAME':'mnist',
    'LAMBDA_recon': 1, #0.0005
    'LAMBDA_margin': 0,
}

## Model Loading

In [6]:
#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case
class Config:
    def __init__(self):
        # CNN (cnn)
        self.cnn_in_channels = 1
        self.cnn_out_channels = 12
        self.cnn_kernel_size = 15

        # Primary Capsule (pc)
        self.pc_num_capsules = 16
        self.pc_in_channels = 12
        self.pc_out_channels = 1
        self.pc_kernel_size = 8
        self.pc_num_routes = 1 * 7 * 7

        # Digit Capsule 1 (dc)
        self.dc_num_capsules = 49
        self.dc_num_routes = 1 * 7 * 7
        self.dc_in_channels = 16
        self.dc_out_channels = 16
        
        # Digit Capsule 2 (dc)
        self.dc_2_num_capsules = 10
        self.dc_2_num_routes = 1 * 7 * 7
        self.dc_2_in_channels = 16
        self.dc_2_out_channels = 16

        # Decoder
        self.input_width = 28
        self.input_height = 28

torch.manual_seed(1)
config = Config()

net = CapsNet(args, config)
# capsule_net = torch.nn.DataParallel(capsule_net)
if args['USE_CUDA']:
    net = net.cuda()
    
#freeze All layers except Decoder
to_freeze = [net.conv_layer, net.primary_capsules, net.digit_capsules_1, net.digit_capsules_2]
for layer in to_freeze:
    for param in layer.parameters():
        param.requires_grad = False
    layer.eval()
    
net.load_state_dict(torch.load(os.path.join(model_path, 'CapsNetMNIST.pth'), map_location='cpu'))

<All keys matched successfully>

## Loading Dataset

In [7]:
trainloader, testloader = dataset(args)

## Training CapsuleNet

In [8]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()))

for e in range(1, args['N_EPOCHS'] + 1):
    train_CapsNet(net, optimizer, trainloader, e, args)
    test_CapsNet(net, testloader, e, args)

  1%|▏         | 3/235 [00:00<02:04,  1.86it/s]

Epoch: [1/30], Batch: [1/235], train accuracy: 0.992188, loss: 0.002686


 44%|████▍     | 103/235 [00:08<00:09, 13.20it/s]

Epoch: [1/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.002621


 86%|████████▋ | 203/235 [00:16<00:02, 11.41it/s]

Epoch: [1/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002596


100%|██████████| 235/235 [00:19<00:00, 12.08it/s]

Epoch: [1/30], train loss: 0.002656



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [1/30], test accuracy: 0.985100, loss: 0.679176


  0%|          | 1/235 [00:00<00:42,  5.53it/s]

Epoch: [2/30], Batch: [1/235], train accuracy: 1.000000, loss: 0.002697


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [2/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.002665


 86%|████████▋ | 203/235 [00:17<00:02, 11.44it/s]

Epoch: [2/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002622


100%|██████████| 235/235 [00:20<00:00, 11.44it/s]

Epoch: [2/30], train loss: 0.002643



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [2/30], test accuracy: 0.985100, loss: 0.675888


  0%|          | 1/235 [00:00<00:44,  5.20it/s]

Epoch: [3/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002616


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [3/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002683


 86%|████████▋ | 203/235 [00:17<00:02, 11.44it/s]

Epoch: [3/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.002585


100%|██████████| 235/235 [00:20<00:00, 11.44it/s]

Epoch: [3/30], train loss: 0.002631



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [3/30], test accuracy: 0.985100, loss: 0.675460


  0%|          | 1/235 [00:00<00:45,  5.13it/s]

Epoch: [4/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002620


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [4/30], Batch: [101/235], train accuracy: 0.992188, loss: 0.002569


 86%|████████▋ | 203/235 [00:17<00:02, 11.43it/s]

Epoch: [4/30], Batch: [201/235], train accuracy: 0.984375, loss: 0.002641


100%|██████████| 235/235 [00:20<00:00, 11.43it/s]

Epoch: [4/30], train loss: 0.002621



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [4/30], test accuracy: 0.985100, loss: 0.675790


  0%|          | 1/235 [00:00<00:44,  5.21it/s]

Epoch: [5/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002567


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [5/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002602


 86%|████████▋ | 203/235 [00:17<00:02, 11.44it/s]

Epoch: [5/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002584


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [5/30], train loss: 0.002615



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [5/30], test accuracy: 0.985100, loss: 0.670443


  0%|          | 1/235 [00:00<00:45,  5.09it/s]

Epoch: [6/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002658


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [6/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002614


 86%|████████▋ | 203/235 [00:17<00:02, 11.43it/s]

Epoch: [6/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002593


100%|██████████| 235/235 [00:20<00:00, 11.43it/s]

Epoch: [6/30], train loss: 0.002606



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [6/30], test accuracy: 0.985100, loss: 0.670192


  0%|          | 1/235 [00:00<00:44,  5.21it/s]

Epoch: [7/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002583


 44%|████▍     | 103/235 [00:09<00:11, 11.46it/s]

Epoch: [7/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.002609


 86%|████████▋ | 203/235 [00:17<00:02, 11.46it/s]

Epoch: [7/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.002534


100%|██████████| 235/235 [00:20<00:00, 11.45it/s]

Epoch: [7/30], train loss: 0.002598



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [7/30], test accuracy: 0.985100, loss: 0.667033


  0%|          | 1/235 [00:00<00:44,  5.23it/s]

Epoch: [8/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002609


 44%|████▍     | 103/235 [00:09<00:11, 11.47it/s]

Epoch: [8/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002596


 86%|████████▋ | 203/235 [00:17<00:02, 11.47it/s]

Epoch: [8/30], Batch: [201/235], train accuracy: 0.972656, loss: 0.002572


100%|██████████| 235/235 [00:20<00:00, 11.45it/s]

Epoch: [8/30], train loss: 0.002592



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [8/30], test accuracy: 0.985100, loss: 0.664126


  0%|          | 1/235 [00:00<00:45,  5.11it/s]

Epoch: [9/30], Batch: [1/235], train accuracy: 0.996094, loss: 0.002522


 44%|████▍     | 103/235 [00:09<00:11, 11.45it/s]

Epoch: [9/30], Batch: [101/235], train accuracy: 0.992188, loss: 0.002522


 86%|████████▋ | 203/235 [00:17<00:02, 11.45it/s]

Epoch: [9/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.002572


100%|██████████| 235/235 [00:20<00:00, 11.44it/s]

Epoch: [9/30], train loss: 0.002582



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [9/30], test accuracy: 0.985100, loss: 0.664451


  0%|          | 1/235 [00:00<00:44,  5.21it/s]

Epoch: [10/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002620


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [10/30], Batch: [101/235], train accuracy: 0.996094, loss: 0.002511


 86%|████████▋ | 203/235 [00:17<00:02, 11.44it/s]

Epoch: [10/30], Batch: [201/235], train accuracy: 0.984375, loss: 0.002531


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [10/30], train loss: 0.002576



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [10/30], test accuracy: 0.985100, loss: 0.663504


  0%|          | 1/235 [00:00<00:44,  5.23it/s]

Epoch: [11/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002528


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [11/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002555


 86%|████████▋ | 203/235 [00:17<00:02, 11.42it/s]

Epoch: [11/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.002560


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [11/30], train loss: 0.002571



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [11/30], test accuracy: 0.985100, loss: 0.662825


  0%|          | 1/235 [00:00<00:44,  5.22it/s]

Epoch: [12/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002609


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [12/30], Batch: [101/235], train accuracy: 0.996094, loss: 0.002548


 86%|████████▋ | 203/235 [00:17<00:02, 11.43it/s]

Epoch: [12/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.002483


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [12/30], train loss: 0.002566



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [12/30], test accuracy: 0.985100, loss: 0.660917


  1%|▏         | 3/235 [00:00<00:38,  5.98it/s]

Epoch: [13/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002507


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [13/30], Batch: [101/235], train accuracy: 0.992188, loss: 0.002533


 86%|████████▋ | 203/235 [00:17<00:02, 11.41it/s]

Epoch: [13/30], Batch: [201/235], train accuracy: 0.976562, loss: 0.002550


100%|██████████| 235/235 [00:20<00:00, 11.41it/s]

Epoch: [13/30], train loss: 0.002562



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [13/30], test accuracy: 0.985100, loss: 0.660540


  0%|          | 1/235 [00:00<00:45,  5.12it/s]

Epoch: [14/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002657


 44%|████▍     | 103/235 [00:09<00:11, 11.39it/s]

Epoch: [14/30], Batch: [101/235], train accuracy: 0.992188, loss: 0.002584


 86%|████████▋ | 203/235 [00:17<00:02, 11.40it/s]

Epoch: [14/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002500


100%|██████████| 235/235 [00:20<00:00, 11.39it/s]

Epoch: [14/30], train loss: 0.002556



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [14/30], test accuracy: 0.985100, loss: 0.657980


  0%|          | 1/235 [00:00<00:45,  5.19it/s]

Epoch: [15/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002540


 44%|████▍     | 103/235 [00:09<00:11, 11.41it/s]

Epoch: [15/30], Batch: [101/235], train accuracy: 0.996094, loss: 0.002546


 86%|████████▋ | 203/235 [00:17<00:02, 11.45it/s]

Epoch: [15/30], Batch: [201/235], train accuracy: 0.984375, loss: 0.002539


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [15/30], train loss: 0.002553



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [15/30], test accuracy: 0.985100, loss: 0.656802


  0%|          | 1/235 [00:00<00:45,  5.14it/s]

Epoch: [16/30], Batch: [1/235], train accuracy: 0.984375, loss: 0.002482


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [16/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002490


 86%|████████▋ | 203/235 [00:17<00:02, 11.45it/s]

Epoch: [16/30], Batch: [201/235], train accuracy: 0.972656, loss: 0.002528


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [16/30], train loss: 0.002546



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [16/30], test accuracy: 0.985100, loss: 0.656492


  0%|          | 1/235 [00:00<00:45,  5.17it/s]

Epoch: [17/30], Batch: [1/235], train accuracy: 0.996094, loss: 0.002560


 44%|████▍     | 103/235 [00:09<00:11, 11.41it/s]

Epoch: [17/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.002542


 86%|████████▋ | 203/235 [00:17<00:02, 11.40it/s]

Epoch: [17/30], Batch: [201/235], train accuracy: 0.964844, loss: 0.002563


100%|██████████| 235/235 [00:20<00:00, 11.41it/s]

Epoch: [17/30], train loss: 0.002545



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [17/30], test accuracy: 0.985100, loss: 0.655309


  0%|          | 1/235 [00:00<00:46,  5.09it/s]

Epoch: [18/30], Batch: [1/235], train accuracy: 0.976562, loss: 0.002529


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [18/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002539


 86%|████████▋ | 203/235 [00:17<00:02, 11.36it/s]

Epoch: [18/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002589


100%|██████████| 235/235 [00:20<00:00, 11.41it/s]

Epoch: [18/30], train loss: 0.002542



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [18/30], test accuracy: 0.985100, loss: 0.654209


  0%|          | 1/235 [00:00<00:45,  5.14it/s]

Epoch: [19/30], Batch: [1/235], train accuracy: 0.996094, loss: 0.002577


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [19/30], Batch: [101/235], train accuracy: 1.000000, loss: 0.002513


 86%|████████▋ | 203/235 [00:17<00:02, 11.40it/s]

Epoch: [19/30], Batch: [201/235], train accuracy: 0.980469, loss: 0.002533


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [19/30], train loss: 0.002538



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [19/30], test accuracy: 0.985100, loss: 0.653713


  0%|          | 1/235 [00:00<00:45,  5.19it/s]

Epoch: [20/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002516


 44%|████▍     | 103/235 [00:09<00:11, 11.37it/s]

Epoch: [20/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.002494


 86%|████████▋ | 203/235 [00:17<00:02, 11.37it/s]

Epoch: [20/30], Batch: [201/235], train accuracy: 0.984375, loss: 0.002559


100%|██████████| 235/235 [00:20<00:00, 11.41it/s]

Epoch: [20/30], train loss: 0.002535



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [20/30], test accuracy: 0.985100, loss: 0.653753


  0%|          | 1/235 [00:00<00:46,  5.06it/s]

Epoch: [21/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002602


 44%|████▍     | 103/235 [00:09<00:11, 11.45it/s]

Epoch: [21/30], Batch: [101/235], train accuracy: 0.992188, loss: 0.002529


 86%|████████▋ | 203/235 [00:17<00:02, 11.47it/s]

Epoch: [21/30], Batch: [201/235], train accuracy: 0.992188, loss: 0.002490


100%|██████████| 235/235 [00:20<00:00, 11.44it/s]

Epoch: [21/30], train loss: 0.002532



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [21/30], test accuracy: 0.985100, loss: 0.654294


  0%|          | 1/235 [00:00<00:44,  5.26it/s]

Epoch: [22/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002489


 44%|████▍     | 103/235 [00:09<00:11, 11.44it/s]

Epoch: [22/30], Batch: [101/235], train accuracy: 0.992188, loss: 0.002429


 86%|████████▋ | 203/235 [00:17<00:02, 11.47it/s]

Epoch: [22/30], Batch: [201/235], train accuracy: 0.996094, loss: 0.002587


100%|██████████| 235/235 [00:20<00:00, 11.43it/s]

Epoch: [22/30], train loss: 0.002528



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [22/30], test accuracy: 0.985100, loss: 0.651219


  0%|          | 1/235 [00:00<00:44,  5.28it/s]

Epoch: [23/30], Batch: [1/235], train accuracy: 0.980469, loss: 0.002609


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [23/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.002473


 86%|████████▋ | 203/235 [00:17<00:02, 11.41it/s]

Epoch: [23/30], Batch: [201/235], train accuracy: 0.984375, loss: 0.002513


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [23/30], train loss: 0.002525



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [23/30], test accuracy: 0.985100, loss: 0.652518


  1%|▏         | 3/235 [00:00<00:39,  5.82it/s]

Epoch: [24/30], Batch: [1/235], train accuracy: 1.000000, loss: 0.002505


 44%|████▍     | 103/235 [00:09<00:11, 11.43it/s]

Epoch: [24/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.002512


 86%|████████▋ | 203/235 [00:17<00:02, 11.49it/s]

Epoch: [24/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.002561


100%|██████████| 235/235 [00:20<00:00, 11.43it/s]

Epoch: [24/30], train loss: 0.002523



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [24/30], test accuracy: 0.985100, loss: 0.652744


  0%|          | 1/235 [00:00<00:44,  5.28it/s]

Epoch: [25/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002510


 44%|████▍     | 103/235 [00:09<00:11, 11.46it/s]

Epoch: [25/30], Batch: [101/235], train accuracy: 0.972656, loss: 0.002520


 86%|████████▋ | 203/235 [00:17<00:02, 11.47it/s]

Epoch: [25/30], Batch: [201/235], train accuracy: 0.992188, loss: 0.002595


100%|██████████| 235/235 [00:20<00:00, 11.46it/s]

Epoch: [25/30], train loss: 0.002521



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [25/30], test accuracy: 0.985100, loss: 0.650468


  0%|          | 1/235 [00:00<00:45,  5.19it/s]

Epoch: [26/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002526


 44%|████▍     | 103/235 [00:09<00:11, 11.46it/s]

Epoch: [26/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.002454


 86%|████████▋ | 203/235 [00:17<00:02, 11.46it/s]

Epoch: [26/30], Batch: [201/235], train accuracy: 0.988281, loss: 0.002520


100%|██████████| 235/235 [00:20<00:00, 11.44it/s]

Epoch: [26/30], train loss: 0.002520



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [26/30], test accuracy: 0.985100, loss: 0.650810


  0%|          | 1/235 [00:00<00:44,  5.28it/s]

Epoch: [27/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002488


 44%|████▍     | 103/235 [00:09<00:11, 11.40it/s]

Epoch: [27/30], Batch: [101/235], train accuracy: 0.988281, loss: 0.002528


 86%|████████▋ | 203/235 [00:17<00:02, 11.41it/s]

Epoch: [27/30], Batch: [201/235], train accuracy: 0.972656, loss: 0.002514


100%|██████████| 235/235 [00:20<00:00, 11.41it/s]

Epoch: [27/30], train loss: 0.002517



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [27/30], test accuracy: 0.985100, loss: 0.650139


  0%|          | 1/235 [00:00<00:44,  5.23it/s]

Epoch: [28/30], Batch: [1/235], train accuracy: 0.988281, loss: 0.002528


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [28/30], Batch: [101/235], train accuracy: 0.976562, loss: 0.002489


 86%|████████▋ | 203/235 [00:17<00:02, 11.42it/s]

Epoch: [28/30], Batch: [201/235], train accuracy: 0.992188, loss: 0.002475


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [28/30], train loss: 0.002513



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [28/30], test accuracy: 0.985100, loss: 0.649508


  0%|          | 1/235 [00:00<00:45,  5.11it/s]

Epoch: [29/30], Batch: [1/235], train accuracy: 1.000000, loss: 0.002417


 44%|████▍     | 103/235 [00:09<00:11, 11.41it/s]

Epoch: [29/30], Batch: [101/235], train accuracy: 0.984375, loss: 0.002512


 86%|████████▋ | 203/235 [00:17<00:02, 11.42it/s]

Epoch: [29/30], Batch: [201/235], train accuracy: 0.984375, loss: 0.002513


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [29/30], train loss: 0.002512



  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: [29/30], test accuracy: 0.985100, loss: 0.647231


  0%|          | 1/235 [00:00<00:45,  5.11it/s]

Epoch: [30/30], Batch: [1/235], train accuracy: 0.968750, loss: 0.002516


 44%|████▍     | 103/235 [00:09<00:11, 11.42it/s]

Epoch: [30/30], Batch: [101/235], train accuracy: 0.980469, loss: 0.002486


 86%|████████▋ | 203/235 [00:17<00:02, 11.44it/s]

Epoch: [30/30], Batch: [201/235], train accuracy: 0.968750, loss: 0.002529


100%|██████████| 235/235 [00:20<00:00, 11.42it/s]

Epoch: [30/30], train loss: 0.002506





Epoch: [30/30], test accuracy: 0.985100, loss: 0.649247


In [9]:
torch.save(net.state_dict(), "./CapsNetMNIST_Recon.pth")

In [9]:
for param in net.parameters():
    if(param.requires_grad==True):
        print(param.shape)

torch.Size([512, 160])
torch.Size([512])
torch.Size([1024, 512])
torch.Size([1024])
torch.Size([784, 1024])
torch.Size([784])


In [6]:
print(lambda p: p.requires_grad, net.parameters())

<function <lambda> at 0x7f7c076463b0> <generator object Module.parameters at 0x7f7c71d49ad0>


In [None]:
#Config for 16 1d vectors in Capsule Layer. Set the Softmax Dimension to 1 in this case
# class Config:
#     def __init__(self, dataset='mnist'):
#         # CNN (cnn)
#         self.cnn_in_channels = 1
#         self.cnn_out_channels = 12
#         self.cnn_kernel_size = 15

#         # Primary Capsule (pc)
#         self.pc_num_capsules = 1
#         self.pc_in_channels = 12
#         self.pc_out_channels = 16
#         self.pc_kernel_size = 8
#         self.pc_num_routes = 16 * 7 * 7

#         # Digit Capsule 1 (dc)
#         self.dc_num_capsules = 49
#         self.dc_num_routes = 16 * 7 * 7
#         self.dc_in_channels = 1
#         self.dc_out_channels = 1 #16
        
#         # Digit Capsule 2 (dc)
#         self.dc_2_num_capsules = 10
#         self.dc_2_num_routes = 7 * 7
#         self.dc_2_in_channels = 1 #16
#         self.dc_2_out_channels = 16

#         # Decoder
#         self.input_width = 28
#         self.input_height = 28