In [None]:
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib import cm
from conv_sparse_model import ConvSparseLayer

from train_conv_sparse_model import load_mnist_data
from train_conv_sparse_model import plot_filters

In [None]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    batch_size = 8
else:
    batch_size = 64

train_loader = load_mnist_data(batch_size)
example_data, example_targets = next(iter(train_loader))
example_data = example_data.to(device)

sparse_layer = ConvSparseLayer(in_channels=1,
                               out_channels=16,
                               kernel_size=8,
                               stride=1,
                               padding=0,
                               lam=0.05, 
                               activation_lr=1e-4,
                               max_activation_iter=1000
                               )
sparse_layer.to(device)

learning_rate = 1e-3
filter_optimizer = torch.optim.Adam(sparse_layer.parameters(),
                                   lr=learning_rate)

In [None]:
checkpoint = torch.load("mnist_out/sparse_conv3d_model-best.pt")
sparse_layer.load_state_dict(checkpoint['model_state_dict'])
filter_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
for epoch in range(3):
    for local_batch, local_labels in train_loader:
        local_batch = local_batch.to(device)
        local_labels = local_labels.to(device)
        activations = sparse_layer(local_batch[:, :, :, :])
        loss = sparse_layer.loss(local_batch[:, :, :, :], activations)
        print('loss={}'.format(loss))

        filter_optimizer.zero_grad()
        loss.backward()
        filter_optimizer.step()
        sparse_layer.normalize_weights()

In [None]:
u_init = torch.zeros([batch_size, sparse_layer.out_channels] +
                    sparse_layer.get_output_shape(example_data))

activations, _ = sparse_layer(example_data, u_init)
reconstructions = sparse_layer.reconstructions(
    activations).cpu().detach().numpy()

print("SHAPES")
print(example_data.shape)
print(example_data.shape)

fig = plt.figure()

img_to_show = 3
for i in range(img_to_show):
    # original
    plt.subplot(img_to_show, 2, i*2 + 1)
    plt.tight_layout()
    plt.imshow(example_data[i, 0, :, :].cpu().detach().numpy(), cmap='gray',
               interpolation='none')
    plt.title("Original Image\nGround Truth: {}".format(
        example_targets[0]))
    plt.xticks([])
    plt.yticks([])

    # reconstruction
    plt.subplot(img_to_show, 2, i*2 + 2)
    plt.tight_layout()
    plt.imshow(reconstructions[i, 0, :, :], cmap='gray',
               interpolation='none')
    plt.title("Reconstruction")
    plt.xticks([])
    plt.yticks([])

In [None]:
plt = plot_filters(sparse_layer.filters.cpu().detach())
plt.savefig('mnist_out/filters.png')