In [1]:
import numpy as np
from functools import partial

import sys
sys.path.append('../../')
import peach as p


import tensorly as tl
from tensorly.decomposition import parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold

from tensorly.decomposition import quantized_parafac
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

In [2]:
numpy_to_struct = {np.float32 : 'f', np.int8 : 'i'}
pytorch_to_struct = {torch.float32 : 'f', torch.int8 : 'i'}

N_ITER_MAX = 10000
RANDOM_INIT_STARTS = 10


params_als_shared = {'n_iter_max' : N_ITER_MAX,\
                    'stop_criterion' : 'rec_error_decrease',\
                    'tol' : None,\
                    'normalize_factors' : True,\
                    'svd' : 'numpy_svd',\
                    'verbose' : 0}

params_als = {'orthogonalise' : False,\
              'non_negative' : False,\
              'mask' : None,\
              'return_errors' : False}

params_qint = {'dtype' : torch.qint8,\
               'qscheme' : torch.per_tensor_affine,\
               'dim' : None}

params_qals = {'qmodes' : [0, 1, 2],\
               'warmup_iters' : 1,\
               'return_scale_zeropoint' : True,\
               'return_rec_errors' : False,\
               'return_qrec_errors' : False}

In [3]:
def get_factors_init(shape, rank, dtype = 'f'):
    if dtype == 'f':
        factors = [torch.randn((i, rank)).type(torch.float32)*10  for i in shape] 
    elif dtype == 'i':
        factors = [torch.randint(-256, 256, (i, rank)).type(torch.int8) for i in shape]
        
    return factors

pytorch_to_struct = {torch.float32 : 'f',\
                     torch.float64 : 'd',\
                     torch.float16 : 'e',\
                     torch.int8 : 'b',\
                     torch.int16 : 'h',\
                     torch.int32 : 'i',\
                     torch.int64 : 'q',\
                     torch.uint8 : 'B',\
                     torch.qint8 : 'b',\
                     torch.quint8 : 'B',\
                     torch.qint32 : 'i'}

## Generate tensor in Kruskal format

In [4]:
rank = 4
shape = (16, 16, 9)

dtype =  torch.float32
struct_dtype = pytorch_to_struct[dtype]

factors = get_factors_init(shape, rank, dtype = struct_dtype)
weights = torch.ones(rank).type(dtype)

krt = KruskalTensor((weights, factors))
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t.type(torch.float32))
print('||factors||: {}, factors type: {}'.format(tnorm, t.dtype))                                                                       


||factors||: 101126.9453125, factors type: torch.float32


In [5]:
random_init_starts = RANDOM_INIT_STARTS
inits = ['svd'] + ['random'] * random_init_starts
random_states = [None] + [int(torch.randint(high = 11999, size = (1,))) for _ in range(random_init_starts)]


## ALS

In [6]:
rank = 4

run_id = 0
init = inits[run_id]
random_state = random_states[run_id]

out_als = parafac(t, rank,\
              init=init,\
              random_state=random_state,\
              **params_als_shared,\
              **params_als)

In [7]:
t_als = kruskal_to_tensor(out_als) 
rec_error_als = tl.norm(t - t_als)

print('||tensor - als_factors||: {}'.format(rec_error_als))

||tensor - als_factors||: 0.017842302098870277


## ALS + post quantization

In [8]:
qfactors_post = []
scales_post = []
zero_points_post = []

for i in range(len(out_als.factors)):
    q, s, z = quantize_qint(out_als.factors[i], **params_qint, return_scale_zeropoint = True)
    qfactors_post.append(q)
    scales_post.append(s)
    zero_points_post.append(z)

qt_post = kruskal_to_tensor(KruskalTensor((out_als.weights, qfactors_post)))
rec_error_qpost = tl.norm(t - qt_post)

print('||tensor - als_factors_quantized||: {}'.format(rec_error_qpost))

||tensor - als_factors_quantized||: 3366.031494140625


In [9]:
qfactors_post[0]/scales_post[0], qfactors_post[1]/scales_post[1], qfactors_post[2]/scales_post[2], zero_points_post

(tensor([[  14.,   66.,  -76.,   80.],
         [  67.,   16.,   32.,   36.],
         [ -71., -118.,  114.,   38.],
         [  11.,  -31.,  -21.,   46.],
         [   8.,   22.,  -47.,   93.],
         [  99.,   -2.,    6., -108.],
         [  92.,  -46.,   12.,  -96.],
         [ -26.,   61.,   35.,   51.],
         [  23., -125.,  -20.,   56.],
         [  37.,    3.,    7.,   14.],
         [  37.,  -50.,    9.,  -54.],
         [  53.,    2.,  -30.,  -25.],
         [ -61.,   27., -127.,   -1.],
         [ 127.,   46.,   88.,  -23.],
         [ -67.,  -83.,   66.,   68.],
         [  39.,  -87.,   88.,   74.]]), tensor([[ -46.,   50.,   19.,   54.],
         [  90.,   73.,  113.,   45.],
         [  70.,   70.,  -36.,  112.],
         [ -15.,   -9.,  114.,    1.],
         [  13.,  -62.,  117.,   43.],
         [  72.,  -60.,  -61.,   48.],
         [  64., -128.,   18.,   29.],
         [  56.,   28.,   42.,   77.],
         [-124.,   50.,  -67., -127.],
         [ -75.,  -58., 

## QALS

In [10]:
(fout, qout), (errors, qerrors), scales, zero_points = quantized_parafac(t, rank,\
                                                                         init=init,\
                                                                         random_state=random_state,\
                                                                         **params_als_shared,\
                                                                         **params_qals,\
                                                                         **params_qint)

t_approx_qals = kruskal_to_tensor(KruskalTensor(qout))
rec_error_qals = tl.norm(t - t_approx_qals)

print('||tensor - qals_factors||: {}'.format(rec_error_qals))

||tensor - qals_factors||: 3224.328857421875


In [11]:
qout.factors[0]/scales[0], qout.weights

(tensor([[ -14.0000,   66.0000,   76.0000,  -80.0000],
         [ -67.0000,   16.0000,  -32.0000,  -36.0000],
         [  71.0000, -118.0000, -114.0000,  -39.0000],
         [ -11.0000,  -31.0000,   21.0000,  -46.0000],
         [  -8.0000,   22.0000,   47.0000,  -93.0000],
         [ -99.0000,   -2.0000,   -6.0000,  108.0000],
         [ -92.0000,  -46.0000,  -12.0000,   96.0000],
         [  26.0000,   61.0000,  -34.0000,  -50.0000],
         [ -22.0000, -125.0000,   19.0000,  -57.0000],
         [ -37.0000,    3.0000,   -7.0000,  -14.0000],
         [ -37.0000,  -50.0000,   -9.0000,   54.0000],
         [ -53.0000,    2.0000,   30.0000,   25.0000],
         [  61.0000,   27.0000,  127.0000,    1.0000],
         [-128.0000,   46.0000,  -87.0000,   23.0000],
         [  67.0000,  -83.0000,  -67.0000,  -69.0000],
         [ -39.0000,  -87.0000,  -88.0000,  -74.0000]]),
 tensor([49561.2617, 53510.1797, 52058.8203, 49514.7227]))

## For simulated annealing

In [12]:
def restore_factors(factors_flatten, shapes):
    factors_restored = []
    l = 0
    for shape in shapes:
        dl = np.prod(shape)
        factors_restored.append(torch.tensor(factors_flatten)[l : l + dl].reshape(shape))
        l += dl
    return factors_restored

def norm_error(factors_flatten, shapes, t, weights = None,\
               scales = None, zero_points = zero_points):
    
    factors_restored = restore_factors(factors_flatten, shapes)
    if scales is not None:
        factors_restored = [(factors_restored[i] - zero_points[0])*scales[i] for i in range(len(scales))]
        
    tensor_restored = kruskal_to_tensor(KruskalTensor((weights, factors_restored)))

    tensor_restored.shape

    norm_error = tl.norm(tensor_restored - t)
    
    print('Norm error: {}'.format(norm_error))
    
    return float(norm_error)

## ALS + post quantization + simulated annealing

In [13]:
qfactors_post[0]/scales_post[0], scales_post, zero_points, out_als.weights

(tensor([[  14.,   66.,  -76.,   80.],
         [  67.,   16.,   32.,   36.],
         [ -71., -118.,  114.,   38.],
         [  11.,  -31.,  -21.,   46.],
         [   8.,   22.,  -47.,   93.],
         [  99.,   -2.,    6., -108.],
         [  92.,  -46.,   12.,  -96.],
         [ -26.,   61.,   35.,   51.],
         [  23., -125.,  -20.,   56.],
         [  37.,    3.,    7.,   14.],
         [  37.,  -50.,    9.,  -54.],
         [  53.,    2.,  -30.,  -25.],
         [ -61.,   27., -127.,   -1.],
         [ 127.,   46.,   88.,  -23.],
         [ -67.,  -83.,   66.,   68.],
         [  39.,  -87.,   88.,   74.]]),
 [tensor(0.0040), tensor(0.0038), tensor(0.0051)],
 (tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32)),
 tensor([49523.3867, 53556.0977, 52042.8242, 49521.7383]))

In [27]:
qfactors_post_int = [qfactors_post[i]/scales_post[i]\
                     for i in range(len(qfactors_post))]

shapes = [factor.shape for factor in qfactors_post_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qfactors_post_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)

In [28]:
bsa = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-2:] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedInvertBitsNeighbor(**neighbor_params)

bsa = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = out_als.weights,\
                         scales = scales_post, zero_points = zero_points_post),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**7, 2**7)]*len(factors_flatten_init),\
                 'i'*len(factors_flatten_init),\
                  emax = 1e-10,\
                  imax = 500000,\
                  T0 = 1,\
                  rt = 0.99999, neighbor = neighbor)

Norm error: 3372.342041015625


In [29]:
flatten_factors_opt, error =  bsa()

Norm error: 3372.342041015625
Norm error: 24500.578125
Norm error: 3372.342041015625
Norm error: 25505.416015625
Norm error: 3372.342041015625
Norm error: 27728.318359375
Norm error: 3372.342041015625
Norm error: 25356.953125
Norm error: 3372.342041015625
Norm error: 21551.267578125
Norm error: 3372.342041015625
Norm error: 29727.65625
Norm error: 3372.342041015625
Norm error: 3467.399169921875
Norm error: 3372.342041015625
Norm error: 12243.0048828125
Norm error: 3372.342041015625
Norm error: 18669.91796875
Norm error: 3372.342041015625
Norm error: 6146.46484375
Norm error: 3372.342041015625
Norm error: 31934.4765625
Norm error: 3372.342041015625
Norm error: 23334.302734375
Norm error: 3372.342041015625
Norm error: 27726.5546875
Norm error: 3372.342041015625
Norm error: 8090.4052734375
Norm error: 3372.342041015625
Norm error: 22336.333984375
Norm error: 3372.342041015625
Norm error: 3355.71630859375
Norm error: 3355.71630859375
Norm error: 3398.374267578125
Norm error: 3355.716308593

Norm error: 23105.794921875
Norm error: 3320.76806640625
Norm error: 50911.57421875
Norm error: 3320.76806640625
Norm error: 14587.6728515625
Norm error: 3320.76806640625
Norm error: 3398.39453125
Norm error: 3320.76806640625
Norm error: 14344.15234375
Norm error: 3320.76806640625
Norm error: 42805.2578125
Norm error: 3320.76806640625
Norm error: 29794.6484375
Norm error: 3320.76806640625
Norm error: 4580.68115234375
Norm error: 3320.76806640625
Norm error: 7623.27001953125
Norm error: 3320.76806640625
Norm error: 17655.185546875
Norm error: 3320.76806640625
Norm error: 3654.092529296875
Norm error: 3320.76806640625
Norm error: 21841.240234375
Norm error: 3320.76806640625
Norm error: 25893.775390625
Norm error: 3320.76806640625
Norm error: 10946.119140625
Norm error: 3320.76806640625
Norm error: 39369.03515625
Norm error: 3320.76806640625
Norm error: 8677.443359375
Norm error: 3320.76806640625
Norm error: 16478.875
Norm error: 3320.76806640625
Norm error: 4391.921875
Norm error: 3320.7

Norm error: 52557.83984375
Norm error: 3231.42578125
Norm error: 15228.4287109375
Norm error: 3231.42578125
Norm error: 22339.15625
Norm error: 3231.42578125
Norm error: 5468.5625
Norm error: 3231.42578125
Norm error: 45814.86328125
Norm error: 3231.42578125
Norm error: 34612.6953125
Norm error: 3231.42578125
Norm error: 18361.3125
Norm error: 3231.42578125
Norm error: 17723.859375
Norm error: 3231.42578125
Norm error: 45290.65234375
Norm error: 3231.42578125
Norm error: 30554.900390625
Norm error: 3231.42578125
Norm error: 34907.5234375
Norm error: 3231.42578125
Norm error: 44285.16015625
Norm error: 3231.42578125
Norm error: 41452.7578125
Norm error: 3231.42578125
Norm error: 4049.663818359375
Norm error: 3231.42578125
Norm error: 34603.25390625
Norm error: 3231.42578125
Norm error: 19611.78125
Norm error: 3231.42578125
Norm error: 46606.3046875
Norm error: 3231.42578125
Norm error: 35768.890625
Norm error: 3231.42578125
Norm error: 14012.12890625
Norm error: 3231.42578125
Norm error

Norm error: 18204.337890625
Norm error: 3226.97998046875
Norm error: 17229.34765625
Norm error: 3226.97998046875
Norm error: 22765.87890625
Norm error: 3226.97998046875
Norm error: 42152.98046875
Norm error: 3226.97998046875
Norm error: 32689.15234375
Norm error: 3226.97998046875
Norm error: 11193.2392578125
Norm error: 3226.97998046875
Norm error: 21459.0859375
Norm error: 3226.97998046875
Norm error: 38745.5078125
Norm error: 3226.97998046875
Norm error: 3266.01904296875
Norm error: 3226.97998046875
Norm error: 9266.8662109375
Norm error: 3226.97998046875
Norm error: 23806.1640625
Norm error: 3226.97998046875
Norm error: 33498.06640625
Norm error: 3226.97998046875
Norm error: 15741.9921875
Norm error: 3226.97998046875
Norm error: 45500.484375
Norm error: 3226.97998046875
Norm error: 22465.849609375
Norm error: 3226.97998046875
Norm error: 5881.8037109375
Norm error: 3226.97998046875
Norm error: 38322.35546875
Norm error: 3226.97998046875
Norm error: 47125.53515625
Norm error: 3226.97

Norm error: 22312.361328125
Norm error: 3211.5751953125
Norm error: 3207.53857421875
Norm error: 3207.53857421875
Norm error: 30538.359375
Norm error: 3207.53857421875
Norm error: 20043.86328125
Norm error: 3207.53857421875
Norm error: 28611.958984375
Norm error: 3207.53857421875
Norm error: 3720.036376953125
Norm error: 3207.53857421875
Norm error: 40836.296875
Norm error: 3207.53857421875
Norm error: 11802.279296875
Norm error: 3207.53857421875
Norm error: 12281.3955078125
Norm error: 3207.53857421875
Norm error: 24608.3359375
Norm error: 3207.53857421875
Norm error: 32039.91796875
Norm error: 3207.53857421875
Norm error: 12448.2236328125
Norm error: 3207.53857421875
Norm error: 34881.08984375
Norm error: 3207.53857421875
Norm error: 16949.982421875
Norm error: 3207.53857421875
Norm error: 17666.17578125
Norm error: 3207.53857421875
Norm error: 6825.73876953125
Norm error: 3207.53857421875
Norm error: 3272.133056640625
Norm error: 3207.53857421875
Norm error: 4152.69140625
Norm error

Norm error: 41675.015625
Norm error: 3207.53857421875
Norm error: 17908.294921875
Norm error: 3207.53857421875
Norm error: 3549.27001953125
Norm error: 3207.53857421875
Norm error: 8296.3974609375
Norm error: 3207.53857421875
Norm error: 22984.59375
Norm error: 3207.53857421875
Norm error: 40398.01171875
Norm error: 3207.53857421875
Norm error: 35011.640625
Norm error: 3207.53857421875
Norm error: 15655.3486328125
Norm error: 3207.53857421875
Norm error: 8446.4033203125
Norm error: 3207.53857421875
Norm error: 3254.04736328125
Norm error: 3207.53857421875
Norm error: 3479.92578125
Norm error: 3207.53857421875
Norm error: 49847.71875
Norm error: 3207.53857421875
Norm error: 13023.59765625
Norm error: 3207.53857421875
Norm error: 28607.775390625
Norm error: 3207.53857421875
Norm error: 10116.9833984375
Norm error: 3207.53857421875
Norm error: 23798.0546875
Norm error: 3207.53857421875
Norm error: 15422.0849609375
Norm error: 3207.53857421875
Norm error: 14436.7724609375
Norm error: 3207.

Norm error: 3202.480712890625
Norm error: 38925.00390625
Norm error: 3202.480712890625
Norm error: 16454.330078125
Norm error: 3202.480712890625
Norm error: 31871.578125
Norm error: 3202.480712890625
Norm error: 48634.1796875
Norm error: 3202.480712890625
Norm error: 28415.35546875
Norm error: 3202.480712890625
Norm error: 3238.2607421875
Norm error: 3202.480712890625
Norm error: 24862.88671875
Norm error: 3202.480712890625
Norm error: 25006.908203125
Norm error: 3202.480712890625
Norm error: 19273.576171875
Norm error: 3202.480712890625
Norm error: 34737.25390625
Norm error: 3202.480712890625
Norm error: 30290.951171875
Norm error: 3202.480712890625
Norm error: 37338.39453125
Norm error: 3202.480712890625
Norm error: 38162.66796875
Norm error: 3202.480712890625
Norm error: 42338.82421875
Norm error: 3202.480712890625
Norm error: 37794.01953125
Norm error: 3202.480712890625
Norm error: 33358.47265625
Norm error: 3202.480712890625
Norm error: 47114.75
Norm error: 3202.480712890625
Norm 

Norm error: 45666.796875
Norm error: 3202.480712890625
Norm error: 4980.3837890625
Norm error: 3202.480712890625
Norm error: 3233.137451171875
Norm error: 3202.480712890625
Norm error: 9284.314453125
Norm error: 3202.480712890625
Norm error: 41077.01953125
Norm error: 3202.480712890625
Norm error: 29957.806640625
Norm error: 3202.480712890625
Norm error: 11240.69140625
Norm error: 3202.480712890625
Norm error: 42789.4609375
Norm error: 3202.480712890625
Norm error: 23790.83203125
Norm error: 3202.480712890625
Norm error: 18673.052734375
Norm error: 3202.480712890625
Norm error: 18278.345703125
Norm error: 3202.480712890625
Norm error: 8005.94384765625
Norm error: 3202.480712890625
Norm error: 47657.8984375
Norm error: 3202.480712890625
Norm error: 29379.78125
Norm error: 3202.480712890625
Norm error: 10816.8916015625
Norm error: 3202.480712890625
Norm error: 18996.365234375
Norm error: 3202.480712890625
Norm error: 3686.90283203125
Norm error: 3202.480712890625
Norm error: 40981.996093

Norm error: 32807.59765625
Norm error: 3202.523193359375
Norm error: 20175.11328125
Norm error: 3202.523193359375
Norm error: 11882.662109375
Norm error: 3202.523193359375
Norm error: 3458.87646484375
Norm error: 3202.523193359375
Norm error: 24434.84765625
Norm error: 3202.523193359375
Norm error: 16615.90234375
Norm error: 3202.523193359375
Norm error: 12079.5830078125
Norm error: 3202.523193359375
Norm error: 51687.41796875
Norm error: 3202.523193359375
Norm error: 3260.0947265625
Norm error: 3202.523193359375
Norm error: 11983.439453125
Norm error: 3202.523193359375
Norm error: 15954.578125
Norm error: 3202.523193359375
Norm error: 27897.8515625
Norm error: 3202.523193359375
Norm error: 42502.97265625
Norm error: 3202.523193359375
Norm error: 16679.15625
Norm error: 3202.523193359375
Norm error: 31323.25390625
Norm error: 3202.523193359375
Norm error: 28097.326171875
Norm error: 3202.523193359375
Norm error: 31221.662109375
Norm error: 3202.523193359375
Norm error: 25841.361328125


Norm error: 5679.4326171875
Norm error: 3173.232177734375
Norm error: 3899.89111328125
Norm error: 3173.232177734375
Norm error: 29954.525390625
Norm error: 3173.232177734375
Norm error: 14027.4208984375
Norm error: 3173.232177734375
Norm error: 35777.0234375
Norm error: 3173.232177734375
Norm error: 32352.103515625
Norm error: 3173.232177734375
Norm error: 25948.458984375
Norm error: 3173.232177734375
Norm error: 31456.72265625
Norm error: 3173.232177734375
Norm error: 8274.703125
Norm error: 3173.232177734375
Norm error: 3220.289306640625
Norm error: 3173.232177734375
Norm error: 5094.91064453125
Norm error: 3173.232177734375
Norm error: 17267.625
Norm error: 3173.232177734375
Norm error: 23414.75
Norm error: 3173.232177734375
Norm error: 35084.7421875
Norm error: 3173.232177734375
Norm error: 43253.90234375
Norm error: 3173.232177734375
Norm error: 27530.193359375
Norm error: 3173.232177734375
Norm error: 19125.794921875
Norm error: 3173.232177734375
Norm error: 44423.83203125
Norm 

Norm error: 46299.6015625
Norm error: 3165.697509765625
Norm error: 11248.3330078125
Norm error: 3165.697509765625
Norm error: 13684.1337890625
Norm error: 3165.697509765625
Norm error: 23326.677734375
Norm error: 3165.697509765625
Norm error: 46623.46875
Norm error: 3165.697509765625
Norm error: 10443.49609375
Norm error: 3165.697509765625
Norm error: 36930.6328125
Norm error: 3165.697509765625
Norm error: 36414.1171875
Norm error: 3165.697509765625
Norm error: 11872.9111328125
Norm error: 3165.697509765625
Norm error: 44955.16796875
Norm error: 3165.697509765625
Norm error: 3203.260009765625
Norm error: 3165.697509765625
Norm error: 19383.0390625
Norm error: 3165.697509765625
Norm error: 20870.134765625
Norm error: 3165.697509765625
Norm error: 6754.77001953125
Norm error: 3165.697509765625
Norm error: 55387.80859375
Norm error: 3165.697509765625
Norm error: 23957.623046875
Norm error: 3165.697509765625
Norm error: 26038.689453125
Norm error: 3165.697509765625
Norm error: 14212.52929

Norm error: 26059.111328125
Norm error: 3161.117431640625
Norm error: 48761.98828125
Norm error: 3161.117431640625
Norm error: 4008.786376953125
Norm error: 3161.117431640625
Norm error: 18608.181640625
Norm error: 3161.117431640625
Norm error: 36588.3515625
Norm error: 3161.117431640625
Norm error: 19277.548828125
Norm error: 3161.117431640625
Norm error: 33977.625
Norm error: 3161.117431640625
Norm error: 3283.753173828125
Norm error: 3161.117431640625
Norm error: 48283.43359375
Norm error: 3161.117431640625
Norm error: 11738.201171875
Norm error: 3161.117431640625
Norm error: 20251.92578125
Norm error: 3161.117431640625
Norm error: 3175.700439453125
Norm error: 3161.117431640625
Norm error: 8477.421875
Norm error: 3161.117431640625
Norm error: 9943.62890625
Norm error: 3161.117431640625
Norm error: 31010.6328125
Norm error: 3161.117431640625
Norm error: 45560.9453125
Norm error: 3161.117431640625
Norm error: 10250.0908203125
Norm error: 3161.117431640625
Norm error: 25012.775390625


## QALS + simulated annealing

In [17]:
qout.factors[0]/scales[0], scales, zero_points, qout.weights

(tensor([[ -14.0000,   66.0000,   76.0000,  -80.0000],
         [ -67.0000,   16.0000,  -32.0000,  -36.0000],
         [  71.0000, -118.0000, -114.0000,  -39.0000],
         [ -11.0000,  -31.0000,   21.0000,  -46.0000],
         [  -8.0000,   22.0000,   47.0000,  -93.0000],
         [ -99.0000,   -2.0000,   -6.0000,  108.0000],
         [ -92.0000,  -46.0000,  -12.0000,   96.0000],
         [  26.0000,   61.0000,  -34.0000,  -50.0000],
         [ -22.0000, -125.0000,   19.0000,  -57.0000],
         [ -37.0000,    3.0000,   -7.0000,  -14.0000],
         [ -37.0000,  -50.0000,   -9.0000,   54.0000],
         [ -53.0000,    2.0000,   30.0000,   25.0000],
         [  61.0000,   27.0000,  127.0000,    1.0000],
         [-128.0000,   46.0000,  -87.0000,   23.0000],
         [  67.0000,  -83.0000,  -67.0000,  -69.0000],
         [ -39.0000,  -87.0000,  -88.0000,  -74.0000]]),
 (tensor(0.0040), tensor(0.0039), tensor(0.0051)),
 (tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32),
  t

In [30]:
qout_factors_int = [qout.factors[i]/scales[i] for i in range(len(qout.factors))]

shapes = [factor.shape for factor in qout_factors_int]
factors_flatten_init =  torch.cat([factor.flatten() for factor in qout_factors_int])
factors_flatten_init = factors_flatten_init.type(torch.int8)

In [31]:
bsa = None

# numpy_to_struct = {np.float32 : 'f', np.float64 : 'd',  np.float16 : 'e',\
#                    np.int8 : 'b', np.int16 : 'h', np.int32 : 'i', np.int64 : 'q',\
#                    np.uint8 : 'B', np.uint16 : 'H', np.uint32 : 'I', np.uint64 : 'Q'}


mask = np.zeros(8).astype(bool)
mask[-2:] = True

neighbor_params = {'nb':2, 'mask': mask}
neighbor = p.sa.MaskedInvertBitsNeighbor(**neighbor_params)

bsa = p.BinarySA(partial(norm_error,\
                         shapes = shapes, t = t, weights = qout.weights,\
                         scales = scales, zero_points = zero_points),\
                 np.array(factors_flatten_init).astype(np.int8),\
#                  np.array(flatten_factors_opt),\
                 [(-2**7, 2**7)]*len(factors_flatten_init),\
                 'i'*len(factors_flatten_init),\
                  emax = 1e-10,\
                  imax = 500000,\
                  T0 = 10,\
                  rt = 0.99999, neighbor = neighbor)

Norm error: 3292.44580078125


In [32]:
flatten_factors_opt, error =  bsa()

Norm error: 3292.44580078125
Norm error: 35309.2890625
Norm error: 3292.44580078125
Norm error: 17573.962890625
Norm error: 3292.44580078125
Norm error: 37294.953125
Norm error: 3292.44580078125
Norm error: 23493.826171875
Norm error: 3292.44580078125
Norm error: 4733.12158203125
Norm error: 3292.44580078125
Norm error: 25874.078125
Norm error: 3292.44580078125
Norm error: 8288.728515625
Norm error: 3292.44580078125
Norm error: 35905.83203125
Norm error: 3292.44580078125
Norm error: 24999.466796875
Norm error: 3292.44580078125
Norm error: 20384.576171875
Norm error: 3292.44580078125
Norm error: 6165.01953125
Norm error: 3292.44580078125
Norm error: 21151.591796875
Norm error: 3292.44580078125
Norm error: 3558.0654296875
Norm error: 3292.44580078125
Norm error: 20823.37890625
Norm error: 3292.44580078125
Norm error: 3554.734619140625
Norm error: 3292.44580078125
Norm error: 15002.3544921875
Norm error: 3292.44580078125
Norm error: 35125.09765625
Norm error: 3292.44580078125
Norm error: 

Norm error: 8744.388671875
Norm error: 3158.9169921875
Norm error: 9930.013671875
Norm error: 3158.9169921875
Norm error: 14285.931640625
Norm error: 3158.9169921875
Norm error: 42690.67578125
Norm error: 3158.9169921875
Norm error: 3237.7001953125
Norm error: 3158.9169921875
Norm error: 12411.564453125
Norm error: 3158.9169921875
Norm error: 25019.0546875
Norm error: 3158.9169921875
Norm error: 4928.7138671875
Norm error: 3158.9169921875
Norm error: 35623.94140625
Norm error: 3158.9169921875
Norm error: 32605.173828125
Norm error: 3158.9169921875
Norm error: 3217.743408203125
Norm error: 3158.9169921875
Norm error: 5262.1416015625
Norm error: 3158.9169921875
Norm error: 6371.7412109375
Norm error: 3158.9169921875
Norm error: 51964.41796875
Norm error: 3158.9169921875
Norm error: 17167.41015625
Norm error: 3158.9169921875
Norm error: 22904.380859375
Norm error: 3158.9169921875
Norm error: 25369.00390625
Norm error: 3158.9169921875
Norm error: 37851.69921875
Norm error: 3158.9169921875


Norm error: 21080.646484375
Norm error: 3226.387939453125
Norm error: 21061.169921875
Norm error: 3226.387939453125
Norm error: 5294.1162109375
Norm error: 3226.387939453125
Norm error: 3217.51416015625
Norm error: 3217.51416015625
Norm error: 29401.00390625
Norm error: 3217.51416015625
Norm error: 3917.498046875
Norm error: 3217.51416015625
Norm error: 26632.6171875
Norm error: 3217.51416015625
Norm error: 35570.875
Norm error: 3217.51416015625
Norm error: 31636.23828125
Norm error: 3217.51416015625
Norm error: 29397.458984375
Norm error: 3217.51416015625
Norm error: 7803.38427734375
Norm error: 3217.51416015625
Norm error: 30755.49609375
Norm error: 3217.51416015625
Norm error: 36053.5390625
Norm error: 3217.51416015625
Norm error: 3231.4814453125
Norm error: 3217.51416015625
Norm error: 18703.470703125
Norm error: 3217.51416015625
Norm error: 3246.810791015625
Norm error: 3217.51416015625
Norm error: 13459.9423828125
Norm error: 3217.51416015625
Norm error: 6603.8251953125
Norm erro

Norm error: 12666.791015625
Norm error: 3224.4423828125
Norm error: 45834.640625
Norm error: 3224.4423828125
Norm error: 36898.2578125
Norm error: 3224.4423828125
Norm error: 14478.7568359375
Norm error: 3224.4423828125
Norm error: 9180.6240234375
Norm error: 3224.4423828125
Norm error: 23333.888671875
Norm error: 3224.4423828125
Norm error: 45408.27734375
Norm error: 3224.4423828125
Norm error: 8897.4013671875
Norm error: 3224.4423828125
Norm error: 7725.048828125
Norm error: 3224.4423828125
Norm error: 20288.849609375
Norm error: 3224.4423828125
Norm error: 23549.57421875
Norm error: 3224.4423828125
Norm error: 46968.9765625
Norm error: 3224.4423828125
Norm error: 11945.2177734375
Norm error: 3224.4423828125
Norm error: 32667.51953125
Norm error: 3224.4423828125
Norm error: 25966.68359375
Norm error: 3224.4423828125
Norm error: 36223.55078125
Norm error: 3224.4423828125
Norm error: 10427.826171875
Norm error: 3224.4423828125
Norm error: 50967.8828125
Norm error: 3224.4423828125
Norm 

Norm error: 3132.771240234375
Norm error: 3132.771240234375
Norm error: 7341.546875
Norm error: 3132.771240234375
Norm error: 12159.072265625
Norm error: 3132.771240234375
Norm error: 44420.74609375
Norm error: 3132.771240234375
Norm error: 3212.278076171875
Norm error: 3132.771240234375
Norm error: 4247.734375
Norm error: 3132.771240234375
Norm error: 13000.2841796875
Norm error: 3132.771240234375
Norm error: 23613.966796875
Norm error: 3132.771240234375
Norm error: 32211.40234375
Norm error: 3132.771240234375
Norm error: 59300.60546875
Norm error: 3132.771240234375
Norm error: 3373.0419921875
Norm error: 3132.771240234375
Norm error: 21509.068359375
Norm error: 3132.771240234375
Norm error: 33349.72265625
Norm error: 3132.771240234375
Norm error: 3181.585205078125
Norm error: 3132.771240234375
Norm error: 20876.044921875
Norm error: 3132.771240234375
Norm error: 36811.06640625
Norm error: 3132.771240234375
Norm error: 34793.83984375
Norm error: 3132.771240234375
Norm error: 10559.613

In [33]:
flatten_factors_opt

(-14,
 66,
 76,
 -80,
 -68,
 16,
 -32,
 -36,
 70,
 -118,
 -116,
 -39,
 -11,
 -31,
 21,
 -46,
 -8,
 23,
 47,
 -94,
 -99,
 -2,
 -6,
 108,
 -91,
 -46,
 -12,
 96,
 26,
 61,
 -34,
 -52,
 -22,
 -125,
 19,
 -58,
 -37,
 3,
 -7,
 -14,
 -37,
 -50,
 -9,
 54,
 -53,
 2,
 30,
 25,
 61,
 27,
 126,
 1,
 -128,
 46,
 -87,
 23,
 67,
 -83,
 -67,
 -69,
 -39,
 -88,
 -88,
 -74,
 45,
 49,
 19,
 53,
 -88,
 72,
 111,
 46,
 -68,
 68,
 -36,
 109,
 15,
 -8,
 117,
 1,
 -13,
 -60,
 115,
 40,
 -71,
 -58,
 -60,
 46,
 -63,
 -128,
 17,
 29,
 -55,
 28,
 43,
 75,
 121,
 48,
 -68,
 -124,
 73,
 -56,
 -36,
 18,
 -25,
 -49,
 21,
 -128,
 -21,
 68,
 54,
 14,
 37,
 -41,
 -79,
 46,
 68,
 13,
 -12,
 37,
 -38,
 -120,
 -81,
 -35,
 -108,
 13,
 -13,
 53,
 -95,
 52,
 -5,
 1,
 -27,
 89,
 60,
 -104,
 -69,
 -76,
 -20,
 -56,
 39,
 111,
 -69,
 -117,
 24,
 -14,
 127,
 44,
 78,
 -77,
 32,
 -14,
 -114,
 -41,
 -17,
 5,
 -28,
 47,
 -97,
 -13,
 -47,
 22,
 12,
 94)