In [1]:
import torch
import torch.nn.functional as f
import numpy as np
import tltorch as tlt
import tensorly as ten
from lsr_tensor import *
from lsr_bcd_regression import *
import cProfile
import pstats

In [2]:
def synthesize_data(true_tensor, sample_size, shape, x_stdev, y_stdev):
    x = torch.randn((sample_size, *shape)) * x_stdev
    y = true_tensor(x) + torch.randn_like(true_tensor(x)) * y_stdev
    
    return x, y

In [3]:
def estimation_error(true, estimate):
    return torch.norm(true - estimate) / torch.norm(true)

In [20]:
# Tests
shape, ranks, separation_rank = (128, 128), (8, 8), 2
x_stdev = 1
y_stdev = 0.05
sample_size = 5000

with torch.no_grad():
    true_lsr = LSR_tensor(shape, ranks, separation_rank)
    f.normalize(true_lsr.core_tensor, p=2, dim=0, out=true_lsr.core_tensor)
    true_lsr.core_tensor *= (5 / torch.sqrt(torch.sqrt(torch.prod(torch.tensor(ranks)))))
    xs, ys = synthesize_data(true_lsr, sample_size, shape, x_stdev, y_stdev)

dataset = torch.utils.data.TensorDataset(xs, ys)

In [26]:
lsr_ten = lsr_bcd_regression(f.mse_loss, dataset, shape, ranks, separation_rank, lr=0.01, momentum=0.9,\
                             step_epochs=1, max_iter=500, batch_size=256, threshold=1e-4,\
                             init_zero=False, ortho=True, debug=False, verbose=True)

Iteration 0 | Delta: 14.43218994140625, Training Loss: 110.47994995117188
Iteration 10 | Delta: 6.6682329177856445, Training Loss: 51.56616973876953
Iteration 20 | Delta: 6.837027072906494, Training Loss: 35.990665435791016


KeyboardInterrupt: 

In [None]:
profiler = cProfile.Profile()
profiler.enable()
lsr_ten = lsr_bcd_regression(f.mse_loss, xs, ys, shape, ranks, separation_rank, lr=0.01, momentum=0.9,\
                             step_epochs=5, max_iter=200, batch_size=sample_size, threshold=1e-3,\
                             init_zero=False, ortho=True, debug=False)
profiler.disable()

In [None]:
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()

In [None]:
print(torch.norm(lsr_ten.expand_to_tensor()))
print(torch.norm(lsr_ten.core_tensor))
for s in lsr_ten.factor_matrices:
    for k in s:
        print(torch.norm(k))

In [None]:
print(torch.norm(true_lsr.expand_to_tensor()))
print(torch.norm(true_lsr.core_tensor))
for s in true_lsr.factor_matrices:
    for k in s:
        print(torch.norm(k))

In [19]:
print(estimation_error(true_lsr.expand_to_tensor(), lsr_ten.expand_to_tensor()))
print(estimation_error(true_lsr.expand_to_tensor(), true_lsr.expand_to_tensor()))

tensor(0.0047, grad_fn=<DivBackward0>)
tensor(0., grad_fn=<DivBackward0>)
