'''
    
    Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/
    Written by Suhan Shetty <suhan.shetty@idiap.ch>,
   
    This file is part of TTGO.

    TTGO is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License version 3 as
    published by the Free Software Foundation.

    TTGO is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with TTGO. If not, see <http://www.gnu.org/licenses/>.
'''


### Comparision of performance of TT vs NN vs BGMM
In this notebook,  we compare the approximation accuracy and speed of training between TT and NN. NN is a great tool for data-driven function approximation. However, it is not that great when the function to be approximated is given. On the other hand, TT is equipped with powerful technique called TT-Cross that can approximate a given function in TT format more efficiently. It directly takes the function to be approximated as input and outputs the function in TT format. Moreover, TT representation, unlike NN, offers other benefits like fast ways to sample, optimize, do algebra etc.

We also compare it against Bayesian GMM. Note that unlike, NN and TT, BGMM requires exact samples from the reference pdf to be fit. 

In [1]:
import torch
from tt_utils import *
from fcn_approx_utils import GMM, NeuralNetwork, BGMM
import time 


In [2]:
device = "cpu"#torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
dim = 5
L = 1
nmix = 20
s = 0.2

# generate an arbitrary function (gmm with centers and covariances chosen randomly)
gmm = GMM(n=dim,nmix=nmix,L=L,mx_coef=None,mu=None,s=s, device=device) 
pdf = gmm.pdf

In [4]:
# For testing and training NN
ndata_train = int(1e5)
ndata_test = 10000

x_train = 2*L*(-0.5 + torch.rand((ndata_train,dim)).to(device))
y_train = pdf(x_train)

x_test = 2*L*(-0.5 + torch.rand((ndata_test,dim)).to(device))
y_test = pdf(x_test)

data_train = torch.cat((x_train.view(-1,dim),y_train.view(-1,1)),dim=-1)
data_test = torch.cat((x_test.view(-1,dim),y_test.view(-1,1)),dim=-1)

### Fit TT Model

In [5]:
# Represent the function in TT format (unsupervised learning and kind of non-parametric)
n_discretization = torch.tensor([200]*dim).to(device)
domain = [torch.linspace(-L,L,n_discretization[i]).to(device) for i in range(dim)] 

t1 = time.time()
tt_gmm = cross_approximate(fcn=pdf,  max_batch=10**6, domain=domain, 
                        rmax=200, nswp=20, eps=1e-3, verbose=True, 
                        kickrank=3, device=device)
t2 = time.time()
print("time taken: ", t2-t1)


cross device is cpu
Cross-approximation over a 5D domain containing 3.2e+11 grid points:
iter: 0  | tt-error: 1.000e+00, test-error:9.807e-01 | time:   0.0699 | largest rank:   1
iter: 1  | tt-error: 2.078e+00, test-error:8.744e-01 | time:   0.2167 | largest rank:   4
iter: 2  | tt-error: 1.188e+00, test-error:7.371e-01 | time:   0.3070 | largest rank:   7
iter: 3  | tt-error: 7.334e-01, test-error:5.693e-01 | time:   0.4282 | largest rank:  10
iter: 4  | tt-error: 5.292e-01, test-error:3.752e-01 | time:   0.5595 | largest rank:  13
iter: 5  | tt-error: 1.260e-01, test-error:3.545e-01 | time:   0.7153 | largest rank:  16
iter: 6  | tt-error: 3.123e-01, test-error:7.426e-03 | time:   0.9527 | largest rank:  19
iter: 7  | tt-error: 5.737e-03, test-error:1.889e-15 | time:   1.2702 | largest rank:  22
iter: 8  | tt-error: 2.039e-08, test-error:1.663e-15 | time:   1.6888 | largest rank:  25 <- converged: eps < 0.001
Did 2543400 function evaluations, which took 1.37s (1.856e+06 evals/s)

tim

In [6]:
# Test the accuracy of TT over the test set 
y_tt =  get_value(tt_model=tt_gmm, x=x_test.to(device),  domain=domain, 
                    n_discretization=n_discretization , max_batch=10**5, device=device)

mse_tt = ((y_tt.view(-1)-y_test.view(-1))**2).mean()
print("mse_tt: ", mse_tt)

mse_tt:  tensor(5.2491e-09)


### Fit NN Model

In [7]:
# Fit NN
lr= 1e-3
batch_size = 128
epochs = 1
nn = NeuralNetwork(dim, width=64, lr=1e-3, device=device)
nn.load_data(data_train, data_test)
t1 = time.time()
# nn.train(num_epochs=epochs, batch_size=batch_size, verbose=True)
t2 = time.time()
print("time taken: ", t2-t1)
print("Done!")

time taken:  3.4809112548828125e-05
Done!


In [8]:
# Test the accuracy of NN over the test set
y_nn = nn.model(x_test)
mse_nn = ((y_nn.view(-1)-y_test.view(-1))**2).mean().detach()
print("mse_nn: ", mse_nn)

mse_nn:  tensor(0.0167)


### Fit BGMM Model

In [9]:
# Sample data and Train BGMM
X_sample = gmm.generate_sample(x_train.shape[0]) # sample from reference distribution
bgmm = BGMM(nmix=nmix)
X_numpy = X_sample.detach().cpu().numpy()
bgmm.load_data(X_numpy)
bgmm.fit()

# Test bgmm
y_bgmm = bgmm.pdf(x_test.detach().cpu().numpy())
y_test_numpy = y_test.detach().cpu().numpy()
mse_bgmm = ((y_bgmm.reshape(-1)-y_test_numpy.reshape(-1))**2).mean()
print("mse_bgmm: ", mse_bgmm)

mse_bgmm:  tensor(0.0002)




In [10]:
print(f" mse_tt:{mse_tt},\n mse_bgmm:{mse_bgmm},\n mse_nn:{mse_nn}")

 mse_tt:5.249069685752172e-09,
 mse_bgmm:0.00016431352560326084,
 mse_nn:0.016652875400446403
