In [1]:
import torch
import torch.nn.functional as f
import numpy as np
import tltorch as tlt
import tensorly as ten
from lsr_tensor import *
from lsr_bcd_regression import *
import cProfile
import pstats

In [2]:
def synthesize_data(true_tensor, sample_size, shape, x_stdev, y_stdev):
    x = torch.randn((sample_size, *shape)) * x_stdev
    y = true_tensor(x) + torch.randn_like(true_tensor(x)) * y_stdev
    
    return x, y

In [3]:
def estimation_error(true, estimate):
    return torch.norm(true - estimate) / torch.norm(true)

In [4]:
# Tests
shape, ranks, separation_rank = (64, 64), (4, 4), 2
x_stdev = 1
y_stdev = 0.05
sample_size = 5000

with torch.no_grad():
    true_lsr = LSR_tensor(shape, ranks, separation_rank)
    f.normalize(true_lsr.core_tensor, p=2, dim=0, out=true_lsr.core_tensor)
    true_lsr.core_tensor *= (5 / torch.sqrt(torch.sqrt(torch.prod(torch.tensor(ranks)))))
    xs, ys = synthesize_data(true_lsr, sample_size, shape, x_stdev, y_stdev)

In [None]:
lsr_ten = lsr_bcd_regression(f.mse_loss, xs, ys, shape, ranks, separation_rank, lr=0.01, momentum=0.0,\
                             step_epochs=1, max_iter=100, batch_size=sample_size)

In [42]:
profiler = cProfile.Profile()
profiler.enable()
lsr_ten = lsr_bcd_regression(f.mse_loss, xs, ys, shape, ranks, separation_rank, lr=0.01, momentum=0.0,\
                             step_epochs=5, max_iter=250, batch_size=sample_size, init_zero=False, ortho=False)
profiler.disable()

iteration 0 diff:  tensor(6.7220, grad_fn=<CopyBackwards>)
iteration 1 diff:  tensor(5.1238, grad_fn=<CopyBackwards>)
iteration 2 diff:  tensor(5.4209, grad_fn=<CopyBackwards>)
iteration 3 diff:  tensor(3.2726, grad_fn=<CopyBackwards>)
iteration 4 diff:  tensor(3.2336, grad_fn=<CopyBackwards>)
iteration 5 diff:  tensor(3.8480, grad_fn=<CopyBackwards>)
iteration 6 diff:  tensor(4.0414, grad_fn=<CopyBackwards>)
iteration 7 diff:  tensor(3.8946, grad_fn=<CopyBackwards>)
iteration 8 diff:  tensor(5.7041, grad_fn=<CopyBackwards>)
iteration 9 diff:  tensor(4.0586, grad_fn=<CopyBackwards>)
iteration 10 diff:  tensor(2.5818, grad_fn=<CopyBackwards>)
iteration 11 diff:  tensor(1.8852, grad_fn=<CopyBackwards>)
iteration 12 diff:  tensor(1.2742, grad_fn=<CopyBackwards>)
iteration 13 diff:  tensor(2.7052, grad_fn=<CopyBackwards>)
iteration 14 diff:  tensor(0.7528, grad_fn=<CopyBackwards>)
iteration 15 diff:  tensor(3.1932, grad_fn=<CopyBackwards>)
iteration 16 diff:  tensor(2.4264, grad_fn=<CopyBa

iteration 137 diff:  tensor(0.2484, grad_fn=<CopyBackwards>)
iteration 138 diff:  tensor(7.4320, grad_fn=<CopyBackwards>)
iteration 139 diff:  tensor(5.9183, grad_fn=<CopyBackwards>)
iteration 140 diff:  tensor(5.1031, grad_fn=<CopyBackwards>)
iteration 141 diff:  tensor(2.5001, grad_fn=<CopyBackwards>)
iteration 142 diff:  tensor(3.9860, grad_fn=<CopyBackwards>)
iteration 143 diff:  tensor(4.0347, grad_fn=<CopyBackwards>)
iteration 144 diff:  tensor(4.1841, grad_fn=<CopyBackwards>)
iteration 145 diff:  tensor(4.2837, grad_fn=<CopyBackwards>)
iteration 146 diff:  tensor(2.3982, grad_fn=<CopyBackwards>)
iteration 147 diff:  tensor(1.8501, grad_fn=<CopyBackwards>)
iteration 148 diff:  tensor(0.9081, grad_fn=<CopyBackwards>)
iteration 149 diff:  tensor(3.8992, grad_fn=<CopyBackwards>)
iteration 150 diff:  tensor(3.2299, grad_fn=<CopyBackwards>)
iteration 151 diff:  tensor(2.7910, grad_fn=<CopyBackwards>)
iteration 152 diff:  tensor(3.4973, grad_fn=<CopyBackwards>)
iteration 153 diff:  ten

In [17]:
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()

         755856 function calls (709640 primitive calls) in 10.416 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    4.000    0.004    7.424    0.007 /home/fishcat/Federated-LSRTR/lsr_tensor.py:57(bcd_factor_update_x)
     4650    2.579    0.001    2.579    0.001 {built-in method torch.matmul}
      250    1.224    0.005    1.389    0.006 /home/fishcat/Federated-LSRTR/lsr_tensor.py:70(bcd_core_update_x)
     2000    0.527    0.000    0.527    0.000 {built-in method torch.cat}
     1250    0.358    0.000    0.358    0.000 {method 'run_backward' of 'torch._C._EngineBase' objects}
    16050    0.168    0.000    0.168    0.000 {built-in method torch.reshape}
     1250    0.101    0.000    0.101    0.000 {built-in method torch._C._nn.mse_loss}
     2250    0.083    0.000    0.083    0.000 {method 'reduce' of 'numpy.ufunc' objects}
43500/12700    0.080    0.000    3.378    0.000 /home/fishcat/.local/lib/python3.8/site-p

<pstats.Stats at 0x7fbff4c037f0>

In [25]:
print(torch.norm(lsr_ten.expand_to_tensor()))
print(torch.norm(lsr_ten.core_tensor))
for s in lsr_ten.factor_matrices:
    for k in s:
        print(torch.norm(k))

tensor(6.1617, grad_fn=<CopyBackwards>)
tensor(4.1462, grad_fn=<CopyBackwards>)
tensor(2., grad_fn=<CopyBackwards>)
tensor(2.0000, grad_fn=<CopyBackwards>)
tensor(2., grad_fn=<CopyBackwards>)
tensor(2.0000, grad_fn=<CopyBackwards>)


In [26]:
print(torch.norm(true_lsr.expand_to_tensor()))
print(torch.norm(true_lsr.core_tensor))
for s in true_lsr.factor_matrices:
    for k in s:
        print(torch.norm(k))

tensor(7.0940, grad_fn=<CopyBackwards>)
tensor(5., grad_fn=<CopyBackwards>)
tensor(2.0000, grad_fn=<CopyBackwards>)
tensor(2., grad_fn=<CopyBackwards>)
tensor(2.0000, grad_fn=<CopyBackwards>)
tensor(2., grad_fn=<CopyBackwards>)


In [43]:
print(estimation_error(true_lsr.expand_to_tensor(), lsr_ten.expand_to_tensor()))
print(estimation_error(true_lsr.expand_to_tensor(), true_lsr.expand_to_tensor()))

tensor(0.5884, grad_fn=<DivBackward0>)
tensor(0., grad_fn=<DivBackward0>)
