In [1]:
import tensorly as tl
from tensorly.decomposition import parafac, quantized_parafac
from tensorly.kruskal_tensor import kruskal_to_tensor, KruskalTensor
from tensorly.base import unfold
from tensorly.quantization import quantize_qint

import torch
tl.set_backend('pytorch')

# Example 1. Tensor quantization
We use PyTorch build-in tools to perform quantization.

Quantization scheme can be either affine or symmetric.

Scale and zero_point values to perform quantization are computed either per channel or per tensor (i.e. we get either vectors or scalars).

Thus, there are 4 types of quantization scheme:

    ``torch.per_tensor_affine``
    ``torch.per_tensor_symmetric``
    ``torch.per_channel_affine``
    ``torch.per_channel_symmetric``

##### Generate a random tensor

In [2]:
t = torch.randn(256, 256, 9)
print('||float_tensor|| = {}'.format(tl.norm(t)))

dtype = torch.qint8

||float_tensor|| = 767.9274291992188


##### Per channel  quantization

In [3]:
for qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
    print("\nPer channel quantization, dtype: {}, qscheme: {}".format(dtype, qscheme))
    
    for dim in range(len(t.shape)):
        qt, scale, zero_point = quantize_qint(t,\
                                              dtype,\
                                              qscheme,\
                                              dim = dim,\
                                              return_scale_zeropoint=True)

        print('Per dim {}, ||float_tensor - quant_tensor|| = {}'.format(dim, tl.norm(t - qt)))


Per channel quantization, dtype: torch.qint8, qscheme: torch.per_channel_affine
Per dim 0, ||float_tensor - quant_tensor|| = 7.307478427886963
Per dim 1, ||float_tensor - quant_tensor|| = 7.122038841247559
Per dim 2, ||float_tensor - quant_tensor|| = 7.646827220916748

Per channel quantization, dtype: torch.qint8, qscheme: torch.per_channel_symmetric
Per dim 0, ||float_tensor - quant_tensor|| = 7.307478427886963
Per dim 1, ||float_tensor - quant_tensor|| = 7.122038841247559
Per dim 2, ||float_tensor - quant_tensor|| = 7.646827220916748


##### Per tensor quantization

In [4]:
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
    print("\nPer tensor quantization, dtype: {}, qscheme: {}".format(dtype, qscheme))

    
    qt, scale, zero_point = quantize_qint(t,\
                                          dtype,\
                                          qscheme,\
                                          dim = dim,\
                                          return_scale_zeropoint=True)
    print('Per tensor, ||float_tensor - quant_tensor|| = {}'.format(tl.norm(t - qt)))


Per tensor quantization, dtype: torch.qint8, qscheme: torch.per_tensor_affine
Per tensor, ||float_tensor - quant_tensor|| = 9.011683464050293

Per tensor quantization, dtype: torch.qint8, qscheme: torch.per_tensor_symmetric
Per tensor, ||float_tensor - quant_tensor|| = 9.011683464050293


# Example 2. Quantization of a tensor in Kruskal format:
    a) via quantization of the corresponding full tensor
    b) via quantization of decomposition factors.

##### Generate tensor 

In [5]:
rank = 16
shape = (64, 64, 9)

factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t = kruskal_to_tensor(krt)

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))

||float_factors||: 713.4921875


##### Choose quantization scheme

In [6]:
dtype = torch.qint8

## Per tensor quantization
qscheme, dim = torch.per_tensor_affine, None

## Uncomment for per channel quantization
# qscheme, dim = torch.per_channel_affine, 0

##### a) Quantize the full tensor

In [7]:
t_quant = quantize_qint(t, dtype, qscheme, dim = dim)
print('||float_factors - float_factors_quantized|| = {}'.format(tl.norm(t - t_quant)))

||float_factors - float_factors_quantized|| = 10.873172760009766


##### b) Quantize several factors

In [8]:
num_factors = len(factors)
for num_quant_factors in range(1, num_factors + 1):
    
    qfactors = [quantize_qint(factors[i], dtype, qscheme, dim = dim)\
                for i in range(num_quant_factors)\
               ] + [factors[i] for i in range(num_quant_factors, num_factors)]

    qkrt = KruskalTensor((weights, qfactors))
    qt = kruskal_to_tensor(qkrt)
    print('\n[{}/{}] factors are quantized'.format(num_quant_factors, num_factors))
    print('||quant_factors - float_factors|| = {}'.format(tl.norm(qt - t)))
#     print('||quant_factors - float_factors_quantized|| = {}'.format(tl.norm(qt - t_quant)))

#     qt_quant = quantize_qint(qt, dtype, qscheme, dim = dim)
#     print('\nquant_factors_quantized - t_quant_factors: {}'.format(tl.norm(qt_quant - qt)/tnorm))
    
#     print('||quant_factors_quantized - float_factors|| = {}'.format(tl.norm(qt_quant - t)/tnorm))
#     print('||quant_factors_quantized - float_factors_quantized|| = {}'.format(tl.norm(qt_quant - t_quant)/tnorm))



[1/3] factors are quantized
||quant_factors - float_factors|| = 12.959274291992188

[2/3] factors are quantized
||quant_factors - float_factors|| = 14.07798957824707

[3/3] factors are quantized
||quant_factors - float_factors|| = 25.923023223876953


# Example 3. Quantized ALS
Compare standard ALS algorithm for finding  CP decomposition with its quantized version, when at the end of each ALS step approximated factor is quantized.

In [45]:
### All in one cell
rank = 12
shape = (128, 128, 9)
stop_criterion = 'rec_error_decrease'
init = 'random'
niter_max = 1000

r_float = 0

## Quantization parameters
dtype = torch.qint8
qscheme, dim = torch.per_tensor_affine, None

## Generate tensor
factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t_original = kruskal_to_tensor(krt)
t = t_original

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))


## Perform als, post als factors quantization,  quantized als
for _ in range(10):
    random_state = int(torch.randint(high = 11999, size = (1,)))
    print('\n=============== random state {} =================='.format(random_state))

    ## ALS with float factors
    t = t_original
    r = rank - r_float
    out = parafac(t, r,\
                  n_iter_max=niter_max, init=init, svd='numpy_svd',\
                  normalize_factors=False, orthogonalise=False,\
                  tol=None, random_state=random_state,\
                  verbose=0, return_errors=False,\
                  non_negative=False, mask=None,
                  stop_criterion = stop_criterion)

    t_approx_als = kruskal_to_tensor(out) 
    print('||float_ffactors||: {}'.format(tl.norm(t - t_approx_als)))

    ## Post ALS factors quantization
    qqfactors = [quantize_qint(out.factors[i], dtype, qscheme, dim = dim)\
                    for i in range(len(out.factors))]
    qqkrt = KruskalTensor((out.weights, qqfactors))
    # corresponding tensor in full format
    qqt = kruskal_to_tensor(qqkrt)
    print('||float_qqfactors||: {}'.format(tl.norm(t - qqt)))


    ## Quantized ALS
    t = t_original
    r = rank - r_float
    normalize_factors = False
    (fout, qout), (errors, qerrors), scales, zero_points = quantized_parafac(
                                        t, r, n_iter_max=1000,\
                                        init=init, tol= None, svd = 'numpy_svd',\
                                        normalize_factors = normalize_factors,\
                                        qmodes = [0, 1, 2],
                                        warmup_iters = 1,
                                        freeze_every = 1,
                                        qscheme = qscheme, dtype = dtype, dim = dim,
                                        return_scale_zeropoint=True, verbose = False,\
                                        random_state=random_state,\
                                        return_rec_errors=False, return_qrec_errors=False,
                                        stop_criterion = stop_criterion)

    t_approx_qals = kruskal_to_tensor(KruskalTensor(qout))
    print('||float_qfactors||: {}'.format(tl.norm(t - t_approx_qals)))




||float_factors||: 1578.495849609375

||float_ffactors||: 0.0004248009354341775
||float_qqfactors||: 38.49079513549805
||float_qfactors||: 26.33399772644043

||float_ffactors||: 0.0004599247477017343
||float_qqfactors||: 27.231096267700195
||float_qfactors||: 290.12408447265625

||float_ffactors||: 0.0003965498472098261
||float_qqfactors||: 48.92815017700195
||float_qfactors||: 27.728422164916992

||float_ffactors||: 0.00039850908797234297
||float_qqfactors||: 41.202423095703125
||float_qfactors||: 19.78460121154785

||float_ffactors||: 0.0004433225258253515
||float_qqfactors||: 35.17451858520508
||float_qfactors||: 289.84814453125

||float_ffactors||: 0.00046504897181876004
||float_qqfactors||: 28.925840377807617
||float_qfactors||: 27.65139389038086

||float_ffactors||: 288.95501708984375
||float_qqfactors||: 290.4665832519531
||float_qfactors||: 20.928878784179688

||float_ffactors||: 0.0004377279838081449
||float_qqfactors||: 52.634944915771484
||float_qfactors||: 293.1118469238281

##### Generate tensor

In [9]:
rank = 12
shape = (128, 128, 9)

factors = [torch.randn((i, rank)) for i in shape] 
weights = torch.ones(rank)

# tensor in Kruscal format
krt = KruskalTensor((weights, factors))

# corresponding tensor in full format
t_original = kruskal_to_tensor(krt)

tnorm = tl.norm(t)
print('||float_factors||: {}'.format(tnorm))

||float_factors||: 713.4921875


##### Find rank-1 float approximation

In [10]:
r_float = 0
stop_criterion = 'rec_error_decrease'
qstop_criterion = 'rec_error_decrease'
init = 'svd'
random_state = 1110

In [11]:
t = t_original
r = r_float
if r > 0:
    out = parafac(t, r,\
                  n_iter_max=1000, init=init, svd='numpy_svd',\
                  normalize_factors=False, orthogonalise=False,\
                  tol=None, random_state=random_state,\
                  verbose=0, return_errors=False,\
                  non_negative=False, mask=None,
                  stop_criterion = stop_criterion)

    t_approx = kruskal_to_tensor(out)
    print(tl.norm(t - t_approx)) 
else:
    t_approx = 0
    
t_residual = t - t_approx

##### Find an approximation using ALS

In [12]:
t = t_residual
r = rank - r_float
out = parafac(t, r,\
              n_iter_max=1000, init=init, svd='numpy_svd',\
              normalize_factors=False, orthogonalise=False,\
              tol=None, random_state=random_state,\
              verbose=0, return_errors=False,\
              non_negative=False, mask=None,
              stop_criterion = stop_criterion)

t_residual_approx_als = kruskal_to_tensor(out) 
print(tl.norm(t - t_residual_approx_als))  
print(tl.norm(t - t_approx - t_residual_approx_als))

tensor(0.0004)
tensor(0.0004)


##### Find an approximation using quantized ALS

In [13]:
dtype = torch.qint8

## Per tensor quantization
qscheme, dim = torch.per_tensor_affine, None

## Uncomment for per channel quantization
# qscheme, dim = torch.per_channel_affine, 0

In [14]:
t = t_residual
r = rank - r_float
normalize_factors = False
(fout, qout), (errors, qerrors), scales, zero_points = quantized_parafac(
                                    t, r, n_iter_max=1000,\
                                    init=init, tol= None, svd = 'numpy_svd',\
                                    normalize_factors = normalize_factors,\
                                    qmodes = [0, 1, 2],
                                    warmup_iters = 1,
                                    freeze_every = 1,
                                    qscheme = qscheme, dtype = dtype, dim = dim,
                                    return_scale_zeropoint=True,\
                                    return_rec_errors = True,\
                                    return_qrec_errors=True,\
                                    verbose = True,\
                                    random_state=random_state,
                                    stop_criterion = stop_criterion)

t_residual_approx_qals = kruskal_to_tensor(qout)
print(tl.norm(t_residual - t_residual_approx_qals))
print(tl.norm(t - t_approx - t_residual_approx_qals))





In parafac norm_tensor = 1322.6064453125
float factors reconstruction error=0.7177299857139587
quantized factors reconstruction error=0.9999871709581504
iteration 100, float factors reconstraction error: 0.03975190967321396, decrease = -0.003070201724767685, unnormalized = 52.576133728027344
iteration 100, quantized factors reconstraction error: 0.02207577133229692, decrease = -2.884227034566367e-07, unnormalized = 29.19755744934082
iteration 200, float factors reconstraction error: 0.03975910320878029, decrease = -0.003073502331972122, unnormalized = 52.58564376831055
iteration 200, quantized factors reconstraction error: 0.022075778542864504, decrease = -2.985174980749128e-07, unnormalized = 29.197566986083984
iteration 300, float factors reconstraction error: 0.03971414640545845, decrease = -0.0029993392527103424, unnormalized = 52.52618408203125
iteration 300, quantized factors reconstraction error: 0.02207578719554561, decrease = -3.187070873184039e-07, unnormalized = 29.197578430

In [17]:
errors

(tensor(0.7870),
 tensor(0.7881),
 tensor(0.6288),
 tensor(0.5085),
 tensor(0.4267),
 tensor(0.3649),
 tensor(0.3428),
 tensor(0.3297),
 tensor(0.2887),
 tensor(0.2595),
 tensor(0.2317),
 tensor(0.2136),
 tensor(0.2111),
 tensor(0.1964),
 tensor(0.1941),
 tensor(0.1880),
 tensor(0.1836),
 tensor(0.1848),
 tensor(0.1868),
 tensor(0.1859),
 tensor(0.1835),
 tensor(0.1832),
 tensor(0.1858),
 tensor(0.1853),
 tensor(0.1888),
 tensor(0.1869),
 tensor(0.1862),
 tensor(0.1862),
 tensor(0.1861),
 tensor(0.1867),
 tensor(0.1837),
 tensor(0.1862),
 tensor(0.1861),
 tensor(0.1853),
 tensor(0.1874),
 tensor(0.1857),
 tensor(0.1863),
 tensor(0.1858),
 tensor(0.1880),
 tensor(0.1858),
 tensor(0.1870),
 tensor(0.1872),
 tensor(0.1860),
 tensor(0.1871),
 tensor(0.1872),
 tensor(0.1868),
 tensor(0.1869),
 tensor(0.1869),
 tensor(0.1869),
 tensor(0.1869),
 tensor(0.1869),
 tensor(0.1869),
 tensor(0.1869),
 tensor(0.1868),
 tensor(0.1869),
 tensor(0.1868),
 tensor(0.1869),
 tensor(0.1868),
 tensor(0.1869

In [18]:
qerrors

(1.0129274650840132,
 0.7892980153660368,
 0.629156786677094,
 0.5096062751091962,
 0.42674935520532004,
 0.36507604699952506,
 0.34239840430556934,
 0.32995667627454206,
 0.28811078922120226,
 0.2600321850128048,
 0.2321932555660405,
 0.2162646889789725,
 0.20802160221866167,
 0.20069085405829057,
 0.1940647277272738,
 0.18861278802978557,
 0.18535246266936414,
 0.1845577644654968,
 0.18442594538084364,
 0.18434355845293543,
 0.18430631142454038,
 0.1842818801404866,
 0.18426371433144412,
 0.18426446700214352,
 0.1842547229679539,
 0.1842535227633251,
 0.1842383269521776,
 0.18423346510630842,
 0.18421869648663897,
 0.1842343601741672,
 0.18422268360710067,
 0.18421257340878702,
 0.18422386346927813,
 0.184219184705471,
 0.1842086473156793,
 0.18420777259027188,
 0.18421236998427365,
 0.1842139363530265,
 0.18421251238143302,
 0.18421499416049592,
 0.18420992889011342,
 0.18420998991746743,
 0.18421220724466297,
 0.18420651135828908,
 0.18421224792956564,
 0.18420921690431669,
 0.1842

In [15]:
qqfactors = [quantize_qint(out.factors[i], dtype, qscheme, dim = dim)\
                for i in range(len(out.factors))]

qqkrt = KruskalTensor((out.weights, qqfactors))

# corresponding tensor in full format
qqt = kruskal_to_tensor(qqkrt)

print('||float_qqfactors||: {}'.format(tl.norm(t - qqt)))

||float_qqfactors||: 44.78242111206055


In [16]:
scales, zero_points

((tensor(0.1900), tensor(0.0138), tensor(0.0146)),
 (tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32),
  tensor(0, dtype=torch.int32)))

In [17]:
qout.factors[0]

tensor([[  5.3201,  11.7802,  -2.6600,  ...,   2.2800,  -0.3800,   0.0000],
        [ -2.2800,   2.4700,   1.3300,  ...,  -2.2800,  -0.5700,  -1.9000],
        [ -7.0301,  -6.6501,  -8.1701,  ...,  -2.6600,   1.1400,   0.3800],
        ...,
        [  5.5101,  -1.5200,  -0.7600,  ...,   9.8802,  -4.7501,  -0.3800],
        [  7.2201, -13.3002,   5.5101,  ...,  -0.9500,   2.4700,   0.7600],
        [ -1.1400,  -2.4700,  10.2602,  ...,  -4.1801,   2.2800,   0.3800]])

In [18]:
qout.factors[0]/scales[0]

tensor([[ 28.0000,  62.0000, -14.0000,  ...,  12.0000,  -2.0000,   0.0000],
        [-12.0000,  13.0000,   7.0000,  ..., -12.0000,  -3.0000, -10.0000],
        [-37.0000, -35.0000, -43.0000,  ..., -14.0000,   6.0000,   2.0000],
        ...,
        [ 29.0000,  -8.0000,  -4.0000,  ...,  52.0000, -25.0000,  -2.0000],
        [ 38.0000, -70.0000,  29.0000,  ...,  -5.0000,  13.0000,   4.0000],
        [ -6.0000, -13.0000,  54.0000,  ..., -22.0000,  12.0000,   2.0000]])

In [19]:
qout.factors[1]/scales[1]

tensor([[ -1.0000, -17.0000, -15.0000,  ...,  11.0000, -27.0000,  19.0000],
        [ -9.0000,   3.0000,  20.0000,  ..., -38.0000,   1.0000, -24.0000],
        [-21.0000,  39.0000,   0.0000,  ..., -11.0000, -24.0000, -38.0000],
        ...,
        [-15.0000,  -5.0000, -36.0000,  ...,  29.0000,  63.0000,  -2.0000],
        [-37.0000, -38.0000,  10.0000,  ..., -29.0000,  -6.0000,  41.0000],
        [-10.0000,  14.0000, -15.0000,  ...,   5.0000,   8.0000,   5.0000]])

##### Try to quantize 2 of 3 factors

In [None]:
normalize_factors = False
(ft, qt), _ = quantized_parafac(
                                    t, rank, n_iter_max=1001,\
                                    init=init, tol= False, svd = 'numpy_svd',\
                                    normalize_factors = normalize_factors,\
                                    qmodes = [0, 2],
                                    warmup_iters = 1,
                                    freeze_every = 1,
                                    qscheme = qscheme, dtype = dtype, dim = dim,
                                    return_scale_zeropoint=False, random_state = random_state)

In [44]:
tl.norm(t - kruskal_to_tensor(qt))

tensor(27.5654)