In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [2]:
import tensorly as tl
from tensorly.decomposition import parafac, quantized_parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

import logging
import sys

# logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',
#                      level=logging.INFO, filename = 'logs.txt')
# logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s',
#                      level=logging.INFO, stream=sys.stdout)

# Example 1. Tensor quantization
We use PyTorch build-in tools to perform quantization.

Quantization scheme can be either affine or symmetric.

Scale and zero_point values to perform quantization are computed either per channel or per tensor (i.e. we get either vectors or scalars).

Thus, there are 4 types of quantization scheme:

    ``torch.per_tensor_affine``
    ``torch.per_tensor_symmetric``
    ``torch.per_channel_affine``
    ``torch.per_channel_symmetric``

##### Generate a random tensor

In [3]:
t = torch.randn(256, 256, 9)
logging.info('||float_tensor|| = {}'.format(tl.norm(t)))

dtype = torch.qint8

##### Per channel  quantization

In [4]:
for qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
    print("\nPer channel quantization, dtype: {}, qscheme: {}".format(dtype, qscheme))
    
    for dim in range(len(t.shape)):
        qt, scale, zero_point = quantize_qint(t,\
                                              dtype,\
                                              qscheme,\
                                              dim = dim,\
                                              return_scale_zeropoint=True)

        print('Per dim {}, ||float_tensor - quant_tensor|| = {}'.format(dim, tl.norm(t - qt)))


Per channel quantization, dtype: torch.qint8, qscheme: torch.per_channel_affine
Per dim 0, ||float_tensor - quant_tensor|| = 6.027401924133301
Per dim 1, ||float_tensor - quant_tensor|| = 6.024443626403809
Per dim 2, ||float_tensor - quant_tensor|| = 7.571700572967529

Per channel quantization, dtype: torch.qint8, qscheme: torch.per_channel_symmetric
Per dim 0, ||float_tensor - quant_tensor|| = 6.364619255065918
Per dim 1, ||float_tensor - quant_tensor|| = 6.355539321899414
Per dim 2, ||float_tensor - quant_tensor|| = 7.846454620361328


##### Per tensor quantization

In [5]:
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
    print("\nPer tensor quantization, dtype: {}, qscheme: {}".format(dtype, qscheme))

    
    qt, scale, zero_point = quantize_qint(t,\
                                          dtype,\
                                          qscheme,\
                                          dim = dim,\
                                          return_scale_zeropoint=True)
    print('Per tensor, ||float_tensor - quant_tensor|| = {}'.format(tl.norm(t - qt)))


Per tensor quantization, dtype: torch.qint8, qscheme: torch.per_tensor_affine
Per tensor, ||float_tensor - quant_tensor|| = 8.609950065612793

Per tensor quantization, dtype: torch.qint8, qscheme: torch.per_tensor_symmetric
Per tensor, ||float_tensor - quant_tensor|| = 9.130134582519531


# Example 2. Quantization of a tensor in Kruskal format:
    a) via quantization of the corresponding full tensor
    b) via quantization of decomposition factors.

##### Generate tensor 

In [6]:
rank = 16
shape = (64, 64, 9)

factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))

||float_factors||: 756.0740356445312


##### Choose quantization scheme

In [7]:
dtype = torch.qint8

## Per tensor quantization
qscheme, dim = torch.per_tensor_affine, None

## Uncomment for per channel quantization
# qscheme, dim = torch.per_channel_affine, 0

##### a) Quantize the full tensor

In [8]:
t_quant = quantize_qint(t, dtype, qscheme, dim = dim)
print('||float_factors - float_factors_quantized|| = {}'.format(tl.norm(t - t_quant)))

||float_factors - float_factors_quantized|| = 10.60978889465332


##### b) Quantize several factors

In [9]:
num_factors = len(factors)
for num_quant_factors in range(1, num_factors + 1):
    
    qfactors = [quantize_qint(factors[i], dtype, qscheme, dim = dim)\
                for i in range(num_quant_factors)\
               ] + [factors[i] for i in range(num_quant_factors, num_factors)]

    qkrt = KruskalTensor((weights, qfactors))
    qt = kruskal_to_tensor(qkrt)
    print('\n[{}/{}] factors are quantized'.format(num_quant_factors, num_factors))
    print('||quant_factors - float_factors|| = {}'.format(tl.norm(qt - t)))
#     print('||quant_factors - float_factors_quantized|| = {}'.format(tl.norm(qt - t_quant)))

#     qt_quant = quantize_qint(qt, dtype, qscheme, dim = dim)
#     print('\nquant_factors_quantized - t_quant_factors: {}'.format(tl.norm(qt_quant - qt)/tnorm))
    
#     print('||quant_factors_quantized - float_factors|| = {}'.format(tl.norm(qt_quant - t)/tnorm))
#     print('||quant_factors_quantized - float_factors_quantized|| = {}'.format(tl.norm(qt_quant - t_quant)/tnorm))



[1/3] factors are quantized
||quant_factors - float_factors|| = 5.461402893066406

[2/3] factors are quantized
||quant_factors - float_factors|| = 8.082331657409668

[3/3] factors are quantized
||quant_factors - float_factors|| = 10.411945343017578


In [10]:
import numpy as np
16/2, 16/3, np.sqrt(16)

(8.0, 5.333333333333333, 4.0)

# Example 3. Quantized ALS
Compare standard ALS algorithm for finding  CP decomposition with its quantized version, when at the end of each ALS step approximated factor is quantized.

##### Generate tensor

In [18]:
rank = 100
shape = (128, 128, 9)

factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t_original = kruskal_to_tensor(krt)
t_original

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))

||float_factors||: 756.0740356445312


##### Find rank-1 float approximation

In [37]:
r_float = 0
stop_criterion = 'rec_error_decrease'
qstop_criterion = 'rec_error_decrease'
init = 'svd'
random_state = 1110
tol = 1e-18
n_iter_max = 50000
normalize_factors = True
device = 'cuda'

In [38]:
t = t_original.to(device)
r = r_float
if r > 0:
    out = parafac(t, r,\
                  n_iter_max=n_iter_max,\
                  init=init,\
                  svd='numpy_svd',\
                  normalize_factors=normalize_factors,
                  orthogonalise=False,\
                  tol=tol, random_state=random_state,\
                  verbose=0,\
                  return_errors=False,\
                  non_negative=False,\
                  mask=None,\
                  stop_criterion = stop_criterion)

    t_approx = kruskal_to_tensor(out)
    print(tl.norm(t - t_approx)) 
else:
    t_approx = 0
    
t_residual = t - t_approx

##### Find an approximation using ALS

In [39]:
t = t_residual.to(device)
r = rank - r_float
out = parafac(t, r,\
              n_iter_max=n_iter_max,\
              init=init,\
              svd='numpy_svd',\
              normalize_factors=normalize_factors,\
              orthogonalise=False,\
              tol=tol,\
              random_state=random_state,\
              verbose=0,\
              return_errors=False,\
              non_negative=False,\
              mask=None,
              stop_criterion = stop_criterion)

t_residual_approx_als = kruskal_to_tensor(out) 
print(tl.norm(t - t_residual_approx_als))  
print(tl.norm(t - t_approx - t_residual_approx_als))

tensor(0.0015, device='cuda:0')
tensor(0.0015, device='cuda:0')


In [40]:
tl.norm(t)

tensor(3884.5942, device='cuda:0')

##### Find an approximation using quantized ALS

In [75]:
dtype = torch.qint8

## Per tensor quantization
qscheme, dim = torch.per_tensor_affine, None

## Uncomment for per channel quantization
# qscheme, dim = torch.per_channel_affine, 0

In [34]:
t = t_residual.to(device)
r = rank - r_float
normalize_factors = True
qout, qerrors, scales, zero_points = quantized_parafac(
                                    t, r,n_iter_max=n_iter_max,\
                                    init=init, tol= tol, svd = 'numpy_svd',\
                                    normalize_factors = normalize_factors,\
                                    qmodes = [0, 1, 2],
                                    qscheme = qscheme, dtype = dtype, dim = dim,
                                    return_scale_zeropoint=True,\
                                    verbose = True,\
                                    random_state=random_state,
                                    stop_criterion = qstop_criterion)

t_residual_approx_qals = kruskal_to_tensor(qout)
print(tl.norm(t_residual - t_residual_approx_qals))
print(tl.norm(t - t_approx - t_residual_approx_qals))





reconstruction error=0.8406111001968384
iteration 1,  reconstraction error: 0.6533171534538269, decrease = 0.18729394674301147, unnormalized = 2537.8720703125
iteration 2,  reconstraction error: 0.564849317073822, decrease = 0.08846783638000488, unnormalized = 2194.21044921875
iteration 3,  reconstraction error: 0.5103333592414856, decrease = 0.054515957832336426, unnormalized = 1982.43798828125
iteration 4,  reconstraction error: 0.46816015243530273, decrease = 0.04217320680618286, unnormalized = 1818.6123046875
iteration 5,  reconstraction error: 0.43217477202415466, decrease = 0.03598538041114807, unnormalized = 1678.82373046875
iteration 6,  reconstraction error: 0.39820757508277893, decrease = 0.03396719694137573, unnormalized = 1546.8748779296875
iteration 7,  reconstraction error: 0.3651961386203766, decrease = 0.033011436462402344, unnormalized = 1418.638916015625
iteration 8,  reconstraction error: 0.3344716429710388, decrease = 0.03072449564933777, unnormalized = 1299.2866210

iteration 67,  reconstraction error: 0.06140148267149925, decrease = 1.4156103134155273e-07, unnormalized = 238.5198516845703
iteration 68,  reconstraction error: 0.06140134856104851, decrease = 1.341104507446289e-07, unnormalized = 238.5193328857422
iteration 69,  reconstraction error: 0.061401233077049255, decrease = 1.1548399925231934e-07, unnormalized = 238.51889038085938
iteration 70,  reconstraction error: 0.0614011213183403, decrease = 1.1175870895385742e-07, unnormalized = 238.51844787597656
iteration 71,  reconstraction error: 0.06140100955963135, decrease = 1.1175870895385742e-07, unnormalized = 238.5180206298828
iteration 72,  reconstraction error: 0.061400916427373886, decrease = 9.313225746154785e-08, unnormalized = 238.5176544189453
iteration 73,  reconstraction error: 0.06140083074569702, decrease = 8.568167686462402e-08, unnormalized = 238.51731872558594
iteration 74,  reconstraction error: 0.061400752514600754, decrease = 7.82310962677002e-08, unnormalized = 238.517013

iteration 132,  reconstraction error: 0.06139881908893585, decrease = 1.862645149230957e-08, unnormalized = 238.50950622558594
iteration 133,  reconstraction error: 0.061398785561323166, decrease = 3.3527612686157227e-08, unnormalized = 238.50938415527344
iteration 134,  reconstraction error: 0.061398763209581375, decrease = 2.2351741790771484e-08, unnormalized = 238.50929260253906
iteration 135,  reconstraction error: 0.061398740857839584, decrease = 2.2351741790771484e-08, unnormalized = 238.5092010498047
iteration 136,  reconstraction error: 0.061398718506097794, decrease = 2.2351741790771484e-08, unnormalized = 238.5091094970703
iteration 137,  reconstraction error: 0.061398692429065704, decrease = 2.60770320892334e-08, unnormalized = 238.50901794433594
iteration 138,  reconstraction error: 0.06139867380261421, decrease = 1.862645149230957e-08, unnormalized = 238.50894165039062
iteration 139,  reconstraction error: 0.061398640275001526, decrease = 3.3527612686157227e-08, unnormaliz

iteration 198,  reconstraction error: 0.06139707937836647, decrease = 2.60770320892334e-08, unnormalized = 238.50274658203125
iteration 199,  reconstraction error: 0.061397045850753784, decrease = 3.3527612686157227e-08, unnormalized = 238.50262451171875
iteration 200,  reconstraction error: 0.06139702722430229, decrease = 1.862645149230957e-08, unnormalized = 238.50254821777344
iteration 201,  reconstraction error: 0.061396997421979904, decrease = 2.9802322387695312e-08, unnormalized = 238.50242614746094
iteration 202,  reconstraction error: 0.06139696016907692, decrease = 3.725290298461914e-08, unnormalized = 238.50228881835938
iteration 203,  reconstraction error: 0.06139693409204483, decrease = 2.60770320892334e-08, unnormalized = 238.50218200683594
iteration 204,  reconstraction error: 0.06139690428972244, decrease = 2.9802322387695312e-08, unnormalized = 238.5020751953125
iteration 205,  reconstraction error: 0.06139687821269035, decrease = 2.60770320892334e-08, unnormalized = 23

iteration 263,  reconstraction error: 0.06139492616057396, decrease = 4.470348358154297e-08, unnormalized = 238.494384765625
iteration 264,  reconstraction error: 0.06139489263296127, decrease = 3.3527612686157227e-08, unnormalized = 238.49424743652344
iteration 265,  reconstraction error: 0.06139485538005829, decrease = 3.725290298461914e-08, unnormalized = 238.49411010742188
iteration 266,  reconstraction error: 0.06139481067657471, decrease = 4.470348358154297e-08, unnormalized = 238.4939422607422
iteration 267,  reconstraction error: 0.06139477342367172, decrease = 3.725290298461914e-08, unnormalized = 238.49378967285156
iteration 268,  reconstraction error: 0.06139473617076874, decrease = 3.725290298461914e-08, unnormalized = 238.49365234375
iteration 269,  reconstraction error: 0.06139469891786575, decrease = 3.725290298461914e-08, unnormalized = 238.49349975585938
iteration 270,  reconstraction error: 0.06139465048909187, decrease = 4.842877388000488e-08, unnormalized = 238.4933

iteration 329,  reconstraction error: 0.06139184162020683, decrease = 5.960464477539063e-08, unnormalized = 238.48240661621094
iteration 330,  reconstraction error: 0.06139178201556206, decrease = 5.960464477539063e-08, unnormalized = 238.482177734375
iteration 331,  reconstraction error: 0.06139172613620758, decrease = 5.587935447692871e-08, unnormalized = 238.48194885253906
iteration 332,  reconstraction error: 0.061391670256853104, decrease = 5.587935447692871e-08, unnormalized = 238.4817352294922
iteration 333,  reconstraction error: 0.06139161065220833, decrease = 5.960464477539063e-08, unnormalized = 238.48150634765625
iteration 334,  reconstraction error: 0.061391547322273254, decrease = 6.332993507385254e-08, unnormalized = 238.48126220703125
iteration 335,  reconstraction error: 0.06139148399233818, decrease = 6.332993507385254e-08, unnormalized = 238.48101806640625
iteration 336,  reconstraction error: 0.061391428112983704, decrease = 5.587935447692871e-08, unnormalized = 238

iteration 394,  reconstraction error: 0.06138681247830391, decrease = 1.043081283569336e-07, unnormalized = 238.46286010742188
iteration 395,  reconstraction error: 0.06138670817017555, decrease = 1.043081283569336e-07, unnormalized = 238.46246337890625
iteration 396,  reconstraction error: 0.061386603862047195, decrease = 1.043081283569336e-07, unnormalized = 238.46205139160156
iteration 397,  reconstraction error: 0.06138649210333824, decrease = 1.1175870895385742e-07, unnormalized = 238.4616241455078
iteration 398,  reconstraction error: 0.061386384069919586, decrease = 1.0803341865539551e-07, unnormalized = 238.46119689941406
iteration 399,  reconstraction error: 0.06138627603650093, decrease = 1.0803341865539551e-07, unnormalized = 238.46078491210938
iteration 400,  reconstraction error: 0.06138616427779198, decrease = 1.1175870895385742e-07, unnormalized = 238.46034240722656
iteration 401,  reconstraction error: 0.06138605624437332, decrease = 1.0803341865539551e-07, unnormalized

iteration 459,  reconstraction error: 0.06137663871049881, decrease = 2.3096799850463867e-07, unnormalized = 238.42333984375
iteration 460,  reconstraction error: 0.061376407742500305, decrease = 2.3096799850463867e-07, unnormalized = 238.42245483398438
iteration 461,  reconstraction error: 0.0613761730492115, decrease = 2.3469328880310059e-07, unnormalized = 238.42153930664062
iteration 462,  reconstraction error: 0.0613759346306324, decrease = 2.384185791015625e-07, unnormalized = 238.4206085205078
iteration 463,  reconstraction error: 0.0613756962120533, decrease = 2.384185791015625e-07, unnormalized = 238.419677734375
iteration 464,  reconstraction error: 0.0613754466176033, decrease = 2.4959444999694824e-07, unnormalized = 238.41871643066406
iteration 465,  reconstraction error: 0.061375200748443604, decrease = 2.4586915969848633e-07, unnormalized = 238.41775512695312
iteration 466,  reconstraction error: 0.06137494742870331, decrease = 2.5331974029541016e-07, unnormalized = 238.4

iteration 524,  reconstraction error: 0.061352748423814774, decrease = 5.550682544708252e-07, unnormalized = 238.33053588867188
iteration 525,  reconstraction error: 0.06135218217968941, decrease = 5.662441253662109e-07, unnormalized = 238.32833862304688
iteration 526,  reconstraction error: 0.06135160103440285, decrease = 5.811452865600586e-07, unnormalized = 238.32608032226562
iteration 527,  reconstraction error: 0.061351023614406586, decrease = 5.774199962615967e-07, unnormalized = 238.32383728027344
iteration 528,  reconstraction error: 0.06135042384266853, decrease = 5.997717380523682e-07, unnormalized = 238.32151794433594
iteration 529,  reconstraction error: 0.061349812895059586, decrease = 6.109476089477539e-07, unnormalized = 238.3191375732422
iteration 530,  reconstraction error: 0.06134919449687004, decrease = 6.183981895446777e-07, unnormalized = 238.31674194335938
iteration 531,  reconstraction error: 0.0613485611975193, decrease = 6.332993507385254e-07, unnormalized = 23

iteration 589,  reconstraction error: 0.04694616422057152, decrease = 5.289912223815918e-07, unnormalized = 182.36680603027344
iteration 590,  reconstraction error: 0.04694565013051033, decrease = 5.140900611877441e-07, unnormalized = 182.36480712890625
iteration 591,  reconstraction error: 0.04694513976573944, decrease = 5.103647708892822e-07, unnormalized = 182.36282348632812
iteration 592,  reconstraction error: 0.04694463685154915, decrease = 5.029141902923584e-07, unnormalized = 182.36087036132812
iteration 593,  reconstraction error: 0.04694413021206856, decrease = 5.066394805908203e-07, unnormalized = 182.35890197753906
iteration 594,  reconstraction error: 0.04694361239671707, decrease = 5.178153514862061e-07, unnormalized = 182.3568878173828
iteration 595,  reconstraction error: 0.04694307968020439, decrease = 5.327165126800537e-07, unnormalized = 182.35482788085938
iteration 596,  reconstraction error: 0.04694253206253052, decrease = 5.476176738739014e-07, unnormalized = 182.

iteration 653,  reconstraction error: 9.736553693073802e-06, decrease = 9.76232513494324e-07, unnormalized = 0.03782256320118904
iteration 654,  reconstraction error: 8.873859769664705e-06, decrease = 8.626939234090969e-07, unnormalized = 0.03447134420275688
iteration 655,  reconstraction error: 8.106047971523367e-06, decrease = 7.678117981413379e-07, unnormalized = 0.03148870915174484
iteration 656,  reconstraction error: 7.421967438858701e-06, decrease = 6.84080532664666e-07, unnormalized = 0.02883133292198181
iteration 657,  reconstraction error: 6.8083786572969984e-06, decrease = 6.135887815617025e-07, unnormalized = 0.02644778974354267
iteration 658,  reconstraction error: 6.25775646767579e-06, decrease = 5.506221896212082e-07, unnormalized = 0.02430884540081024
iteration 659,  reconstraction error: 5.762326964031672e-06, decrease = 4.954295036441181e-07, unnormalized = 0.02238430269062519
iteration 660,  reconstraction error: 5.312694156600628e-06, decrease = 4.4963280743104406e-

iteration 3,  reconstraction error: 0.008281714282929897, decrease = 9.313225746154785e-10, unnormalized = 32.17110061645508
iteration 4,  reconstraction error: 0.008281714282929897, decrease = 0.0, unnormalized = 32.17110061645508
PARAFAC has stopped after iteration 4
reconstruction error=0.012153036892414093
iteration 1,  reconstraction error: 0.012153036892414093, decrease = 0.0, unnormalized = 47.209617614746094
PARAFAC has stopped after iteration 1
tensor(51.8905, device='cuda:0')
tensor(51.8905, device='cuda:0')


In [None]:
qerrors

In [35]:
qqfactors = [quantize_qint(out.factors[i].cpu(), dtype, qscheme, dim = dim).to(device)\
                for i in range(len(out.factors))]

qqkrt = KruskalTensor((out.weights, qqfactors))

# corresponding tensor in full format
qqt = kruskal_to_tensor(qqkrt)

print('||float_qqfactors||: {}'.format(tl.norm(t - qqt)))

||float_qqfactors||: 244.42758178710938


In [36]:
scales, zero_points

((tensor(0.0027), tensor(0.0027), tensor(0.0067)),
 (tensor(-12, dtype=torch.int32),
  tensor(-17, dtype=torch.int32),
  tensor(-4, dtype=torch.int32)))

In [26]:
qout.factors[0]

tensor([[ 0.0160, -0.0748,  0.1870,  ..., -0.0134,  0.0240, -0.0214],
        [-0.0962,  0.1309, -0.1309,  ...,  0.0080,  0.1549,  0.0935],
        [-0.0454,  0.1015,  0.2004,  ..., -0.2591, -0.0080, -0.0187],
        ...,
        [ 0.0614,  0.0347,  0.0107,  ..., -0.1496,  0.0695,  0.0427],
        [ 0.0561, -0.0588, -0.0374,  ...,  0.0027, -0.1229,  0.0427],
        [-0.0695,  0.1229,  0.1309,  ...,  0.0588, -0.0721,  0.0347]],
       device='cuda:0')

In [27]:
qout.factors[0]/scales[0]

tensor([[  6.0000, -28.0000,  70.0000,  ...,  -5.0000,   9.0000,  -8.0000],
        [-36.0000,  49.0000, -49.0000,  ...,   3.0000,  58.0000,  35.0000],
        [-17.0000,  38.0000,  75.0000,  ..., -97.0000,  -3.0000,  -7.0000],
        ...,
        [ 23.0000,  13.0000,   4.0000,  ..., -56.0000,  26.0000,  16.0000],
        [ 21.0000, -22.0000, -14.0000,  ...,   1.0000, -46.0000,  16.0000],
        [-26.0000,  46.0000,  49.0000,  ...,  22.0000, -27.0000,  13.0000]],
       device='cuda:0')

In [28]:
qout.factors[1]/scales[1]

tensor([[  0.0000, -14.0000,   4.0000,  ...,  20.0000, -20.0000,  33.0000],
        [-18.0000,   1.0000, -13.0000,  ...,  36.0000,   7.0000, -16.0000],
        [ 40.0000, -64.0000,  21.0000,  ..., -19.0000,  -7.0000, -14.0000],
        ...,
        [ 30.0000,  21.0000,   1.0000,  ..., -12.0000,  -4.0000,  23.0000],
        [ 58.0000,  38.0000,  17.0000,  ...,  -3.0000, -35.0000,  53.0000],
        [-39.0000, -19.0000,  54.0000,  ..., -14.0000, -11.0000, -33.0000]],
       device='cuda:0')

##### Try to quantize 2 of 3 factors

In [30]:
qt, _ = quantized_parafac(
                        t.to(device), rank, n_iter_max=n_iter_max,\
                        init=init, tol= tol, svd = 'numpy_svd',\
                        normalize_factors = normalize_factors,\
                        qmodes = [0, 1],
                        qscheme = qscheme, dtype = dtype, dim = dim,
                        return_scale_zeropoint=False, random_state = random_state,\
                        stop_criterion = qstop_criterion)

In [31]:
tl.norm(t - kruskal_to_tensor(qt))

tensor(47.6132, device='cuda:0')