In [None]:
#################
#### IMPORTS ####
#################

# Arrays
import numpy as np

# Deep Learning stuff
import torch
import torchvision
import torchvision.transforms as transforms

# Images display and plots
import matplotlib.pyplot as plt

# Fancy progress bars
import tqdm.notebook as tq

# Tensor Network Stuff
%config InlineBackend.figure_formats = ['svg']
import quimb.tensor as qtn # Tensor Network library
import quimb


In [None]:
import sys
sys.path.insert(0, '../')
# My functions
from TNutils import *


### 1. Handling MNIST images

Get MNIST data
1. Actually download data
2. Flatten each image
3. Normalize [0,255] -> [0,1]
4. Trasform from grayscale to binary images [0,1] -> {0,1}

In [None]:
train_set, test_set = get_data()


In [None]:
print('Shape of the training set: {}'.format(train_set.shape) )
print('Shape of the test set:     {}'.format(test_set.shape) )


Example of an image:

In [None]:
plot_img(train_set[1])


Further on, we will like to get rid parts of the image and reconstruct them.
There is a function for partially removing parts of the image in the test set:

In [None]:
plot_img( partial_removal_img(train_set[1], fraction = .4, axis = 1))

### 2. MPS 

1. Create an MPS network
2. Canonicalize towards the second tensor
3. Rename indexes (in a readable form)

In [None]:
toymps = initialize_mps(Ldim=10, bdim=10)

Inspect shape/canonicalization

In [None]:
toymps.show()

Inspect indexes

In [None]:
toymps.tensors

In [None]:
mps = initialize_mps(bdim=30)

I developed two forms to compute the contraction of the mps and an image 

namely psi(v)

1. (SLOW) actually creates the network of the image and contract it 
2. einsums

the two methods outputs the same result, but the second is way faster

In [None]:
%%time

slow_psi = quimb_transform_img2state(train_set[0]) @ mps

In [None]:
%%time 

fast_psi = computepsi(mps, train_set[0])

### 3. Learning

In [None]:
imgs = train_set[:2]
mps = initialize_mps(bdim=30)


In [None]:
# Compute probability of the first image of the training set of the untrained network
computepsi(mps,imgs[0])**2

In [None]:
learning_epoch(mps, imgs, 1, 0.1)


In [None]:
# YOU CAN GET PROBABILITIES MORE THAN ONE, WE NEED TO RE-NORMALIZE THE MPS
mps = mps / mps.norm()


In [None]:
# Compute probability of the first image of the training set of the trained network
computepsi(mps,imgs[0])**2


In [None]:
# Compute probability of random noise from the trained network
computepsi(mps,np.random.randint(0,2,(784)) )**2


In [None]:
computeNLL(mps, imgs)

In [None]:
computepsi(mps,imgs[0])**2

### 4. Generation

In [None]:
gen = generate_sample(mps)
gen.appendleft(0) # I didn't bother to apply the last step in the loop of the generation function
plot_img(gen)