In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [2]:
import tensorly as tl
from tensorly.decomposition import parafac, quantized_parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

import numpy as np
from collections import defaultdict
import os
import logging
import sys

formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')


def setup_logger(name, log_file, level=logging.INFO):
    """To setup as many loggers as you want"""

    handler = logging.FileHandler(log_file)        
    handler.setFormatter(formatter)

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.addHandler(handler)

    return logger

In [3]:
torch.cuda.is_available()

True

# Example 3. Quantized ALS
Compare standard ALS algorithm for finding  CP decomposition with its quantized version, when at the end of each ALS step approximated factor is quantized.

In [4]:
def generate_tensor(shape, rank):
   ## Generate tensor
    factors = [torch.randn((i, rank)) for i in shape] 
    weights = torch.ones(rank)

    # tensor in Kruscal format
    krt = KruskalTensor((weights, factors)) 
    
    # corresponding tensor in full format
    t = kruskal_to_tensor(krt)
    
    return t

In [10]:
def compare_als(shape, inits, random_states,\
                params_als_shared, params_als, params_qals, params_qint,\
                nruns = 2, loggsfile = None, delta_rank = 1, device = 'cuda'):
    
    if loggsfile:
        logger = setup_logger('logger_{}'.format(shape), '{}.log'.format(loggsfile))

    
    tmp = shape[0]
#     ranks = set([int(np.sqrt(tmp)), tmp//2, tmp//3, tmp + tmp//3]) 
    ranks = sorted(set([int(np.sqrt(tmp)), tmp//2, tmp,  tmp + tmp//2, 2*tmp]))

    
    result = {}
    
    for rank in ranks:
        rec_errors_all = {run : {'tensor_norm':None,\
                                  'als_factors':[],\
                                  'als_factors_quantized':[],\
                                  'qals_factors':[]} for run in range(nruns)}

        for run in range(nruns):
            t = generate_tensor(shape, rank + delta_rank).to(device)
            tnorm = tl.norm(t)
            rec_errors_all[run]['tensor_norm'] = tnorm

            logger.info('\n-----------------------------------\n')
            logger.info("Shape = {}, rank_init = {}, rank_approx = {}, run = {}, ||tensor|| = {}".format(shape, rank + delta_rank, rank, run, tnorm))
            
            for init, random_state in zip(inits, random_states):
                
                logger.info('\n===================================\nrandom state {}'.format(random_state))
                try:
                    ## ALS with float factors
                    out = parafac(t, rank,\
                                  init=init,\
                                  random_state=random_state,\
                                  **params_als_shared,\
                                  **params_als)

                    t_als = kruskal_to_tensor(out) 

                    rec_error_als = tl.norm(t - t_als)
                    rec_errors_all[run]['als_factors'].append(rec_error_als)

                    logger.info('||tensor - als_factors||: {}'.format(rec_error_als))


                    ## Post ALS factors quantization
                    qfactors_post = [quantize_qint(out.factors[i].cpu(), **params_qint).to(device)\
                                     for i in range(len(out.factors))]
                    qt_post = kruskal_to_tensor(KruskalTensor((out.weights, qfactors_post)))

                    rec_error_qpost = tl.norm(t - qt_post)
                    rec_errors_all[run]['als_factors_quantized'].append(rec_error_qpost)

                    logger.info('||tensor - als_factors_quantized||: {}'.format(rec_error_qpost))


                    ## Quantized ALS
                    qout, qerrors, scales, zero_points = quantized_parafac(t, rank,\
                                                                         init=init,\
                                                                         random_state=random_state,\
                                                                         **params_als_shared,\
                                                                         **params_qals,\
                                                                         **params_qint)

                    t_approx_qals = kruskal_to_tensor(KruskalTensor(qout))

                    rec_error_qals = tl.norm(t - t_approx_qals)
                    rec_errors_all[run]['qals_factors'].append(rec_error_qals)

                    logger.info('||tensor - qals_factors||: {}'.format(rec_error_qals))
                except:
                    logger.info('Bad random state')


        result[rank] = rec_errors_all
    return result

### Run comparison 

In [11]:
N_ITER_MAX = 50000
RANDOM_INIT_STARTS = 10
NRUNS = 3
SHAPES = [(16, 16, 9), (32, 32, 9), (64, 64, 9),\
          (128, 128, 9), (256, 256, 9), (512, 512, 9)]

In [13]:
params_als_shared = {'n_iter_max' : N_ITER_MAX,\
                    'stop_criterion' : 'rec_error_decrease',\
                    'tol' : 1e-18,\
                    'normalize_factors' : True,\
                    'svd' : 'numpy_svd',\
                    'verbose' : 0}

params_als = {'orthogonalise' : False,\
              'non_negative' : False,\
              'mask' : None,\
              'return_errors' : False}

params_qint = {'dtype' : torch.qint8,\
               'qscheme' : torch.per_channel_affine,\
               'dim' : 1}

params_qals = {'qmodes' : [0, 1, 2],\
               'return_scale_zeropoint' : True,\
              }

In [None]:
delta_rank = 0
save_dir = '/workspace/raid/data/jgusak/tensorly/qals_test_results_drank{}'.format(delta_rank)
loggs_dir = save_dir + '/loggs'

if not os.path.exists(save_dir):
    os.makedirs(save_dir)
if not os.path.exists(loggs_dir):
    os.makedirs(loggs_dir)

random_init_starts = RANDOM_INIT_STARTS
inits = ['svd'] + ['random'] * random_init_starts
random_states = [None] + [int(torch.randint(high = 11999, size = (1,))) for _ in range(random_init_starts)]

nruns = NRUNS
shapes = SHAPES

for shape in shapes:
    loggsfile = '{}/{}'.format(loggs_dir, '_'.join(map(str, shape)))
    
    result = compare_als(shape, inits,\
                random_states,params_als_shared,\
                params_als, params_qals, params_qint,\
                nruns = nruns, loggsfile = loggsfile,\
                delta_rank = delta_rank)
    np.save('{}/{}.npy'.format(save_dir, shape), result)


### Load saved results