## Global update with `Dask`

`Dask` is a framework that allows to add parallelization to code in Python. To install the extra dependencies (not added in `pyproject.toml` yet):

In [None]:
%%bash
python -m pip install dask distributed --upgrade

`graphviz` (or `cytoscape`) is also needed to visualize the task graph.

In [None]:
from dask.distributed import Client, LocalCluster

client = Client()
client # Displays the information of the local client

This creates automatically a Dask 'Client' given by our system specificaition. One can then go to the link in the previous block to see `Dask dashboard`, which enables us to see the task stream and other useful information live during the computation.

Imports:

In [None]:
%config InlineBackend.figure_format = 'svg'

from multiprocessing import pool
import os
import sys
import json 
import time
import dask
from dask.distributed import Client, wait, performance_report # LocalCluster
from dask_jobqueue import SLURMCluster
import itertools
import numpy as np
import tnad.FeatureMap as fm
from tnad.losses import loss_miss, loss_reg
from tnad.gradients import gradient_miss, gradient_reg
from tnad.optimization import load_mnist_train_data, data_preprocessing
from tnad import smpo
import tnad.procedures as p
import math
import quimb.tensor as qtn
import quimb as qu
import matplotlib.pyplot as plt
from tqdm import tqdm

`procedures.py` functions adapted to Dask framework:

In [None]:
def local_update_sweep_dyncanonization_renorm(P, n_epochs, n_iters, data, batch_size, alpha, lamda_init, lamda_init_2, bond_dim, decay_rate=None, expdecay_tol=None):
    N_features = P.nsites

    loss_array = []
    for epoch in range(n_epochs):
        for it in (pbar := tqdm(range(n_iters))):        
            pbar.set_description("Epoch #"+str(epoch)+", sample in batch:")
            # define sweeps
            sweeps = itertools.chain(zip(list(range(0,N_features-1)), list(range(1,N_features))), reversed(list(zip(list(range(1,N_features)),list(range(0,N_features-1))))))
            for sweep_it, sites in enumerate(sweeps):
                [sitel, siter] = sites
                site_tags = [P.site_tag(site) for site in sites]
                # canonize P with root in sites
                ortog_center = sites
                P.canonize(sites, cur_orthog=ortog_center)
                # copy P as reference
                P_ref = P.copy(deep=True)
                # pop site tensor
                [origl, origr] = P.select_tensors(site_tags, which="any")
                tensor_orig = origl & origr ^ all
                # memorize bond between 2 selected sites
                bond_ind_removed = P.bond(site_tags[0], site_tags[1])

                #virtual bonds
                #    left
                if sitel == 0 or (sitel == N_features-1 and sitel>siter): vindl = []
                elif sitel>0 and sitel<siter: vindl = [P.bond(sitel-1, sitel)]
                else: vindl = [P.bond(sitel, sitel+1)]
                #    right
                if siter == N_features - 1 or (siter == 0 and siter<sitel): vindr = []
                elif siter < N_features-1 and siter>sitel: vindr = [P.bond(siter, siter+1)]
                else: vindr = [P.bond(siter-1, siter)]

                # remove site tags of poped sites
                P.delete(site_tags, which="any")

                grad_miss=0; loss_miss_batch=0
                for sample in data[it]:
                    # create MPS for input sample
                    phi, _ = fm.embed(sample.flatten(), fm.trigonometric)
                    
                    #calculate loss
                    loss_miss_batch += dask.delayed(loss_miss)(phi, P_ref)
                    
                    #calculate gradient
                    grad_miss += dask.delayed(gradient_miss)(phi, P_ref, P, sites)
                # total loss
                
                loss_miss_batch = loss_miss_batch.compute()
                grad_miss = grad_miss.compute()

                loss = (1/batch_size)*(loss_miss_batch)
                loss_array.append(loss)

                # gradient of loss miss
                grad_miss.drop_tags()
                grad_miss.add_tag(site_tags[0]); grad_miss.add_tag(site_tags[1])
                # gradient of loss reg
                # grad_regular = gradient_reg(P_ref, P, alpha, sites, N_features)
                # if grad_regular != 0:
                #     grad_regular.drop_tags()
                #     grad_regular.add_tag(site_tags[0]); grad_regular.add_tag(site_tags[1])
                # total gradient
                total_grad = (1/batch_size)*grad_miss

                # update tensor
                if epoch >= expdecay_tol:
                    if decay_rate != None:
                        # exp. decay of lamda
                        if epoch == exp_decay_tol: lamda = lamda_init_2
                        else: lamda = lamda_init_2*math.pow((1 - decay_rate/100),epoch)
                        tensor_new = tensor_orig - lamda*total_grad
                else:
                    tensor_new = tensor_orig - lamda_init*total_grad

                # normalize updated tensor
                tensor_new.normalize(inplace=True)

                # split updated tensor in 2 tensors
                lower_ind = [f'b{sitel}'] if f'b{sitel}' in P.lower_inds else []
                [tensorl, tensorr] = tensor_new.split(get="tensors", left_inds=[*vindl, P.upper_ind(sitel), *lower_ind], bond_ind=bond_ind_removed, max_bond=bond_dim)

                # link new tensors to P back
                for site, tensor in zip(sites, [tensorl, tensorr]):
                    tensor.drop_tags()
                    tensor.add_tag(P.site_tag(site))
                    P.add_tensor(tensor)

                    
def get_sample_grad(sample, embed_func, P, P_rem, tensor):
    # create MPS for input sample
    phi, _ = fm.embed(sample.flatten(), embed_func)

    #calculate gradient
    grad_miss = gradient_miss(phi, P, P_rem, [tensor])
    return grad_miss

def get_sample_loss(sample, embed_func, P):
    # create MPS for input sample
    phi, _ = fm.embed(sample.flatten(), embed_func)

    #calculate loss
    loss_miss_batch = loss_miss(phi, P)
    return loss_miss_batch

def get_total_grad(P, tensor, data, embed_func, batch_size, alpha):
    P_rem = P.copy(deep=True)
    
    site_tag = P_rem.site_tag(tensor)
    # remove site tag of poped sites
    P_rem.delete(site_tag, which="any")

    # parallelize
    grad_miss = []
    for i, sample in enumerate(data):
        output_per_sample = get_sample_grad(sample, embed_func, P, P_rem, tensor)
        grad_miss.append(output_per_sample)
    
    # gradient of loss miss
    grad_miss = sum(grad_miss)
    grad_miss.drop_tags()
    grad_miss.add_tag(site_tag)
    # gradient of loss reg
    grad_regular = gradient_reg(P, P_rem, alpha, [tensor])
    if grad_regular != 0:
        grad_regular.drop_tags()
        grad_regular.add_tag(site_tag)
    # total gradient
    total_grad = (1/batch_size)*(grad_miss) + grad_regular
    
    return total_grad

def global_update_costfuncnorm(P, n_epochs, n_iters, data, batch_size, alpha, lamda_init, lamda_init_2, bond_dim, decay_rate=None, expdecay_tol=None):
    loss_array = []
    loss_reg_array = []
    loss_miss_array = []
    n_tensors = P.nsites
    # P = dask.delayed(P)
    
    for epoch in range(n_epochs):
        for it in (pbar := tqdm(range(n_iters))):             
            pbar.set_description("Epoch #"+str(epoch)+", sample in batch:")
            
            # parallelize
            grad_per_tensor=[]
            for tensor in range(n_tensors):
                embed_func = fm.trigonometric
                output_per_tensor = dask.delayed(get_total_grad)(P, tensor, data[it], embed_func, batch_size, alpha) # get grad per tensor
                grad_per_tensor.append(output_per_tensor)
            grad_per_tensor = list(dask.compute(*grad_per_tensor))
            
            # get loss per sample
            loss_miss = 0
            for i, sample in enumerate(data[it]):
                embed_func = fm.trigonometric
                output_per_sample = dask.delayed(get_sample_loss)(sample, embed_func, P)
                loss_miss += output_per_sample
                
            # get total loss
            total_loss = (1/batch_size)*(loss_miss) + dask.delayed(loss_reg)(P, alpha)
            total_loss_miss = (1/batch_size)*(loss_miss)
            total_loss_reg = dask.delayed(loss_reg)(P, alpha)
            
            
            loss_array.append(total_loss)
            loss_reg_array.append(total_loss_reg)
            loss_miss_array.append(total_loss_miss)
            
            loss_array = list(dask.compute(*loss_array))
            loss_reg_array = list(dask.compute(*loss_reg_array))
            loss_miss_array = list(dask.compute(*loss_miss_array))  
            
            # update P
            # no need to parallelize
            futures = []
            for tensor in range(n_tensors):
                site_tag = P.site_tag(tensor)
                tensor_orig = dask.compute(P.select_tensors(site_tag, which="any")[0])[0]

                if epoch >= expdecay_tol:
                    if decay_rate != None:
                        # exp. decay of lamda
                        if epoch == exp_decay_tol: lamda = lamda_init_2
                        else: lamda = lamda_init_2*math.pow((1 - decay_rate/100),epoch)
                        tensor_orig.modify(data=tensor_orig.data - lamda*grad_per_tensor[tensor].transpose_like(tensor_orig).data)
                else:
                    tensor_orig.modify(data=tensor_orig.data - lamda_init*grad_per_tensor[tensor].transpose_like(tensor_orig).data)            
  

    return P, list(loss_array), list(loss_reg_array), list(loss_miss_array)

def train_SMPO(data, spacing, n_epochs, alpha, opt_procedure, lamda_init=2e-5, lamda_init_2=2e-3, decay_rate=None, expdecay_tol=None, bond_dim=4, init_func='normal', scale=0.5, batch_size=32, seed: int = None):
    
    train_data = np.array(data)
    N_features = train_data.shape[1]*train_data.shape[2]
    train_data_batched = np.array(np.split(train_data, batch_size))
    n_iters = int(train_data.shape[0]/batch_size)
    
    # initialize P
    P_orig = smpo.SpacedMatrixProductOperator.rand(n=N_features, spacing=spacing, bond_dim=bond_dim, init_func=init_func, scale=scale, seed=seed)
    P = P_orig.copy(deep=True)
    
    P, loss_array, loss_reg_array, loss_miss_array = opt_procedure(P, n_epochs, n_iters, train_data_batched, batch_size, alpha, lamda_init=lamda_init, lamda_init_2=lamda_init_2, bond_dim=bond_dim, decay_rate=decay_rate, expdecay_tol=expdecay_tol)
    return P, loss_array, loss_reg_array, loss_miss_array

# args:
train_size = 64
batch_size = 32
strides = (2,2)
pool_size = (2,2)
padding = 'same'
reduced_shape = (14,14)
opt_procedure = 'global_update_costfuncnorm' # 'global_update_costfuncnorm'
spacing = 8
n_epochs = 3
alpha = 0.4
lambda_init = 2e-5
lambda_init_2 = lambda_init # 2e-3
decay_rate = 0.01
expdecay_tol = 20
bond_dim = 5
init_func = 'normal'
scale_init_p = 0.5

params = {
    'train_size': train_size, 
    'batch_size': batch_size,
    'strides': strides,
    'pool_size': pool_size,
    'padding': padding,
    'reduced_shape': reduced_shape,
    'opt_procedure': opt_procedure,
    'spacing': spacing,
    'n_epochs': n_epochs,
    'alpha': alpha,
    'lambda_init': lambda_init,
    'lambda_init_2': lambda_init_2,
    'decay_rate': decay_rate,
    'expdecay_tol': expdecay_tol,
    'bond_dim': bond_dim,
    'init_func': init_func,
    'scale_init_p': scale_init_p
    }

results_folder = "/gpfs/scratch/bsc21/bsc21504/tnad/v_30-09-22/output/dask/128_trainsize/results"
reports_folder = "/gpfs/scratch/bsc21/bsc21504/tnad/v_30-09-22/output/dask/128_trainsize/preports"

We can now take a look at the `train_data`:

In [None]:
from tnad.optimization import load_mnist_train_data, data_preprocessing
import tnad.procedures as p
import matplotlib.pyplot as plt

train_data = load_mnist_train_data(train_size=train_size) # seed=123456
data = data_preprocessing(train_data, strides=strides, pool_size=pool_size, padding=padding, reduced_shape=reduced_shape)
plt.imshow(train_data[0,:,:], interpolation='nearest', cmap='Greys')
plt.show()

In [None]:
start_time = time.time()

train_data = load_mnist_train_data(train_size=train_size)
data = data_preprocessing(train_data, strides=strides, pool_size=pool_size, padding=padding, reduced_shape=reduced_shape)

if opt_procedure == 'local_update_sweep_dyncanonization_renorm':
    opt_procedure = local_update_sweep_dyncanonization_renorm

elif opt_procedure == 'global_update_costfuncnorm':
    opt_procedure = global_update_costfuncnorm

P, loss_array, loss_reg_array, loss_miss_array = train_SMPO(data, spacing, n_epochs, alpha, opt_procedure, lambda_init, lambda_init_2, decay_rate, expdecay_tol, bond_dim, init_func, scale_init_p, batch_size)
wait(loss_array)
loss_array_values = dask.compute(*loss_array)

computation_time = time.time()-start_time
print("computation time: ", computation_time, flush=True)
print(loss_array_values)

delayed_time = 0.

In [None]:
# create plot
markersize = 4.

fig, ax = plt.subplots(figsize=(10,4))
ax.plot(range(0, len(loss_array)), loss_array, "o", linestyle="--", markersize=markersize, label=r"total loss")
ax.plot(range(0, len(loss_reg_array)), loss_reg_array, "s", linestyle="--", markersize=markersize, label=r"loss reg")
ax.plot(range(0, len(loss_miss_array)), loss_miss_array, "d", linestyle="--", markersize=markersize, label=r"loss miss")

ax.set_xlabel("epoch", fontsize=16)
ax.set_ylabel("loss", fontsize=16)

ax.tick_params(labelsize=14)

plt.subplots_adjust(
    top=0.97,
    bottom=0.14,
    left=0.13,
    right=0.97,
    hspace=0.,
    wspace=0.
)

plt.legend(loc="best", prop={'size': 17}, handlelength=2)