In [1]:
import numpy as np
from functools import partial

import sys
sys.path.append('../../')
import peach as p


import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold

from tensorly.decomposition import quantized_parafac
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

In [2]:
numpy_to_struct = {np.float32 : 'f', np.int8 : 'i'}
pytorch_to_struct = {torch.float32 : 'f', torch.int8 : 'i'}

N_ITER_MAX = 10000
RANDOM_INIT_STARTS = 10


params_als_shared = {'n_iter_max' : N_ITER_MAX,\
                    'stop_criterion' : 'rec_error_decrease',\
                    'tol' : None,\
                    'normalize_factors' : True,\
                    'svd' : 'numpy_svd',\
                    'verbose' : 0}

params_als = {'orthogonalise' : False,\
              'non_negative' : False,\
              'mask' : None,\
              'return_errors' : False}

params_qint = {'dtype' : torch.qint8,\
               'qscheme' : torch.per_channel_affine,\
               'dim' : 1}

params_qals = {'qmodes' : [0, 1, 2],\
               'warmup_iters' : 1,\
               'return_scale_zeropoint' : True,\
               'return_rec_errors' : False,\
               'return_qrec_errors' : False}

In [3]:
def get_factors_init(shape, rank, dtype = 'f'):
    if dtype == 'f':
        factors = [torch.randn((i, rank)).type(torch.float32)  for i in shape] 
    elif dtype == 'i':
        factors = [torch.randint(-256, 256, (i, rank)).type(torch.int8) for i in shape]
        
    return factors

pytorch_to_struct = {torch.float32 : 'f',\
                     torch.float64 : 'd',\
                     torch.float16 : 'e',\
                     torch.int8 : 'b',\
                     torch.int16 : 'h',\
                     torch.int32 : 'i',\
                     torch.int64 : 'q',\
                     torch.uint8 : 'B',\
                     torch.qint8 : 'b',\
                     torch.quint8 : 'B',\
                     torch.qint32 : 'i'}

## Generate tensor in Kruskal format

In [4]:
rank = 9
shape = (16, 16, 9)

dtype =  torch.float32
struct_dtype = pytorch_to_struct[dtype]

factors = get_factors_init(shape, rank, dtype = struct_dtype)
weights = torch.ones(rank).type(dtype)

krt = KruskalTensor((weights, factors))
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t.type(torch.float32))
print('||factors||: {}, factors type: {}'.format(tnorm, t.dtype))                                                                       


||factors||: 132.60455322265625, factors type: torch.float32


In [5]:
random_init_starts = RANDOM_INIT_STARTS
inits = ['svd'] + ['random'] * random_init_starts
random_states = [None] + [int(torch.randint(high = 11999, size = (1,))) for _ in range(random_init_starts)]


## ALS

In [6]:
rank = 8

run_id = 0
init = inits[run_id]
random_state = random_states[run_id]

out_als = parafac(t, rank,\
              init=init,\
              random_state=random_state,\
              **params_als_shared,\
              **params_als)

In [7]:
t_als = kruskal_to_tensor(out_als) 
rec_error_als = tl.norm(t - t_als)

print('||tensor - als_factors||: {}'.format(rec_error_als))

||tensor - als_factors||: 21.873809814453125


## ALS + post quantization

In [8]:
qfactors_post = []
scales_post = []
zero_points_post = []

for i in range(len(out_als.factors)):
    q, s, z = quantize_qint(out_als.factors[i], **params_qint, return_scale_zeropoint = True)
    qfactors_post.append(q)
    scales_post.append(s)
    zero_points_post.append(z)

qt_post = kruskal_to_tensor(KruskalTensor((out_als.weights, qfactors_post)))
rec_error_qpost = tl.norm(t - qt_post)

print('||tensor - als_factors_quantized||: {}'.format(rec_error_qpost))

||tensor - als_factors_quantized||: 21.992183685302734


In [9]:
qfactors_post[0]/scales_post[0],  zero_points_post

(tensor([[  40.0000,  -82.0000,   -8.0000,  -43.0000,   55.0000,   -5.0000,
           -95.0000,  -37.0000],
         [ -57.0000,  -66.0000,  -65.0000,   55.0000,  -50.0000,   -3.0000,
           -34.0000,   25.0000],
         [ -23.0000,  -31.0000,  -97.0000,  -29.0000,  -16.0000,  -20.0000,
            30.0000,   14.0000],
         [  19.0000,  -12.0000,   13.0000,  -64.0000,  -63.0000,   42.0000,
           -58.0000,   27.0000],
         [   3.0000,   31.0000,   42.0000,  -22.0000,   42.0000,   45.0000,
             3.0000,  -36.0000],
         [ -42.0000,  -12.0000,   58.0000,    4.0000,  -54.0000,   26.0000,
            47.0000,  -37.0000],
         [ -10.0000,  -24.0000,   36.0000,   61.0000,   10.0000,  -17.0000,
            45.0000,  -31.0000],
         [  87.0000,  -27.0000,  -53.0000,  -89.0000,   43.0000,   79.0000,
           -86.0000,   12.0000],
         [ -85.0000,   62.0000,  -10.0000,  -50.0000,  -22.0000,   70.0000,
           -68.0000,  -26.0000],
         [  22.0000

## QALS

In [10]:
(fout, qout), (errors, qerrors), scales, zero_points = quantized_parafac(t, rank,\
                                                                         init=init,\
                                                                         random_state=random_state,\
                                                                         **params_als_shared,\
                                                                         **params_qals,\
                                                                         **params_qint)

t_approx_qals = kruskal_to_tensor(KruskalTensor(qout))
rec_error_qals = tl.norm(t - t_approx_qals)

print('||tensor - qals_factors||: {}'.format(rec_error_qals))

||tensor - qals_factors||: 21.998783111572266


In [11]:
qout.factors[0]/scales[0], qout.weights

(tensor([[ 40.0000, -81.0000,   7.0000,  80.0000, -45.0000,   5.0000, -43.0000,
          -37.0000],
         [-57.0000, -66.0000,  58.0000,  29.0000,  41.0000,   4.0000,  55.0000,
           25.0000],
         [-23.0000, -31.0000,  87.0000, -26.0000,  13.0000,  22.0000, -29.0000,
           14.0000],
         [ 20.0000, -12.0000, -12.0000,  50.0000,  52.0000, -46.0000, -64.0000,
           27.0000],
         [  3.0000,  31.0000, -38.0000,  -2.0000, -34.0000, -49.0000, -22.0000,
          -36.0000],
         [-42.0000, -12.0000, -52.0000, -40.0000,  45.0000, -29.0000,   4.0000,
          -37.0000],
         [-10.0000, -24.0000, -33.0000, -39.0000,  -8.0000,  19.0000,  61.0000,
          -31.0000],
         [ 87.0000, -27.0000,  48.0000,  73.0000, -35.0000, -86.0000, -90.0000,
           11.0000],
         [-85.0000,  62.0000,   9.0000,  58.0000,  18.0000, -76.0000, -49.0000,
          -25.0000],
         [ 22.0000, -16.0000,   6.0000, -42.0000, -44.0000, -25.0000, -10.0000,
          -

## For simulated annealing

In [12]:
def restore_factors(factors_flatten, shapes):
    factors_restored = []
    l = 0
    for shape in shapes:
        dl = np.prod(shape)
        factors_restored.append(torch.tensor(factors_flatten)[l : l + dl].reshape(shape))
        l += dl
    return factors_restored

def norm_error(factors_flatten, shapes, t, weights = None,\
               scales = None, zero_points = zero_points):
    
    factors_restored = restore_factors(factors_flatten, shapes)
    if scales is not None:
        factors_restored = [(factors_restored[i] - zero_points[0])*scales[i] for i in range(len(scales))]
        
    tensor_restored = kruskal_to_tensor(KruskalTensor((weights, factors_restored)))

    tensor_restored.shape

    norm_error = tl.norm(tensor_restored - t)
    
    print('Norm error: {}'.format(norm_error))
    
    return float(norm_error)

## ALS + post quantization + simulated annealing

In [13]:
qfactors_post[0]/scales_post[0], scales_post, zero_points, out_als.weights

(tensor([[  40.0000,  -82.0000,   -8.0000,  -43.0000,   55.0000,   -5.0000,
           -95.0000,  -37.0000],
         [ -57.0000,  -66.0000,  -65.0000,   55.0000,  -50.0000,   -3.0000,
           -34.0000,   25.0000],
         [ -23.0000,  -31.0000,  -97.0000,  -29.0000,  -16.0000,  -20.0000,
            30.0000,   14.0000],
         [  19.0000,  -12.0000,   13.0000,  -64.0000,  -63.0000,   42.0000,
           -58.0000,   27.0000],
         [   3.0000,   31.0000,   42.0000,  -22.0000,   42.0000,   45.0000,
             3.0000,  -36.0000],
         [ -42.0000,  -12.0000,   58.0000,    4.0000,  -54.0000,   26.0000,
            47.0000,  -37.0000],
         [ -10.0000,  -24.0000,   36.0000,   61.0000,   10.0000,  -17.0000,
            45.0000,  -31.0000],
         [  87.0000,  -27.0000,  -53.0000,  -89.0000,   43.0000,   79.0000,
           -86.0000,   12.0000],
         [ -85.0000,   62.0000,  -10.0000,  -50.0000,  -22.0000,   70.0000,
           -68.0000,  -26.0000],
         [  22.0000

In [14]:
qfactors_post_int = [qfactors_post[i]/scales_post[i]\
                     for i in range(len(qfactors_post))]

shapes = [factor.shape for factor in qfactors_post_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qfactors_post_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)


bsa = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-2:] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedBitsNeighbor(**neighbor_params)

bsa_als = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = out_als.weights,\
                         scales = scales_post, zero_points = zero_points_post),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**7, 2**7)]*len(factors_flatten_init),\
                 'i'*len(factors_flatten_init),\
                  emax = 1e-10,\
                  imax = 10000,\
                  T0 = 1,\
                  rt = 0.8, neighbor = neighbor)

Norm error: 22.53983497619629


In [16]:
flatten_factors_opt_als, error_als =  bsa_als()

Norm error: 21.998380661010742
Norm error: 28.177337646484375
Norm error: 21.998380661010742
Norm error: 33.068416595458984
Norm error: 21.998380661010742
Norm error: 31.785926818847656
Norm error: 21.998380661010742
Norm error: 59.587337493896484
Norm error: 21.998380661010742
Norm error: 42.09645080566406
Norm error: 21.998380661010742
Norm error: 55.714454650878906
Norm error: 21.998380661010742
Norm error: 22.018604278564453
Norm error: 21.998380661010742
Norm error: 25.422475814819336
Norm error: 21.998380661010742
Norm error: 22.011640548706055
Norm error: 21.998380661010742
Norm error: 38.034881591796875
Norm error: 21.998380661010742
Norm error: 58.88566589355469
Norm error: 21.998380661010742
Norm error: 25.42574691772461
Norm error: 21.998380661010742
Norm error: 45.86314010620117
Norm error: 21.998380661010742
Norm error: 24.92977523803711
Norm error: 21.998380661010742
Norm error: 45.96036911010742
Norm error: 21.998380661010742
Norm error: 44.70004653930664
Norm error: 21.

## QALS + simulated annealing

In [17]:
qout.factors[0]/scales[0], scales, zero_points, qout.weights

(tensor([[ 40.0000, -81.0000,   7.0000,  80.0000, -45.0000,   5.0000, -43.0000,
          -37.0000],
         [-57.0000, -66.0000,  58.0000,  29.0000,  41.0000,   4.0000,  55.0000,
           25.0000],
         [-23.0000, -31.0000,  87.0000, -26.0000,  13.0000,  22.0000, -29.0000,
           14.0000],
         [ 20.0000, -12.0000, -12.0000,  50.0000,  52.0000, -46.0000, -64.0000,
           27.0000],
         [  3.0000,  31.0000, -38.0000,  -2.0000, -34.0000, -49.0000, -22.0000,
          -36.0000],
         [-42.0000, -12.0000, -52.0000, -40.0000,  45.0000, -29.0000,   4.0000,
          -37.0000],
         [-10.0000, -24.0000, -33.0000, -39.0000,  -8.0000,  19.0000,  61.0000,
          -31.0000],
         [ 87.0000, -27.0000,  48.0000,  73.0000, -35.0000, -86.0000, -90.0000,
           11.0000],
         [-85.0000,  62.0000,   9.0000,  58.0000,  18.0000, -76.0000, -49.0000,
          -25.0000],
         [ 22.0000, -16.0000,   6.0000, -42.0000, -44.0000, -25.0000, -10.0000,
          -

In [18]:
qout_factors_int = [qout.factors[i]/scales[i] for i in range(len(qout.factors))]

shapes = [factor.shape for factor in qout_factors_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qout_factors_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)


bsa = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-2:] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedBitsNeighbor(**neighbor_params)

bsa_qals = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = qout.weights,\
                         scales = scales, zero_points = zero_points),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**7, 2**7)]*len(factors_flatten_init),\
                 'i'*len(factors_flatten_init),\
                  emax = 1e-10,\
                  imax = 10000,\
                  T0 = 1,\
                  rt = 0.8, neighbor = neighbor)

Norm error: 22.58852767944336


In [20]:
flatten_factors_opt_qals, error_qals =  bsa_qals()

Norm error: 22.027708053588867
Norm error: 31.105695724487305
Norm error: 22.027708053588867
Norm error: 27.949188232421875
Norm error: 22.027708053588867
Norm error: 60.16566467285156
Norm error: 22.027708053588867
Norm error: 25.516637802124023
Norm error: 22.027708053588867
Norm error: 49.08021545410156
Norm error: 22.027708053588867
Norm error: 31.264850616455078
Norm error: 22.027708053588867
Norm error: 26.79775619506836
Norm error: 22.027708053588867
Norm error: 53.3630485534668
Norm error: 22.027708053588867
Norm error: 22.091899871826172
Norm error: 22.027708053588867
Norm error: 27.758569717407227
Norm error: 22.027708053588867
Norm error: 28.20972442626953
Norm error: 22.027708053588867
Norm error: 39.25651168823242
Norm error: 22.027708053588867
Norm error: 39.05971145629883
Norm error: 22.027708053588867
Norm error: 51.181209564208984
Norm error: 22.027708053588867
Norm error: 25.776613235473633
Norm error: 22.027708053588867
Norm error: 22.031818389892578
Norm error: 22.0

Norm error: 35.830814361572266
Norm error: 22.027708053588867
Norm error: 47.307003021240234
Norm error: 22.027708053588867
Norm error: 24.188682556152344
Norm error: 22.027708053588867
Norm error: 22.57244300842285
Norm error: 22.027708053588867
Norm error: 22.427173614501953
Norm error: 22.027708053588867
Norm error: 36.51369857788086
Norm error: 22.027708053588867
Norm error: 31.51897621154785
Norm error: 22.027708053588867
Norm error: 28.968097686767578
Norm error: 22.027708053588867
Norm error: 43.19017791748047
Norm error: 22.027708053588867
Norm error: 22.398393630981445
Norm error: 22.027708053588867
Norm error: 58.95134735107422
Norm error: 22.027708053588867
Norm error: 26.40958023071289
Norm error: 22.027708053588867
Norm error: 34.41197204589844
Norm error: 22.027708053588867
Norm error: 33.89562225341797
Norm error: 22.027708053588867
Norm error: 54.54009246826172
Norm error: 22.027708053588867
Norm error: 24.408166885375977
Norm error: 22.027708053588867
Norm error: 24.30

Norm error: 48.93480682373047
Norm error: 22.027708053588867
Norm error: 25.39876365661621
Norm error: 22.027708053588867
Norm error: 35.10333251953125
Norm error: 22.027708053588867
Norm error: 24.83453369140625
Norm error: 22.027708053588867
Norm error: 27.223302841186523
Norm error: 22.027708053588867
Norm error: 22.095808029174805
Norm error: 22.027708053588867
Norm error: 65.22598266601562
Norm error: 22.027708053588867
Norm error: 35.28342819213867
Norm error: 22.027708053588867
Norm error: 22.77098846435547
Norm error: 22.027708053588867
Norm error: 40.18872833251953
Norm error: 22.027708053588867
Norm error: 23.968456268310547
Norm error: 22.027708053588867
Norm error: 36.34998321533203
Norm error: 22.027708053588867
Norm error: 32.988487243652344
Norm error: 22.027708053588867
Norm error: 36.64154052734375
Norm error: 22.027708053588867
Norm error: 34.967384338378906
Norm error: 22.027708053588867
Norm error: 31.167381286621094
Norm error: 22.027708053588867
Norm error: 35.375

Norm error: 29.109256744384766
Norm error: 22.027708053588867
Norm error: 43.51509475708008
Norm error: 22.027708053588867
Norm error: 22.271217346191406
Norm error: 22.027708053588867
Norm error: 48.570152282714844
Norm error: 22.027708053588867
Norm error: 34.563446044921875
Norm error: 22.027708053588867
Norm error: 45.36761474609375
Norm error: 22.027708053588867
Norm error: 34.49357223510742
Norm error: 22.027708053588867
Norm error: 23.691497802734375
Norm error: 22.027708053588867
Norm error: 28.940778732299805
Norm error: 22.027708053588867
Norm error: 39.88970184326172
Norm error: 22.027708053588867
Norm error: 51.927467346191406
Norm error: 22.027708053588867
Norm error: 31.295146942138672
Norm error: 22.027708053588867
Norm error: 27.286218643188477
Norm error: 22.027708053588867
Norm error: 24.338239669799805
Norm error: 22.027708053588867
Norm error: 40.52605056762695
Norm error: 22.027708053588867
Norm error: 35.97270202636719
Norm error: 22.027708053588867
Norm error: 34

In [22]:
flatten_factors_opt_qals

(40,
 -83,
 5,
 80,
 -45,
 4,
 -44,
 -38,
 -59,
 -66,
 57,
 28,
 40,
 4,
 56,
 24,
 -24,
 -32,
 85,
 -27,
 13,
 20,
 -31,
 13,
 20,
 -13,
 -12,
 48,
 51,
 -46,
 -66,
 26,
 2,
 30,
 -38,
 -3,
 -35,
 -49,
 -23,
 -36,
 -42,
 -12,
 -52,
 -40,
 44,
 -31,
 4,
 -38,
 -10,
 -24,
 -33,
 -39,
 -8,
 17,
 61,
 -32,
 87,
 -28,
 46,
 72,
 -36,
 -85,
 -91,
 11,
 -88,
 62,
 8,
 57,
 16,
 -76,
 -52,
 -27,
 22,
 -16,
 5,
 -42,
 -43,
 -26,
 -12,
 -68,
 21,
 33,
 10,
 -12,
 -24,
 -86,
 4,
 18,
 24,
 -16,
 -16,
 -11,
 -40,
 5,
 23,
 -68,
 19,
 -11,
 -61,
 -9,
 40,
 61,
 -69,
 -58,
 -63,
 -10,
 -47,
 9,
 8,
 -71,
 -19,
 4,
 11,
 20,
 -16,
 52,
 89,
 -8,
 -84,
 41,
 -7,
 86,
 -52,
 58,
 40,
 8,
 -28,
 85,
 53,
 45,
 35,
 60,
 64,
 -4,
 12,
 -5,
 -53,
 66,
 47,
 9,
 -48,
 13,
 -39,
 12,
 52,
 -5,
 58,
 70,
 25,
 44,
 -29,
 -52,
 0,
 -6,
 48,
 -48,
 11,
 -31,
 36,
 -44,
 24,
 -10,
 -92,
 52,
 80,
 17,
 -16,
 -88,
 -1,
 -35,
 44,
 72,
 -24,
 -33,
 22,
 32,
 -104,
 -56,
 16,
 27,
 -25,
 68,
 -13,
 4,
 8,
 92,
 4