'''
    
    Copyright (c) 2022 Idiap Research Institute, http://www.idiap.ch/
    Written by Suhan Shetty <suhan.shetty@idiap.ch>,
   
    This file is part of TTGO.

    TTGO is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License version 3 as
    published by the Free Software Foundation.

    TTGO is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with TTGO. If not, see <http://www.gnu.org/licenses/>.
'''


### Comparision of performance of TT vs NN
In this notebook,  we compare the approximation accuracy and speed of training between TT and NN. NN is a great tool for data-driven function approximation. However, it is not that great when the function to be approximated is given. On the other hand, TT is equipped with powerful technique called TT-Cross that can approximate a given function in TT format more efficiently. It directly takes the function to be approximated as input and outputs the function in TT format. Moreover, TT representation, unlike NN, offers other benefits like fast ways to sample, optimize, do algebra etc.

In [1]:
import torch
from tt_utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:

def gmm(n=2,nmix=3,L=1,mx_coef=None,mu=None,s=0.1, device='cpu'):
    """
        Mixture of spherical Gaussians (un-normalized)
        nmix: number of mixture coefficients
        n: dimension of the domain
        s: variance
        mu: the centers assumed to be in : [-L,L]^n
    """
    n_sqrt = torch.sqrt(torch.tensor([n]).to(device))
    if mx_coef is None: # if centers and mixture coef are not given, generate them randomly
        mx_coef = torch.rand(nmix).to(device)
        mx_coef = mx_coef/torch.sum(mx_coef)
        mu = (torch.rand(nmix,n).to(device)-0.5)*2*L

    def pdf(x):
        result = torch.tensor([0]).to(device)
        for k in range(nmix):
            l = torch.linalg.norm(mu[k]-x, dim=1)/n_sqrt
            result = result + mx_coef[k]*torch.exp(-(l/s)**2)
        return 1.+100*result

    return pdf


In [4]:
dim = 10
L = 1
nmix = 1
s = 0.2

# generate an arbitrary function (gmm with centers and covariances chosen randomly)
pdf = gmm(n=dim,nmix=nmix,L=L,mx_coef=None,mu=None,s=s, device=device) 

In [5]:
# Represent the function in TT format (unsupervised learning and kind of non-parametric)
n_discretization = torch.tensor([200]*dim).to(device)
domain = [torch.linspace(-L,L,n_discretization[i]).to(device) for i in range(dim)] 

import time 
t1 = time.time()
tt_gmm = cross_approximate(fcn=pdf,  max_batch=10**6, domain=domain, 
                        rmax=200, nswp=20, eps=1e-3, verbose=True, 
                        kickrank=3, device=device)
t2 = time.time()
print("time taken: ", t2-t1)


cross device is cuda
Cross-approximation over a 10D domain containing 1.024e+23 grid points:
iter: 0  | tt-error: 1.074e+00, test-error:9.370e-01 | time:   1.3723 | largest rank:   1
iter: 1  | tt-error: 2.653e+00, test-error:6.331e-15 | time:   1.5148 | largest rank:   4
iter: 2  | tt-error: 0.000e+00, test-error:1.367e-14 | time:   1.6755 | largest rank:   7 <- converged: eps < 0.001
Did 218400 function evaluations, which took 0.03448s (6.335e+06 evals/s)

time taken:  1.8815488815307617


In [6]:
# For testing and training NN
ndata_train = 200000
ndata_test = 1000

x_train = 2*L*(-0.5 + torch.rand((ndata_train,dim)).to(device))
y_train = pdf(x_train)

x_test = 2*L*(-0.5 + torch.rand((ndata_test,dim)).to(device))
y_test = pdf(x_test)

data_train = torch.cat((x_train.view(-1,dim),y_train.view(-1,1)),dim=-1)
data_test = torch.cat((x_test.view(-1,dim),y_test.view(-1,1)),dim=-1)

In [7]:
# Test the accuracy of TT over the test set 
y_tt =  get_value(tt_model=tt_gmm, x=x_test.to(device),  domain=domain, 
                    n_discretization=n_discretization , max_batch=10**5, device=device)

mse_tt = (((y_tt.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()
print("mse_tt: ", mse_tt)

mse_tt:  tensor(5.6480e-11, device='cuda:0')


#### Represent the function as a NN 

In [8]:
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F


Matplotlib created a temporary config/cache directory at /tmp/matplotlib-wmlzg2qv because the default path (/idiap/home/sshetty/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [9]:
class NeuralNetwork(nn.Module):
    def __init__(self, dim=2, width=32):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(dim, width),
            nn.ReLU(),
            nn.Linear(width, width),
            nn.ReLU(),
            nn.Linear(width, 1),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork(dim=dim, width=dim*nmix*10).to(device)
def train_loop(data, model, loss_fn, optimizer, batch_size):
    size = data.shape[0]
    counter = 0
    for i in range(int(size/batch_size)-1):
        # Compute prediction and loss
        next_counter = (counter+batch_size)
        x_data = data[counter:next_counter,:-1]
        y_data = data[counter:next_counter,-1].view(-1,1)
        y_pred = model(x_data)
        loss = loss_fn(y_pred, y_data)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        counter = 1*next_counter

        if (i % int(0.25*size/batch_size)) == 0 :
            loss = loss.item()
            print(f"loss: {loss:>7f}")


def test_loop(data, model, loss_fn):
    x_data = data[:,:-1]
    y_data = data[:,-1]
    with torch.no_grad():
        pred = model(x_data)
        test_loss = loss_fn(pred, y_data).item()
    print(f"Test Error: ", test_loss)

In [10]:
learning_rate = 1e-3
batch_size = 100
epochs = 10

In [11]:
y_nn_0 = model(x_test)
mse_nn_0 = (((y_nn_0.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()
print("mse_nn_0: ", mse_nn_0)

mse_nn_0:  tensor(1.1368, device='cuda:0', grad_fn=<MeanBackward0>)


In [12]:
# NN training (Note: compare the time it takes with TT-Cross)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
batch_size = 100
epochs = 10
t1 = time.time()
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(data_train, model, loss_fn, optimizer, batch_size)
    test_loop(data_test, model, loss_fn)
t2 = time.time()
print("time taken: ", t2-t1)
print("Done!")

Epoch 1
-------------------------------
loss: 1.178049
loss: 0.002727
loss: 0.130106
loss: 0.001618


  return F.mse_loss(input, target, reduction=self.reduction)


Test Error:  0.19129082062105746
Epoch 2
-------------------------------
loss: 0.014067
loss: 0.001272
loss: 0.296887
loss: 0.001081
Test Error:  0.24745713466026872
Epoch 3
-------------------------------
loss: 0.001960
loss: 0.000607
loss: 0.107745
loss: 0.000396
Test Error:  0.2681934520381166
Epoch 4
-------------------------------
loss: 0.003532
loss: 0.000300
loss: 0.015622
loss: 0.000641
Test Error:  0.2947253271853954
Epoch 5
-------------------------------
loss: 0.003216
loss: 0.000287
loss: 0.003205
loss: 0.000881
Test Error:  0.28090638369930526
Epoch 6
-------------------------------
loss: 0.004636
loss: 0.000319
loss: 0.000631
loss: 0.000360
Test Error:  0.3405355712848097
Epoch 7
-------------------------------
loss: 0.002147
loss: 0.000153
loss: 0.000551
loss: 0.000163
Test Error:  0.29791631463540097
Epoch 8
-------------------------------
loss: 0.002407
loss: 0.000159
loss: 0.001045
loss: 0.000148
Test Error:  0.3015412034751193
Epoch 9
-------------------------------


In [13]:
# Test the accuracy over the test set
y_nn = model(x_test)
mse_nn = (((y_nn.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()
print("mse_nn: ", mse_nn)

mse_nn:  tensor(0.0003, device='cuda:0', grad_fn=<MeanBackward0>)


In [14]:
y_nn = model(x_test)
mse_nn = (((y_nn.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()
print("mse_nn: ", mse_nn)

mse_nn:  tensor(0.0003, device='cuda:0', grad_fn=<MeanBackward0>)


In [15]:
mse_tt = (((y_tt.view(-1)-y_test.view(-1))/(1e-9+y_test.view(-1).abs()))**2).mean()
print("mse_tt: ", mse_tt)

mse_tt:  tensor(5.6480e-11, device='cuda:0')


In [16]:
y_nn = model(x_train)
mse_nn = (((y_nn.view(-1)-y_train.view(-1))/(1e-6+y_train.view(-1).abs()))**2).mean()
print("mse_nn: ", mse_nn)

mse_nn:  tensor(0.0006, device='cuda:0', grad_fn=<MeanBackward0>)


In [17]:
# y_nn = model(x_train)
# mse_nn = ((y_nn-y_train)**2).mean()
# print("mse_nn: ", mse_nn)

In [18]:
y_nn = model(x_test)
mse_nn = ((y_nn-y_test)**2).mean()
print("mse_nn: ", mse_nn)

mse_nn:  tensor(0.2952, device='cuda:0', grad_fn=<MeanBackward0>)


In [19]:
mse_tt = (((y_tt.view(-1)-y_test.view(-1)))**2).mean()
print("mse_tt: ", mse_tt)

mse_tt:  tensor(6.4013e-09, device='cuda:0')
