In [1]:
#################
#### IMPORTS ####
#################

# Arrays
import numpy as np
import cytoolz

# Deep Learning stuff
import torch
import torchvision
import torchvision.transforms as transforms

# Images display and plots
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
import matplotlib.pylab as pl

# Fancy progress bars
import tqdm.notebook as tq

# Tensor Network Stuff
%config InlineBackend.figure_formats = ['svg']
import quimb.tensor as qtn # Tensor Network library
import quimb

import collections
import opt_einsum as oe
import itertools
import copy


In [2]:
import sys, os
sys.path.insert(0, '../')
# My functions
from TNutils import *


In [7]:
def get_data(train_size = 1000, test_size = 100, grayscale_threshold = .5):
    '''
    Prepare the MNIST dataset for the training algorithm:
     * Choose randomly a subset from the whole dataset
     * Flatten each image to mirror the mps structure
     * Normalize images from [0,255] to [0,1]
     * Apply a threshold for each pixels so that each value 
       below that threshold are set to 0, the others get set to 1.
       For this algorithm we will only deal to binary states {0,1}
       instead of a range from 0 to 1    
    '''
    # Download all data
    mnist = torchvision.datasets.MNIST('classifier_data', train=True, download=True,
                                                  transform = transforms.Compose([transforms.ToTensor()]) )
    
    # Convert torch.tenor to numpy
    npmnist = mnist.data.numpy()
    
    # Choose just a subset of the data
    # Creating a mask by randomly sampling the indexes of the full dataset
    subset_indexes = np.random.choice(np.arange(npmnist.shape[0]), size=(train_size + test_size), 
                                      replace=False, p=None)
    
    # Apply the mask
    npmnist = npmnist[subset_indexes]

    # Flatten every image
    npmnist = np.reshape(npmnist, (npmnist.shape[0], npmnist.shape[1]*npmnist.shape[2]))
    
    # Normalize the data from 0 - 255 to 0 - 1
    npmnist = npmnist/npmnist.max()
    
    # As in the paper, we will only deal with {0,1} values, not a range
    
    if ((grayscale_threshold <= 0) or (grayscale_threshold >= 1)):
        raise ValueError('grayscale_threshold must be in range ]0,1[')
    
    npmnist[npmnist > grayscale_threshold] = 1
    npmnist[npmnist <= grayscale_threshold] = 0
    
    # Return training set and test set
    return npmnist[:train_size], npmnist[train_size:]


In [8]:
train_set, test_set = get_data(train_size=60000,test_size = 0)


In [9]:
print('Shape of the training set: {}'.format(train_set.shape) )
print('Shape of the test set:     {}'.format(test_set.shape) )


Shape of the training set: (60000, 784)
Shape of the test set:     (0, 784)


In [10]:
size_train = 10000

In [11]:
data = train_set[:size_train]
val_data = train_set[size_train:(size_train+100)]

In [12]:
_imgs = np.array([tens_picture(img[:6]) for img in data[:]])

In [13]:
mps = initialize_mps(_imgs.shape[1],bdim=500)

In [14]:
img_cache = left_right_cache(mps,_imgs)

In [3]:
from dask.distributed import Client

client = Client(n_workers=16)

In [4]:
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 16
Total threads: 16,Total memory: 15.40 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:54084,Workers: 16
Dashboard: http://127.0.0.1:8787/status,Total threads: 16
Started: Just now,Total memory: 15.40 GiB

0,1
Comm: tcp://127.0.0.1:54192,Total threads: 1
Dashboard: http://127.0.0.1:54193/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54097,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-oqlvcuzv,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-oqlvcuzv

0,1
Comm: tcp://127.0.0.1:54202,Total threads: 1
Dashboard: http://127.0.0.1:54204/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54087,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-xwzy181z,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-xwzy181z

0,1
Comm: tcp://127.0.0.1:54219,Total threads: 1
Dashboard: http://127.0.0.1:54220/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54102,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-vk8g2g8g,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-vk8g2g8g

0,1
Comm: tcp://127.0.0.1:54210,Total threads: 1
Dashboard: http://127.0.0.1:54211/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54098,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-1fudxk7t,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-1fudxk7t

0,1
Comm: tcp://127.0.0.1:54181,Total threads: 1
Dashboard: http://127.0.0.1:54183/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54089,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-2vnuh2do,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-2vnuh2do

0,1
Comm: tcp://127.0.0.1:54195,Total threads: 1
Dashboard: http://127.0.0.1:54197/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54101,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-essueu0z,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-essueu0z

0,1
Comm: tcp://127.0.0.1:54201,Total threads: 1
Dashboard: http://127.0.0.1:54203/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54088,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-279t5qc7,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-279t5qc7

0,1
Comm: tcp://127.0.0.1:54213,Total threads: 1
Dashboard: http://127.0.0.1:54214/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54092,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-f7ottelo,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-f7ottelo

0,1
Comm: tcp://127.0.0.1:54222,Total threads: 1
Dashboard: http://127.0.0.1:54223/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54093,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-3uo4ur23,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-3uo4ur23

0,1
Comm: tcp://127.0.0.1:54196,Total threads: 1
Dashboard: http://127.0.0.1:54198/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54094,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-h9l0ycfu,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-h9l0ycfu

0,1
Comm: tcp://127.0.0.1:54225,Total threads: 1
Dashboard: http://127.0.0.1:54226/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54096,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-sr2hpxth,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-sr2hpxth

0,1
Comm: tcp://127.0.0.1:54228,Total threads: 1
Dashboard: http://127.0.0.1:54229/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54099,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-gue21xyt,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-gue21xyt

0,1
Comm: tcp://127.0.0.1:54189,Total threads: 1
Dashboard: http://127.0.0.1:54190/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54090,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-d4am56dt,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-d4am56dt

0,1
Comm: tcp://127.0.0.1:54216,Total threads: 1
Dashboard: http://127.0.0.1:54217/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54095,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-skr0kd79,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-skr0kd79

0,1
Comm: tcp://127.0.0.1:54207,Total threads: 1
Dashboard: http://127.0.0.1:54208/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54100,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-kt5rn6l7,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-kt5rn6l7

0,1
Comm: tcp://127.0.0.1:54182,Total threads: 1
Dashboard: http://127.0.0.1:54184/status,Memory: 0.96 GiB
Nanny: tcp://127.0.0.1:54091,
Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-ldmryluk,Local directory: C:\Users\Saverio\Documents\GitHub\mps_born_machines\tests\dask-worker-space\worker-ldmryluk


In [15]:
import dask as ds

def _into_data(tensor_array):
    op_arr = []
    for ten in tensor_array:
        op_arr.append(ds.delayed(lambda x: x.data)(ten))
    data_arr = ds.delayed(lambda x: x)(op_arr).compute()
    return data_arr

In [18]:
def datize(x): return x.data

In [21]:
%%timeit
_r = _into_data(img_cache[:,0,3])

5.02 s ± 129 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [22]:
%%timeit
r = np.array(list(map(datize,img_cache[:,0,3])))

4.39 ms ± 254 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
%%timeit
tneinsum3(img_cache[:,0,3],img_cache[:,1,4])


64.4 ms ± 2.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [24]:
%%timeit
tneinsum3(img_cache[:,0,4],img_cache[:,1,4+1],backend = 'numpy')

59.6 ms ± 4.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [25]:
%%timeit
tneinsum3(img_cache[:,0,4],img_cache[:,1,4+1],backend = 'torch')

63.2 ms ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [26]:
%%timeit
tneinsum3(img_cache[:,0,4],img_cache[:,1,4+1],backend = 'tensorflow')

71.5 ms ± 21.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%%timeit
for i in range(1,201):
    tneinsum3(img_cache[:,0,i],img_cache[:,1,i+1],backend = 'jax')

ModuleNotFoundError: No module named 'jax'

In [27]:
def _into_data(tensor_array):
    op_arr = []
    for ten in tensor_array:
        op_arr.append(ds.delayed(lambda x: x.data.astype(np.float32))(ten))
    data_arr = ds.delayed(lambda x: x)(op_arr).compute()
    return data_arr


In [30]:
%%timeit
_r = _into_data(img_cache[:,0,3])


67.7 ms ± 3.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [31]:
def datize(x): return x.data

In [32]:
%%timeit
r = np.array(list(map(datize,img_cache[:,0,3])))


263 µs ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
