In [1]:
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from model import Net
import argparse
import utils
from collections import OrderedDict
import numpy as np


  (fname, cnt))
  (fname, cnt))


In [2]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100, 1, 1)    # fixed noise
if torch.cuda.is_available():
    fixed_z_.cuda()
fixed_z_ = Variable(fixed_z_)
def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):
    z_ = torch.randn((5*5, 100)).view(-1, 100, 1, 1)
    if torch.cuda.is_available():
        z_.cuda()
    z_ = Variable(z_, volatile=True)

    G.eval()
    if isFix:
        test_images = G(fixed_z_)
    else:
        test_images = G(z_)
    G.train()

    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i, j].cla()
        ax[i, j].imshow(test_images[k, 0].cpu().data.numpy(), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [8]:
def load_model(model_type, dict_file):
    state_dict = torch.load(dict_file, map_location=lambda storage, loc: storage)

    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if k[:7] == 'module.':
            name = k[7:] # remove `module.`
#         if name[:2] == 'fc':
#             name = 'decoder.' + name
        new_state_dict[name] = v

    model = None
    if model_type == 'g':
        model = generator(128)
    elif model_type == 'd':
        model = Net(num_conv_in_channel=args.num_conv_in_channel,
                    num_conv_out_channel=args.num_conv_out_channel,
                    num_primary_unit=args.num_primary_unit,
                    primary_unit_size=args.primary_unit_size,
                    num_classes=args.num_classes,
                    output_unit_size=args.output_unit_size,
                    num_routing=args.num_routing,
                    use_reconstruction_loss=args.use_reconstruction_loss,
                    regularization_scale=args.regularization_scale,
                    input_width=args.input_width,
                    input_height=args.input_height,
                    cuda_enabled=args.cuda)
    elif model_type == 'b_g':
        model = base_generator(128)
    elif model_type == 'b_d':
        model = base_discriminator(128)

    model.load_state_dict(new_state_dict)
    return model

In [4]:
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image

class Synthetic_MNIST_LOADER():
    def __init__(self):
        self.folder = '../synthetic_mnist/'
        self.to_tensor = transforms.ToTensor()
        
    def __getitem__(self, index):
        img_id = self.folder + str(index+1) + '.jpg'
        image = Image.open(img_id)
        
        ### taking image from 26x26 to 28x28
        new_image = Image.new("RGB", (28,28)) 
        new_image.paste(image, (1, 1))
        
        image = new_image.convert('L')
        
        image_tensor = self.to_tensor(image)
        path = img_id
        return (image_tensor, path)
            
    def __len__(self):
        return 38

In [12]:
import sys
sys.argv=['']
parser = argparse.ArgumentParser(description='Example of Capsule Network')
parser.add_argument('--epochs', type=int, default=10,
                    help='number of training epochs. default=10')
parser.add_argument('--lr', type=float, default=0.01,
                    help='learning rate. default=0.01')
parser.add_argument('--batch-size', type=int, default=256,
                    help='training batch size. default=128')
parser.add_argument('--test-batch-size', type=int,
                    default=128, help='testing batch size. default=128')
parser.add_argument('--log-interval', type=int, default=10,
                    help='how many batches to wait before logging training status. default=10')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training. default=false')
parser.add_argument('--threads', type=int, default=4,
                    help='number of threads for data loader to use. default=4')
parser.add_argument('--seed', type=int, default=42,
                    help='random seed for training. default=42')
parser.add_argument('--num-conv-out-channel', type=int, default=256,
                    help='number of channels produced by the convolution. default=256')
parser.add_argument('--num-conv-in-channel', type=int, default=1,
                    help='number of input channels to the convolution. default=1')
parser.add_argument('--num-primary-unit', type=int, default=8,
                    help='number of primary unit. default=8')
parser.add_argument('--primary-unit-size', type=int,
                    default=1152, help='primary unit size is 32 * 6 * 6. default=1152')
parser.add_argument('--num-classes', type=int, default=1,
                    help='number of digit classes. 1 unit for one MNIST digit. default=10')
parser.add_argument('--output-unit-size', type=int,
                    default=16, help='output unit size. default=16')
parser.add_argument('--num-routing', type=int,
                    default=3, help='number of routing iteration. default=3')
parser.add_argument('--use-reconstruction-loss', type=utils.str2bool, nargs='?', default=True,
                    help='use an additional reconstruction loss. default=True')
parser.add_argument('--regularization-scale', type=float, default=0.0005,
                    help='regularization coefficient for reconstruction loss. default=0.0005')
parser.add_argument('--dataset', help='the name of dataset (mnist, cifar10)', default='mnist')
parser.add_argument('--input-width', type=int,
                    default=28, help='input image width to the convolution. default=28 for MNIST')
parser.add_argument('--input-height', type=int,
                    default=28, help='input image height to the convolution. default=28 for MNIST')
parser.add_argument('--is-training', type=int,
                    default=1, help='Whether or not is training, default is yes')
parser.add_argument('--weights', type=str,
                    default=None, help='Load pretrained weights, default is none')

args = parser.parse_args()

args.no_cuda = True
print(args)

# Check GPU or CUDA is available
args.cuda = not args.no_cuda and torch.cuda.is_available()

# Get reproducible results by manually seed the random number generator
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)



# training parameters
batch_size = 128
lr = 0.0002
train_epoch = 20

# data_loader
img_size = 28#64

transform = transforms.Compose([
        transforms.Scale(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

synth_loader = Synthetic_MNIST_LOADER()

# network

D = load_model('d', 'MNIST_CAPSGAN_FC_results/discriminator_param.pkl')

if args.cuda:
    print('Utilize GPUs for computation')
    print('Number of GPU available', torch.cuda.device_count())
    print('CUDNN version: ', torch.backends.cudnn.version())
    # torch.backends.cudnn.benchmark = True
    D = torch.nn.DataParallel(D).cuda()

# results save folder
if not os.path.isdir('synthetic_results'):
    os.mkdir('synthetic_results')

    
def thresh(vals, t):
    inds = vals < t
    vals[inds] = 0
    return vals

sets = [("3_fade", 1, 4), ("8_trans", 5, 10), ("4_rot", 11, 18), ("0_size", 19, 22), ("8_skew", 23, 28), ("3_width", 29, 34), ("4_stroke", 35, 38)]
### ONLY SET THESE ###
for name, start, end in sets:
    #start = 5
    #end = 10

    net = end - start + 1
    fig, ax = plt.subplots(3,net)
    fig.set_size_inches(18.5, 10.5)
    for i in range(3):
        for j in range(net):
            ax[i,j].get_xaxis().set_visible(False)
            ax[i,j].get_yaxis().set_visible(False)
            ax[i,j].cla()
    values = []
    counter = 1
    for x_, path in synth_loader:
        if counter >= start and counter <= end:

            # print(str(x_.data.numpy()[0,1]))
            x_.unsqueeze_(0)

            D(x_)
            feats = D.out_features
            
            feats = feats.squeeze()
            feats = feats.unsqueeze(0)

            values.append(feats.data.numpy())

            #plt.imshow(feats.squeeze().data.reshape((96,96)))

        if counter == 38:
            for i in range(0,net):
                ax[0, i].imshow(values[i], cmap='gray')
            for i in range(1,net):
                val = values[i-1] - values[i]
                ax[1,i].imshow(thresh(val, np.max(val)/2), cmap='gray')
            for i in range(1,net):
                val = values[0] - values[i]
                ax[2,i].imshow(thresh(val, np.max(val)/2), cmap='gray')

            plt.savefig('figs/16_layer/' + name + '.png')
            plt.clf()
            np.save('figs/16_layer' + name, values)
            break

        counter += 1
    


Namespace(batch_size=256, dataset='mnist', epochs=10, input_height=28, input_width=28, is_training=1, log_interval=10, lr=0.01, no_cuda=True, num_classes=1, num_conv_in_channel=1, num_conv_out_channel=256, num_primary_unit=8, num_routing=3, output_unit_size=16, primary_unit_size=1152, regularization_scale=0.0005, seed=42, test_batch_size=128, threads=4, use_reconstruction_loss=True, weights=None)




<matplotlib.figure.Figure at 0x7fb1aaa30e10>

<matplotlib.figure.Figure at 0x7fb1a9e7f518>

<matplotlib.figure.Figure at 0x7fb1a994cdd8>

<matplotlib.figure.Figure at 0x7fb1aa4ff828>

<matplotlib.figure.Figure at 0x7fb1a961c240>

<matplotlib.figure.Figure at 0x7fb1a8856a58>

<matplotlib.figure.Figure at 0x7fb1a9de7a20>