In [1]:
import numpy as np
from functools import partial

import sys
sys.path.append('../../')
import peach as p


import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold

from tensorly.decomposition import quantized_parafac
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

import copy

In [2]:
numpy_to_struct = {np.float32 : 'f', np.int8 : 'i'}
pytorch_to_struct = {torch.float32 : 'f', torch.int8 : 'i'}

N_ITER_MAX = 10000
RANDOM_INIT_STARTS = 10


params_als_shared = {'n_iter_max' : N_ITER_MAX,\
                    'tol' : 1e-18,\
                    'normalize_factors' : True,\
                    'svd' : 'numpy_svd',\
                    'verbose' : 0}

params_als = {'orthogonalise' : False,\
              'non_negative' : False,\
              'mask' : None,\
              'return_errors' : False,\
              'stop_criterion' : 'rec_error_deviation'}

params_qint = {'dtype' : torch.qint8,\
               'qscheme' : torch.per_channel_affine,\
               'dim' : 1}

params_qals = {'qmodes' : [0, 1, 2],\
               'warmup_iters' : 1,\
               'return_scale_zeropoint' : True,\
               'return_rec_errors' : False,\
               'return_qrec_errors' : False,\
               'stop_criterion' : 'both_error_deviation'}

In [3]:
def get_factors_init(shape, rank, dtype = 'f'):
    if dtype == 'f':
        factors = [torch.randn((i, rank)).type(torch.float32)  for i in shape] 
    elif dtype == 'i':
        factors = [torch.randint(-256, 256, (i, rank)).type(torch.int8) for i in shape]
        
    return factors

pytorch_to_struct = {torch.float32 : 'f',\
                     torch.float64 : 'd',\
                     torch.float16 : 'e',\
                     torch.int8 : 'b',\
                     torch.int16 : 'h',\
                     torch.int32 : 'i',\
                     torch.int64 : 'q',\
                     torch.uint8 : 'B',\
                     torch.qint8 : 'b',\
                     torch.quint8 : 'B',\
                     torch.qint32 : 'i'}

## Generate tensor in Kruskal format

In [4]:
rank = 10
shape = (16, 16, 9)

rank_expansion = 1

dtype =  torch.float32
struct_dtype = pytorch_to_struct[dtype]

factors = get_factors_init(shape, rank, dtype = struct_dtype)

for i in range(len(factors)):
    factors[i] = factors[i].repeat(1, rank_expansion)

weights = torch.ones(rank * rank_expansion).type(dtype)
    

krt = KruskalTensor((weights, factors))
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t.type(torch.float32))
print('||factors||: {}, factors type: {}'.format(tnorm, t.dtype))                                                                       


||factors||: 154.82830810546875, factors type: torch.float32


In [5]:
random_init_starts = RANDOM_INIT_STARTS
inits = ['svd'] + ['random'] * random_init_starts
random_states = [None] + [int(torch.randint(high = 11999, size = (1,))) for _ in range(random_init_starts)]


## ALS

In [6]:

rank = rank 

run_id = 1
init = inits[run_id]
random_state = random_states[run_id]

out_als = None
out_als = parafac(t, rank,\
              init=init,\
              random_state=random_state,\
              **params_als_shared,\
              **params_als)

In [7]:
t_als = kruskal_to_tensor(out_als) 
rec_error_als = tl.norm(t - t_als)

print('||tensor - als_factors||: {}'.format(rec_error_als))

||tensor - als_factors||: 0.00046896134153939784


In [8]:
for i in range(3):
    print(tl.norm(out_als.factors[0], 2, axis = 0))
    
out_als.weights

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])


tensor([46.4955, 22.5025, 18.7776, 73.8671, 23.2815, 77.7286, 61.1841, 25.5077,
        46.8260, 48.0008])

## ALS + post quantization

In [9]:
qfactors_post = []
scales_post = []
zero_points_post = []

weights_post = copy.deepcopy(out_als.weights)

for i in range(len(out_als.factors)):
    q, s, z = quantize_qint(out_als.factors[i], **params_qint, return_scale_zeropoint = True)
    qfactors_post.append(q)
    scales_post.append(s)
    zero_points_post.append(z)
    
    weights_post /=  tl.norm(q, order = 2, axis = 0)
    
print(weights_post)


qt_post = kruskal_to_tensor(KruskalTensor((weights_post, qfactors_post)))
rec_error_qpost = tl.norm(t - qt_post)

print('||tensor - als_factors_quantized||: {}'.format(rec_error_qpost))

tensor([46.5547, 22.5584, 18.8045, 74.0972, 23.3725, 78.3342, 61.5168, 25.4934,
        46.8735, 48.2320])
||tensor - als_factors_quantized||: 1.0523691177368164


In [10]:
zero_points_post, scales_post

([tensor([ -5,  -2,  46,  10, -46,  63,   8,  38, -50, -24], dtype=torch.int32),
  tensor([-23, -33,  57,  15,  -1, -28, -32, -26,  -6,  -2], dtype=torch.int32),
  tensor([ -40,   -2,   24,  -37,   10,  -38,  -13,   36, -127,   61],
         dtype=torch.int32)],
 [tensor([0.0032, 0.0030, 0.0031, 0.0035, 0.0031, 0.0037, 0.0035, 0.0029, 0.0036,
          0.0031]),
  tensor([0.0031, 0.0037, 0.0043, 0.0029, 0.0037, 0.0040, 0.0038, 0.0036, 0.0039,
          0.0039]),
  tensor([0.0036, 0.0044, 0.0046, 0.0036, 0.0043, 0.0039, 0.0036, 0.0036, 0.0025,
          0.0038])])

## QALS

In [11]:
(fout, qout), (errors, qerrors), scales, zero_points = quantized_parafac(t, rank,\
                                                                         init=init,\
                                                                         random_state=random_state,\
                                                                         **params_als_shared,\
                                                                         **params_qals,\
                                                                         **params_qint)

t_approx_qals = kruskal_to_tensor(KruskalTensor(qout))
rec_error_qals = tl.norm(t - t_approx_qals)

t_approx_fals = kruskal_to_tensor(KruskalTensor(fout))
rec_error_fals = tl.norm(t - t_approx_fals)

print('||tensor - fals_factors||: {}'.format(rec_error_fals))
print('||tensor - qals_factors||: {}'.format(rec_error_qals))

||tensor - fals_factors||: 21.62333106994629
||tensor - qals_factors||: 21.653894424438477


In [12]:
for i in range(3):
    print(tl.norm(fout.factors[0], 2, axis = 0))
    print(tl.norm(qout.factors[0], 2, axis = 0))
    
    
print(qout.weights)
print(fout.weights)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0010, 1.0003, 0.9985, 0.9990, 0.9993, 0.9984, 0.9995, 0.9997, 1.0002,
        1.0005])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0010, 1.0003, 0.9985, 0.9990, 0.9993, 0.9984, 0.9995, 0.9997, 1.0002,
        1.0005])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000])
tensor([1.0010, 1.0003, 0.9985, 0.9990, 0.9993, 0.9984, 0.9995, 0.9997, 1.0002,
        1.0005])
tensor([68.9587, 46.9449, 45.4421, 18.2753, 74.2930, 70.0987, 25.3675, 79.4484,
        22.5150, 48.1033])
tensor([68.7769, 46.9411, 45.3775, 18.2146, 74.1132, 70.1068, 25.3691, 79.1590,
        22.4923, 47.9603])


## For simulated annealing

In [30]:
def restore_factors(factors_flatten, shapes):
    factors_restored = []
    l = 0
    for shape in shapes:
        dl = np.prod(shape)
        factors_restored.append(torch.tensor(factors_flatten)[l : l + dl].reshape(shape))
        l += dl
    return factors_restored


def norm_error(factors_flatten, shapes, t, weights = None,\
               scales = None, zero_points = None):
    
    factors_restored = restore_factors(factors_flatten, shapes)
    if scales is not None:
        factors_restored = [(factors_restored[i] - zero_points[i]) * scales[i]\
                            for i in range(len(scales))]

    tensor_restored = kruskal_to_tensor(KruskalTensor((weights, factors_restored)))

    norm_error = tl.norm(tensor_restored - t)
  
    global i
    global gnorm
    i += 1
    gnorm = float(norm_error)
    
    return float(norm_error)

## ALS + post quantization + simulated annealing

In [31]:
qmin=-128
qmax=127
qfactors_post[0]/scales_post[0] + zero_points_post[0]

tensor([[   8.0000,  -74.0000,   48.0000,  -31.0000,  -55.0000,   80.0000,
         -128.0000,   86.0000,  -39.0000,  -51.0000],
        [-115.0000,  -96.0000,   -9.0000,   34.0000,  113.0000,   85.0000,
          -35.0000,  -99.0000,  -15.0000,   -9.0000],
        [ -15.0000,   63.0000,  -55.0000,  101.0000, -125.0000,   92.0000,
            1.0000,  -99.0000,  -30.0000,  -55.0000],
        [  50.0000, -125.0000,   75.0000,  -67.0000,  -10.0000,   44.0000,
          -93.0000,   55.0000,  126.0000,   -1.0000],
        [  52.0000,   13.0000,   23.0000,   -8.0000,   42.0000, -128.0000,
          -61.0000,  109.0000,  -17.0000,  119.0000],
        [  92.0000,  123.0000, -111.0000,  -88.0000,  -52.0000,   21.0000,
           28.0000,   77.0000,    2.0000,   75.0000],
        [ -38.0000,   73.0000,  127.0000,  126.0000,  -70.0000,  -84.0000,
           60.0000,   25.0000,  -41.0000,  -95.0000],
        [ -14.0000,  -82.0000,   83.0000, -128.0000, -128.0000,  126.0000,
           43.0000,   

In [32]:
# qmin = -128
# qmax = 127
qfactors_post_int = [qfactors_post[i]/scales_post[i] + zero_points_post[i] \
                     for i in range(len(qfactors_post))]

shapes = [factor.shape for factor in qfactors_post_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qfactors_post_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)


bsa_als = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[:] = True

neighbor_params = {'nb':1, 'mask': mask}
neighbor = p.sa.MaskedBitsNeighbor(**neighbor_params)

i = 0
gnorm = 0
bsa_als = p.BinarySA(partial(norm_error,
                             shapes = shapes, t = t, weights = weights_post,\
                             scales = scales_post, zero_points = zero_points_post),\
                     np.array(factors_flatten_init).astype(np.int8),\
                     [(-2**7, 2**7)]*len(factors_flatten_init),\
                     'i'*len(factors_flatten_init),\
                     emax = 1e-10,\
                     imax = 10000,\
                     T0 = 1,\
                     rt = 0.8, neighbor = neighbor)

In [None]:
# make sense to run several times
for r in range(10000):
    flatten_factors_opt_als, error_als =  bsa_als()
    if r % 100 == 0:
        print('round {}, i = {}, norm_error = {}'.format(r, i, gnorm))
    i = 0
    

round 0, i = 273, norm_error = 1.1724416017532349
round 100, i = 1190, norm_error = 1.0019264221191406
round 200, i = 1896, norm_error = 0.9990823864936829
round 300, i = 358, norm_error = 0.9930474162101746
round 400, i = 848, norm_error = 0.9800000786781311
round 500, i = 388, norm_error = 0.9783644080162048


In [65]:
flatten_factors_opt_als, error_als =  bsa_als()

Norm error: 0.6827762126922607, type : torch.float32
Norm error: 49.30796813964844, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 56.129146575927734, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 128.5807342529297, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 19.572996139526367, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 86.55799865722656, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 88.92695617675781, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 118.18357849121094, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 19.79806900024414, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float32
Norm error: 28.945478439331055, type : torch.float32
Norm error: 0.6827762126922607, type : torch.float3

## QALS + simulated annealing

In [57]:
qout.factors[0]/scales[0], scales, zero_points, qout.weights

(tensor([[  88.0000, -127.0000],
         [  30.0000,   49.0000],
         [  32.0000,  -75.0000],
         [  -4.0000, -127.0000],
         [ 109.0000,   10.0000]]),
 (tensor([0.0068, 0.0048]),
  tensor([0.0060, 0.0046]),
  tensor([0.0048, 0.0069])),
 (tensor([-1, -1], dtype=torch.int32),
  tensor([-1, -1], dtype=torch.int32),
  tensor([-1, -1], dtype=torch.int32)),
 tensor([113.6810,  46.5381]))

In [67]:
qout_factors_int = [qout.factors[i]/scales[i]+zero_points[i] for i in range(len(qout.factors))]

shapes = [factor.shape for factor in qout_factors_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qout_factors_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)


bsa_qals = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-2:] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedBitsNeighbor(**neighbor_params)

bsa_qals = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = qout.weights,\
                         scales = scales, zero_points = zero_points),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**7, 2**7)]*len(factors_flatten_init),\
                 'i'*len(factors_flatten_init),\
                  emax = 1e-10,\
                  imax = 10000,\
                  T0 = 1,\
                  rt = 0.8, neighbor = neighbor)

Norm error: 0.8272514343261719, type : torch.float32


In [None]:
# make sense to run several times
for j in range(100):
    print('=========================={}=========================='.format(j))
    flatten_factors_opt_qals, error_qals =  bsa_qals()

In [73]:
flatten_factors_opt_qals, error_qals =  bsa_qals()

Norm error: 0.5835343599319458, type : torch.float32
Norm error: 0.9785531759262085, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 26.667634963989258, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 83.25546264648438, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 23.067344665527344, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 140.20301818847656, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 113.11274719238281, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 93.16915130615234, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 92.21993255615234, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 1.027391791343689, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float

Norm error: 129.10337829589844, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 115.65867614746094, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 69.82952880859375, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 1.0807952880859375, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 88.44097137451172, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 54.795433044433594, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 72.8763656616211, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 86.60588073730469, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 70.5448226928711, type : torch.float32
Norm error: 0.5835343599319458, type : torch.float32
Norm error: 76.43610382080078, type : torch.float32
N

In [21]:
flatten_factors_opt_qals

(-20,
 59,
 -46,
 71,
 -44,
 -46,
 76,
 -38,
 14,
 48,
 -11,
 -36,
 8,
 55,
 54,
 -60,
 -14,
 4,
 -10,
 -6,
 -20,
 18,
 -16,
 -31,
 -3,
 -54,
 24,
 -4,
 46,
 14,
 22,
 -15,
 -66,
 -8,
 34,
 20,
 -3,
 -32,
 36,
 29,
 76,
 40,
 -60,
 -44,
 9,
 38,
 -83,
 -7,
 -6,
 4,
 8,
 -49,
 -53,
 57,
 8,
 -32,
 -56,
 33,
 -18,
 -4,
 92,
 0,
 -44,
 0,
 68,
 -3,
 24,
 53,
 63,
 -48,
 -80,
 -4,
 20,
 0,
 -23,
 -8,
 27,
 2,
 -29,
 4,
 4,
 87,
 -26,
 0,
 53,
 -28,
 -28,
 56,
 -36,
 12,
 -44,
 -44,
 -15,
 34,
 -26,
 3,
 -34,
 55,
 90,
 7,
 -35,
 12,
 -17,
 -41,
 60,
 65,
 77,
 -69,
 4,
 12,
 -12,
 101,
 -52,
 -42,
 -25,
 -117,
 -61,
 97,
 47,
 9,
 -66,
 -3,
 -16,
 -24,
 35,
 23,
 -23,
 -6,
 -3,
 2,
 -12,
 -52,
 52,
 -16,
 32,
 56,
 -28,
 -33,
 25,
 -36,
 31,
 11,
 -6,
 55,
 50,
 51,
 36,
 -37,
 15,
 -77,
 -8,
 -18,
 -41,
 -112,
 1,
 14,
 -53,
 -38,
 -71,
 -55,
 -35,
 -64,
 -40,
 -50,
 -19,
 -44,
 47,
 28,
 -12,
 36,
 -71,
 28,
 103,
 -15,
 8,
 -41,
 33,
 1,
 26,
 -64,
 -56,
 21,
 58,
 91,
 -15,
 16,
 52,
 

In [3]:
(2**7-1 + 2**7 ) == 2**8 - 1

True

### 