# Reverse MNIST

## Preliminaries

In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
dtype = torch.float
device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU
import NeuralNetwork as NN
import Layer
import importlib
import time
from copy import deepcopy
from IPython.display import display
from ipywidgets import FloatProgress

import mnist_loader
importlib.reload(mnist_loader)
importlib.reload(NN)
importlib.reload(Layer);

## Load MNIST DATA

In [2]:
train, validate, test = mnist_loader.load_data_wrapper()
train = [torch.tensor(train[0][:500]).float().to(device), torch.tensor(train[1][:500]).float().to(device)]

## Create Network

In [4]:
# Reversed
net = NN.NeuralNetwork()
net.AddLayer(Layer.InputPELayer(n=10))  # 0 Class vector
net.AddLayer(Layer.TopPELayer(n=784))  # 1 Input
net.Connect(0,1)
# Input layer is one-hot
net.layers[0].sigma = Layer.softmax
net.layers[0].sigma_p = Layer.softmax_p
# Top vector reconstructed image, [0, 1]
net.layers[-1].sigma = Layer.logistic
net.layers[-1].sigma_p = Layer.logistic_p

net.SetTau(0.08)

## Train the Network

In [5]:
# Train in reverse direction: one-hot -> reconstructed image
epochs = 1
T = 3.
start_time = time.time()
batch_size = 200
net.learning_tau = torch.tensor(batch_size).float().to(device)
fp = FloatProgress(min=0,max=epochs*len(train[0]))  
display(fp)
for k in range(epochs):
    batches = NN.MakeBatches(train[1], train[0], batch_size=batch_size)
    for x in batches:
        net.Reset()
        net.Infer(T, x[0], x[1])
        fp.value += batch_size
end_time = time.time()
print('Total time: '+str(end_time-start_time))

Total time: 0.8628246784210205


## Save the Model

In [7]:
net.Save('MNIST reversed.npy')

## Probe the Model

In [6]:
net.connections[0].W

tensor([[-3.6231e+12, -3.6231e+12, -3.6231e+12,  ..., -3.6231e+12,
         -3.6231e+12, -3.6231e+12],
        [-6.1488e+13, -6.1488e+13, -6.1488e+13,  ..., -6.1488e+13,
         -6.1488e+13, -6.1488e+13],
        [ 5.2267e+13,  5.2267e+13,  5.2267e+13,  ...,  5.2267e+13,
          5.2267e+13,  5.2267e+13],
        ...,
        [ 4.2727e+13,  4.2727e+13,  4.2727e+13,  ...,  4.2727e+13,
          4.2727e+13,  4.2727e+13],
        [-6.7169e+12, -6.7169e+12, -6.7169e+12,  ..., -6.7169e+12,
         -6.7169e+12, -6.7169e+12],
        [ 4.4713e+13,  4.4713e+13,  4.4713e+13,  ...,  4.4713e+13,
          4.4713e+13,  4.4713e+13]], device='cuda:0')

## Test the Network