# Experiment 4:
## Reconstructing using multiple generative models

This experiment shows visual behaviour of using multiple generative models for the reconstruction of images in the MNIST data set by consider a generative model for each label.

In [None]:
import numpy as np;
import numpy.linalg as linalg;

import matplotlib.pyplot as plt;
from matplotlib.pyplot import figure;

import torch;
import torch.nn as nn;
import torch.nn.functional as F;
import torch.optim as optim;

from torch.utils.data import DataLoader;
import torchvision;

import pickle;

import time;
import random;
import math;

from IPython.display import clear_output;

import MNIST_utils as MNIST;
import MNIST_generative as generative;
import sparsity_utils as sparsity;
import general_utils as utils;

In [None]:
(train_data, test_data), (train_loader, test_loader) = MNIST.load_dataset();

mspace = [10,20,30,50,100];
noise = 0.1;

In [None]:
utils.reset_seeds();

networks = [generative.createNetwork(2, 20)[1] for i in range(10)];
for i in range(10):
    generative.trainNetwork(networks[i], MNIST.getTrainLoader(i), num_epochs = 20);

In [None]:
x_true = MNIST.getImageAsVector(test_loader,1);
plt.imshow(MNIST.VectorToImage(x_true), cmap="binary");
plt.show();

imgs = [];

utils.reset_seeds();
for i in range(10):
    img =  generative.project(x_true, networks[i],
                              num_epochs = 200,
                              learning_rate = 0.0001);
    
    plt.imshow(MNIST.VectorToImage(img), cmap="binary");
    plt.show();
    imgs.append(img);

best_data = [() for i in range(10)];
for i in range(10):
    true_error = linalg.norm(utils.normalize(imgs[i])-utils.normalize(x_true));
    best_data[i] = (imgs[i], true_error);
    print(str(i) + " : " + str(true_error));

In [None]:
x_true = MNIST.getImageAsVector(test_loader,1);
plt.imshow(MNIST.VectorToImage(x_true), cmap="binary");
plt.show();

data = [[() for i in range(10)] for m in mspace];
utils.reset_seeds();
for j in range(len(mspace)):
    m = mspace[j];

    A, q = utils.generateMeasurements_Gaussian(x_true, m, noise);
    q = utils.quantize(q);


    for i in range(10):
        projector =  lambda x : generative.project(x,networks[i],
                                 num_epochs = 50, learning_rate = 0.1);
    
    
        best = 10000;
        best_obj = 100000;
        y_best = ();
        rep = 3;
        for k in range(rep):
            y, accuracy = MNIST.reconstruct_BIP(A, q, x_true, projector,
                                           learning_param = 0.04, iterations = 50);
            
            inter = np.multiply(A.dot(utils.normalize(y)),q);
            inter = [min(z, 0) for z in inter];
            inter = linalg.norm(inter, ord=1)
            if inter < best_obj:
                best_obj = inter;
                best = accuracy;
                y_best = y;


        data[j][i] = (y_best, best_obj, linalg.norm(utils.normalize(x_true)-utils.normalize(y_best)));
        print(str(m)+" "+str(i) + " " + str(best) + " " + str(best_obj));
    

In [None]:
thresholds = [[] for i in range(len(mspace))];
for i in range(len(mspace)):
    objs = [d[1] for d in data[i]];
    thresholds[i] = min(objs);
    
print(thresholds);

In [None]:
accheight = 14;
models = len(networks);

filterindices = [0,1,2,3,4]
ms = len(filterindices);

results = np.zeros(((28+accheight+1)*(ms + 1),(models)*(28+1)));
figure(figsize=(5, 5), dpi=500);

def drawBorder(x1,y1,x2,y2,color):
    plt.vlines([x1,x2-1],y1,y2-1, colors=color, linewidth=1);
    plt.hlines([y1,y2-1],x1,x2-1, colors=color, linewidth=1);
    
for i in range(ms):
    for j in range(models):
        image = data[filterindices[i]][j][0];
        image = np.abs(image);
        image = image/np.max(image);
        results[(28+accheight+1)*i+1:(28+accheight+1)*i+28+1,(28+1)*j+1:(28+1)*j+28+1] = MNIST.VectorToImage(image);
        
for j in range(models):   
    best_img = MNIST.VectorToImage(best_data[j][0]);
    best_img = np.abs(best_img);
    best_img = best_img/np.max(best_img);
    offset = (28+accheight+1)*ms + 2;
    results[offset:offset+28,(28+1)*j+1:(28+1)*j+28+1] = best_img;
    
plt.imshow(results, cmap='binary');
plt.xticks([]);
plt.yticks([]);
plt.ylim([((28+accheight+1)*(ms+1)),0]);
plt.xlim([0, (models)*(28+1)-1 + 1]);

plt.hlines([i*(28+accheight+1) for i in range(ms+1)],0,(models)*(28+1), colors='black', linewidth=1);
plt.hlines([i*(28+accheight+1)-accheight for i in range(ms+1)],0,(models)*(28+1), colors='black', linewidth=1);

plt.vlines([i*(28+1) for i in range(1,(models))],0,((28+accheight+1)*(ms+1) + 2) - 1, colors='black', linewidth=1);

plt.hlines([(ms)*(28+accheight+1) + 1],0,(models)*(28+1), colors='black', linewidth=1);
plt.hlines([(ms+1)*(28+accheight+1)-accheight+1],0,(models)*(28+1), colors='black', linewidth=1);

for i in range(ms):
    for j in range(models):
        string = "obj. {:.3f}".format(data[filterindices[i]][j][1]) + "\n" + "acc. {:.3f}".format(data[filterindices[i]][j][2]);
        plt.text((28+1)*j + 4, -4 + ((i+1)*(28+1+accheight)), string, fontsize=3.6, fontweight="bold", ha="left");

for j in range(models):
    string = "acc. {:.3f}".format(best_data[j][1]);
    plt.text((28+1)*j + 4, -8 + (28+accheight+1)*(ms+1) + 2, string, fontsize=3.6, fontweight="bold", ha="left");

for i in range(ms):
    plt.text(-9, (28+accheight+1)*i + 20, str(mspace[filterindices[i]]), fontsize=5.0, fontweight="bold", ha="center");
plt.text(-9, -4, "m", fontsize=5.0, fontweight="bold", ha="center");
plt.text(-14, (28+accheight+1)*(ms) + 21, "projector", fontsize=4.5, fontweight="bold", ha="center");
    
for i in range(ms):
    for j in range(models):
        if data[i][j][1] <= thresholds[i]:
            drawBorder((28+1)*j + 1,(28+accheight+1)*i + 1, (28+1)*j+28 + 1,(28+accheight+1)*i+28 + 1, "firebrick");
    
plt.rcParams['axes.titley'] = 1.005;
plt.title("Multi model reconstruction of \"2\"", fontsize = 10);
plt.savefig("results_images\images_multi_model_half_k20.png", dpi=1000);