In [1]:
import numpy as np
from functools import partial

import sys
sys.path.append('../../')
import peach as p


import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold

from tensorly.decomposition import quantized_parafac
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

In [117]:
numpy_to_struct = {np.float32 : 'f', np.int8 : 'i'}
pytorch_to_struct = {torch.float32 : 'f', torch.int8 : 'i'}

N_ITER_MAX = 10000
RANDOM_INIT_STARTS = 10


params_als_shared = {'n_iter_max' : N_ITER_MAX,\
                    'stop_criterion' : 'rec_error_decrease',\
                    'tol' : None,\
                    'normalize_factors' : True,\
                    'svd' : 'numpy_svd',\
                    'verbose' : 0}

params_als = {'orthogonalise' : False,\
              'non_negative' : False,\
              'mask' : None,\
              'return_errors' : False}

params_qint = {'dtype' : torch.qint8,\
               'qscheme' : torch.per_tensor_affine,\
               'dim' : None,\
               'return_scale_zeropoint' : True}

params_qals = {'qmodes' : [0, 1, 2],\
               'warmup_iters' : 1,\
               'return_scale_zeropoint' : True,\
               'return_rec_errors' : False,\
               'return_qrec_errors' : False}

## Generate tensor in Kruskal format

In [17]:
rank = 2
shape = (5, 5, 3)

dtype =  torch.float32
struct_dtype = pytorch_to_struct[dtype]

factors = get_factors_init(shape, rank, dtype = struct_dtype)
weights = torch.ones(rank).type(dtype)

krt = KruskalTensor((weights, factors))
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t.type(torch.float32))
print('||factors||: {}, factors type: {}'.format(tnorm, t.dtype))                                                                       


||factors||: 10141.3798828125, factors type: torch.float32


In [18]:
random_init_starts = RANDOM_INIT_STARTS
inits = ['svd'] + ['random'] * random_init_starts
random_states = [None] + [int(torch.randint(high = 11999, size = (1,))) for _ in range(random_init_starts)]


## ALS

In [19]:
rank = 2

run_id = 0
init = inits[run_id]
random_state = random_states[run_id]

out_als = parafac(t, rank,\
              init=init,\
              random_state=random_state,\
              **params_als_shared,\
              **params_als)

In [20]:
t_als = kruskal_to_tensor(out_als) 
rec_error_als = tl.norm(t - t_als)

print('||tensor - als_factors||: {}'.format(rec_error_als))

||tensor - als_factors||: 0.0013729545753449202


## ALS + post quantization

In [127]:
qfactors_post = []
scales_post = []
zero_points_post = []

for i in range(len(out_als.factors)):
    q, s, z = quantize_qint(out_als.factors[i], **params_qint)
    qfactors_post.append(q)
    scales_post.append(s)
    zero_points_post.append(z)

qt_post = kruskal_to_tensor(KruskalTensor((out_als.weights, qfactors_post)))
rec_error_qpost = tl.norm(t - qt_post)

print('||tensor - als_factors_quantized||: {}'.format(rec_error_qpost))

||tensor - als_factors_quantized||: 1934.88818359375


In [128]:
qfactors_post[0]/scales_post[0]

tensor([[  40., -128.],
        [-128.,   90.],
        [ -15.,  100.],
        [  51.,  -48.],
        [-128.,   54.]])

## QALS

In [25]:
(fout, qout), (errors, qerrors), scales, zero_points = quantized_parafac(t, rank,\
                                                                         init=init,\
                                                                         random_state=random_state,\
                                                                         **params_als_shared,\
                                                                         **params_qals,\
                                                                         **params_qint)

t_approx_qals = kruskal_to_tensor(KruskalTensor(qout))
rec_error_qals = tl.norm(t - t_approx_qals)

print('||tensor - qals_factors||: {}'.format(rec_error_qals))

||tensor - qals_factors||: 1631.13037109375


In [29]:
qout.factors[0]/scales[0], qout.weights

(tensor([[ -33.0000, -127.0000],
         [ 123.0000,   84.0000],
         [  12.0000,   81.0000],
         [ -43.0000,  -43.0000],
         [ 127.0000,   55.0000]]), tensor([7097.9902, 6765.4238]))

## For simulated annealing

In [78]:
def restore_factors(factors_flatten, shapes):
    factors_restored = []
    l = 0
    for shape in shapes:
        dl = np.prod(shape)
        factors_restored.append(torch.tensor(factors_flatten)[l : l + dl].reshape(shape))
        l += dl
    return factors_restored

def norm_error(factors_flatten, shapes, t, weights = None,\
               scales = None, zero_points = zero_points):
    
    factors_restored = restore_factors(factors_flatten, shapes)
    if scales is not None:
        factors_restored = [factors_restored[i]*scales[i] + zero_points[0] for i in range(len(scales))]
        
    tensor_restored = kruskal_to_tensor(KruskalTensor((weights, factors_restored)))

    tensor_restored.shape

    norm_error = tl.norm(tensor_restored - t)
    
    print('Norm error: {}'.format(norm_error))
    
    return float(norm_error)

## ALS + post quantization + simulated annealing

In [131]:
qfactors_post[0]/scales_post[0], scales_post, zero_points, out_als.weights

(tensor([[  40., -128.],
         [-128.,   90.],
         [ -15.,  100.],
         [  51.,  -48.],
         [-128.,   54.]]),
 [tensor(0.0046), tensor(0.0059), tensor(0.0070)],
 (tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32)),
 tensor([7182.6895, 6582.5713]))

In [146]:
qfactors_post_int = [qfactors_post[i]/scales_post[i]\
                     for i in range(len(qfactors_post))]

shapes = [factor.shape for factor in qfactors_post_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qfactors_post_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)

In [147]:
bsa = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-1] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedInvertBitsNeighbor(**neighbor_params)

bsa = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = out_als.weights,\
                         scales = scales_post, zero_points = zero_points_post),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**32, 2**32)]*len(factors_flatten_init),\
                 'f'*len(factors_flatten_init),\
                  emax = 1e-12,\
                  imax = 50000,\
                  T0 = 1000,\
                  rt = 0.999, neighbor = neighbor)

Norm error: 1934.88818359375


In [148]:
flatten_factors_opt, error =  bsa()

Norm error: 1934.88818359375
Norm error: 10966.8486328125
Norm error: 1934.88818359375
Norm error: 1934.877197265625
Norm error: 1934.877197265625
Norm error: 1934.876708984375
Norm error: 1934.876708984375
Norm error: 5176.64990234375
Norm error: 1934.876708984375
Norm error: 6046.2802734375
Norm error: 1934.876708984375
Norm error: 1934.8243408203125
Norm error: 1934.8243408203125
Norm error: 1934.810302734375
Norm error: 1934.810302734375
Norm error: 5025.10595703125
Norm error: 1934.810302734375
Norm error: 12302.564453125
Norm error: 1934.810302734375
Norm error: 1931.247802734375
Norm error: 1931.247802734375
Norm error: 4149.4375
Norm error: 1931.247802734375
Norm error: 1931.0098876953125
Norm error: 1931.0098876953125
Norm error: 4138.51025390625
Norm error: 1931.0098876953125
Norm error: 6141.02294921875
Norm error: 1931.0098876953125
Norm error: 1931.0045166015625
Norm error: 1931.0045166015625
Norm error: 10974.037109375
Norm error: 1931.0045166015625
Norm error: 1931.00146

## QALS + simulated annealing

In [132]:
qout.factors[0]/scales[0], scales, zero_points, qout.weights

(tensor([[ -33.0000, -127.0000],
         [ 123.0000,   84.0000],
         [  12.0000,   81.0000],
         [ -43.0000,  -43.0000],
         [ 127.0000,   55.0000]]),
 (tensor(0.0054), tensor(0.0050), tensor(0.0071)),
 (tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32)),
 tensor([7097.9902, 6765.4238]))

In [141]:
qout_factors_int = [qout.factors[i]/scales[i] for i in range(len(qout.factors))]

shapes = [factor.shape for factor in qout_factors_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qout_factors_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)

In [144]:
bsa = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-1] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedInvertBitsNeighbor(**neighbor_params)

bsa = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = qout.weights,\
                         scales = scales, zero_points = zero_points),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**32, 2**32)]*len(factors_flatten_init),\
                 'f'*len(factors_flatten_init),\
                  emax = 1e-12,\
                  imax = 50000,\
                  T0 = 1000,\
                  rt = 0.999, neighbor = neighbor)

Norm error: 1634.4222412109375


In [145]:
flatten_factors_opt, error =  bsa()

Norm error: 1634.4222412109375
Norm error: 1627.099365234375
Norm error: 1627.099365234375
Norm error: 1627.099365234375
