In [None]:
#!pip install quimb -U
#!pip install autoray
#!pip install --upgrade tbb

In [2]:
#################
#### IMPORTS ####
#################

# Arrays
import numpy as np
import cytoolz

# Deep Learning stuff
import torch
import torchvision
import torchvision.transforms as transforms

# Images display and plots
import matplotlib.pyplot as plt

# Fancy progress bars
import tqdm.notebook as tq

# Tensor Network Stuff
%config InlineBackend.figure_formats = ['svg']
import quimb.tensor as qtn # Tensor Network library
import quimb

import collections
import opt_einsum as oe
import itertools

from TNutils import *



In [None]:
import TNutils
import importlib
importlib.reload(TNutils)
from TNutils import *

In [3]:
# TODO: Get full dataset
train_set, test_set = get_data()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to classifier_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting classifier_data/MNIST/raw/train-images-idx3-ubyte.gz to classifier_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to classifier_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting classifier_data/MNIST/raw/train-labels-idx1-ubyte.gz to classifier_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to classifier_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting classifier_data/MNIST/raw/t10k-images-idx3-ubyte.gz to classifier_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to classifier_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting classifier_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to classifier_data/MNIST/raw



In [37]:
import time
from matplotlib.ticker import MaxNLocator

def plot_nll(nlls,baseline):
    plt.plot(nlls)
    plt.title('Negative log-likelihood')
    plt.axhline(baseline,color = 'r', linestyle= 'dashed')
    plt.legend(['training set','baseline'])
    plt.xlabel('epochs')
    plt.ylabel(r'$\mathcal{L}$')
    #plt.xaxis.set_major_locator(MaxNLocator(integer=True))


def training_and_probing(
    period_epochs,
    periods,
    mps,
    _imgs,
    img_cache,
    batch_size,
    initial_lr = 0.5, 
    lr_update = lambda x: x/2,
    val_imgs = None,
    period_samples = 0,
    corrupted_set = None,
    plot = False,
    **kwargs):
    # Initialize the training costs
    train_costs = [computeNLL_cached(mps, _imgs, img_cache,0)]

    # TODO: adapt computeNLL to tneinsum3
    if val_imgs:
        # Initialize the validation costs
        val_costs = [computeNLL(mps, val_imgs, 0)]


    samples = []
    lr = initial_lr

    # begin the iteration
    for period in range(periods):
        costs = learning_epoch_cached(mps,_imgs,period_epochs,lr,img_cache,batch_size = batch_size,**kwargs)
        train_costs.extend(costs)
        if val_imgs:
            val_costs.append(computeNLL(mps, val_imgs, 0))
        lr = lr_update(lr)
        if plot:
            plot_nll(train_costs,np.log(len(_imgs)))
            plt.show()
            time.sleep(2)

        # TODO: friendlier sampling strategy
        for i in range(period_samples):
            samples.append(generate_sample(mps))

    return train_costs, samples
        
        


In [None]:
initial_bdim = 16
train_size = 1000
scale = True

# Convert accordingly
_imgs = np.array([tens_picture(img) for img in train_set[:train_size]])

# Initialize MPS
mps = initialize_mps(_imgs.shape[1],bdim=initial_bdim)

# scale tensors
if scale:
    for i, ten in enumerate(mps.tensors):
        mps[i].modify(data = ten.data/ten.data.max())

# Initialize the cache
img_cache = left_right_cache(mps,_imgs[:])

In [None]:
period_epochs = 5
periods = 10
batch_size = 200
max_bond = 500


nlls, samples = training_and_probing(
    period_epochs,
    periods,
    mps,
    _imgs,
    img_cache,
    batch_size,
    initial_lr = 0.08, 
    lr_update = lambda x: x/2,
    val_imgs = None,
    period_samples = 0,
    corrupted_set = None,
    plot = True,
    max_bond = max_bond)
