# Tensor decomposition using Autograd

In [8]:
import autograd.numpy as np
import tensorly
from tensorly.decomposition import parafac

In [9]:
t = np.arange(24).reshape(2, 3, 4).astype('float32')

In [10]:
t

array([[[  0.,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.]],

       [[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]]], dtype=float32)

In [11]:
A, B, C = parafac(t, rank=2)

In [12]:
A

array([[ 35.86886727,   4.33025058],
       [ 61.91392143,   1.1737134 ]])

In [13]:
B

array([[ 0.48010198, -1.33704111],
       [ 0.58706644, -0.98212501],
       [ 0.6940309 , -0.62720892]])

In [14]:
C

array([[ 0.48671098,  1.46524661],
       [ 0.51060759,  1.35143039],
       [ 0.53450419,  1.23761417],
       [ 0.5584008 ,  1.12379794]])

In [15]:
def khatri(A, B, C):
    return np.einsum('il,jl,kl->ijk',A,B,C)

In [16]:
khatri(A, B, C)

array([[[ -0.10186127,   0.96861972,   2.03910072,   3.10958171],
        [  4.01740161,   5.00464518,   5.99188875,   6.97913232],
        [  8.13666448,   9.04067063,   9.94467678,  10.84868293]],

       [[ 12.16806623,  13.05700484,  13.94594346,  14.83488207],
        [ 16.00173049,  17.0015141 ,  18.0012977 ,  19.0010813 ],
        [ 19.83539476,  20.94602335,  22.05665194,  23.16728053]]])

In [17]:
from autograd import multigrad

In [18]:
def cost(A, B, C):
    pred = khatri(A, B, C)
    gt = t
    mask = ~np.isnan(t)
    error = (pred-gt)[mask].flatten()
    return np.sqrt((error**2).mean())

In [19]:
mg = multigrad(cost, argnums=[0, 1, 2])

In [25]:
rank = 2
lr = 0.01

m, n, o = t.shape
a = np.random.randn(m, rank)
b = np.random.randn(n, rank)
c = np.random.randn(o, rank)

for i in range(6000):
    del_a, del_b, del_c = mg(a, b, c)
    a-=lr*del_a
    b-=lr*del_b
    c-=lr*del_c
    if i%100==0:
        print(cost(a, b, c))
    

13.3896087303
13.1403995288
12.6784145997
11.3836338072
6.17119881855
1.59367647945
1.57712522522
1.56148801277
1.54164416832
1.51193851273
1.46214715129
1.37100312768
1.19169887199
0.845201820036
0.500801838486
0.432763033457
0.405959724161
0.381344857275
0.358236500737
0.337202582543
0.318398281673
0.301684474737
0.286785494852
0.273390956387
0.261207363766
0.249978499386
0.239489810964
0.229565918191
0.220065957556
0.210878955224
0.201920118594
0.193128309099
0.18446465096
0.175912044612
0.167614292608
0.193258651365
0.194960439551
0.196512645059
0.19793277622
0.199234721931
0.200430564816
0.201531169287
0.20254637573
0.203485060521
0.204355160172
0.205163699619
0.205916838771
0.206619938584
0.207277641767
0.207893961044
0.208472368124
0.209015877975
0.209527124858
0.210008428301
0.210461848605
0.210889232381
0.211292249209
0.211672420708
0.212031143384
0.212369706492


In [26]:
np.round(khatri(a, b, c))

array([[[ -0.,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.]],

       [[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]]])

### Decomposing with missing entries

In [28]:
t[0, 0, 0] = np.NAN
t[1, 1, 2] = np.NAN
t

array([[[ nan,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.]],

       [[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  nan,  19.],
        [ 20.,  21.,  22.,  23.]]], dtype=float32)

In [29]:
rank = 2
lr = 0.01

m, n, o = t.shape
a = np.random.randn(m, rank)
b = np.random.randn(n, rank)
c = np.random.randn(o, rank)

for i in range(6000):
    del_a, del_b, del_c = mg(a, b, c)
    a-=lr*del_a
    b-=lr*del_b
    c-=lr*del_c
    if i%100==0:
        print(cost(a, b, c))

13.3274290767
12.9615767378
12.231506557
10.1565095586
3.53174838987
0.767466918209
0.453413573575
0.439333159898
0.429231706735
0.417799776519
0.405029785513
0.391291035855
0.377112481341
0.363057905783
0.349606472752
0.337081420301
0.325641873772
0.31531668423
0.306050716626
0.297744840329
0.290283565845
0.283551390652
0.277441118095
0.271857223272
0.266716469829
0.261947158412
0.257487799033
0.253285629231
0.249295184027
0.245477004759
0.241796511742
0.238223035429
0.234728987458
0.231289148475
0.2278800493
0.224479423418
0.221065710751
0.217617594509
0.214113554359
0.210531420269
0.206847911951
0.203038149334
0.199075120024
0.194929091091
0.190567657628
0.210202329142
0.211834653577
0.21338624314
0.214863837549
0.216272751307
0.217617803003
0.218903682835
0.220135110565
0.22131689393
0.22245392217
0.223551101138
0.224613224724
0.225644773567
0.226649634947
0.227630748398


In [31]:
np.round(khatri(a, b, c))

array([[[  0.,   1.,   2.,   3.],
        [  4.,   5.,   6.,   7.],
        [  8.,   9.,  10.,  11.]],

       [[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]]])

In [32]:
a

array([[-2.44052782,  2.14861234],
       [-0.30563334,  3.45964636]])

In [33]:
b

array([[-2.34061748, -1.97795096],
       [-1.81765871, -2.4921937 ],
       [-1.30148193, -3.01197557]])

In [34]:
c

array([[-1.4562426 , -1.9755641 ],
       [-1.35248743, -2.0831725 ],
       [-1.25381021, -2.18718612],
       [-1.15631012, -2.29792478]])