In [None]:
import torch
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from model.net import *
from utils.training import *
from data.data import *

## Experiment Config

In [None]:
args = {
    'USE_CUDA': True if torch.cuda.is_available() else False,
    'BATCH_SIZE': 32,
    'N_EPOCHS': 30,
    'LEARNING_RATE': 0.01,
    'MOMENTUM': 0.9,
    'DATASET_NAME':'mnist',
    'LAMBDA_recon': 1, #0.0005
    'LAMBDA_margin': 0,
}

## Model Loading

In [None]:
#Config for 49 16d vectors in the Primary Capsule. Set Softmax dimension to 0 in this case
class Config:
    def __init__(self):
        # CNN (cnn)
        self.cnn_in_channels = 1
        self.cnn_out_channels = 12
        self.cnn_kernel_size = 15

        # Primary Capsule (pc)
        self.pc_num_capsules = 16
        self.pc_in_channels = 12
        self.pc_out_channels = 1
        self.pc_kernel_size = 8
        self.pc_num_routes = 1 * 7 * 7

        # Digit Capsule 1 (dc)
        self.dc_num_capsules = 49
        self.dc_num_routes = 1 * 7 * 7
        self.dc_in_channels = 16
        self.dc_out_channels = 16
        
        # Digit Capsule 2 (dc)
        self.dc_2_num_capsules = 10
        self.dc_2_num_routes = 1 * 7 * 7
        self.dc_2_in_channels = 16
        self.dc_2_out_channels = 16

        # Decoder
        self.input_width = 28
        self.input_height = 28

torch.manual_seed(1)
config = Config()

net = CapsNet(args, config)
# capsule_net = torch.nn.DataParallel(capsule_net)
if args['USE_CUDA']:
    net = net.cuda()
    
net.load_state_dict(torch.load(os.path.join(model_path, 'CapsNetMNIST.pth'), map_location='cpu'))

## Loading Dataset

In [None]:
trainloader, testloader = dataset(args)

## Training CapsuleNet

In [None]:
optimizer = torch.optim.Adam(net.parameters())

for e in range(1, args['N_EPOCHS'] + 1):
    train_CapsNet(net, optimizer, trainloader, e, args)
    test_CapsNet(net, testloader, e, args)

In [None]:
# torch.save(capsule_net.state_dict(), "./CapsNetMNIS.pth ")

In [None]:
#Config for 16 1d vectors in Capsule Layer. Set the Softmax Dimension to 1 in this case
# class Config:
#     def __init__(self, dataset='mnist'):
#         # CNN (cnn)
#         self.cnn_in_channels = 1
#         self.cnn_out_channels = 12
#         self.cnn_kernel_size = 15

#         # Primary Capsule (pc)
#         self.pc_num_capsules = 1
#         self.pc_in_channels = 12
#         self.pc_out_channels = 16
#         self.pc_kernel_size = 8
#         self.pc_num_routes = 16 * 7 * 7

#         # Digit Capsule 1 (dc)
#         self.dc_num_capsules = 49
#         self.dc_num_routes = 16 * 7 * 7
#         self.dc_in_channels = 1
#         self.dc_out_channels = 1 #16
        
#         # Digit Capsule 2 (dc)
#         self.dc_2_num_capsules = 10
#         self.dc_2_num_routes = 7 * 7
#         self.dc_2_in_channels = 1 #16
#         self.dc_2_out_channels = 16

#         # Decoder
#         self.input_width = 28
#         self.input_height = 28