In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader

import tntorch as tn

import tensorkrowch as tk
from tensorkrowch.decompositions import tt_rss

# Create dataset

In [2]:
# Instance parameters
sq_root_size = 10
input_size = sq_root_size ** 2
dataset_size = int(1e4)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 128

In [3]:
# Create dataset
def create_dataset(size):
    tensors = []
    labels = []
    for _ in range(size):
        while True:
            t = torch.randint(low=0, high=2, size=(1, sq_root_size))  # stripes
            t = t.expand(sq_root_size, -1)
            if not torch.equal(t, torch.zeros_like(t)) \
                and not torch.equal(t, torch.ones_like(t)):
                    break
        
        if torch.rand(1) < 0.5:
            # stripes
            tensors.append(t)
            labels.append(1)
            
        else:
            # bars
            tensors.append(t.t())
            labels.append(0)
    
    tensors = torch.stack(tensors, dim=0).view(size, -1)
    labels = torch.Tensor(labels).long()
    
    return tensors, labels

In [4]:
dataset_dir = os.path.join('..', '..', 'results', '1_performance', 'bars_stripes_mps')
os.makedirs(dataset_dir, exist_ok=True)

In [22]:
dataset = create_dataset(dataset_size)
# torch.save(dataset, os.path.join(dataset_dir, f'dataset_{input_size}.pt'))

# Train MPS model on Bars and Stripes dataset

In [6]:
# Load dataset
dataset = torch.load(os.path.join(dataset_dir, f'dataset_{input_size}.pt'),
                     weights_only=False)

train_dataset = TensorDataset(dataset[0][:int(0.8 * dataset_size)],
                              dataset[1][:int(0.8 * dataset_size)])
test_dataset = TensorDataset(dataset[0][int(0.8 * dataset_size):],
                             dataset[1][int(0.8 * dataset_size):])

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=2)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True,
                         num_workers=2)

In [7]:
# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model, embedding):
    num_correct = 0
    num_samples = 0
    
    model.eval()
    with torch.no_grad():
        for data, labels in loader:
            data = data.to(device)
            labels = labels.to(device)
            
            scores = model(embedding(data),
                           inline_input=False,
                           inline_mats=False)
            scores = scores.pow(2)
            scores = scores / scores.norm(dim=1, keepdim=True)
            
            _, preds = scores.max(1)
            num_correct += (preds == labels).sum()
            num_samples += preds.size(0)
        
    model.train()
    return float(num_correct) / float(num_samples) * 100

In [8]:
def embedding(x):
    x = tk.embeddings.basis(x, dim=2).float()
    return x

In [9]:
model = tk.models.MPSLayer(n_features=input_size + 1,
                           in_dim=2,
                           out_dim=2,
                           bond_dim=10,
                           init_method='unit',
                           device=device)

model.trace(torch.zeros(1, input_size, 2, device=device),
            inline_input=False,
            inline_mats=False)

# Hyperparameters
num_epochs = 20
learning_rate = 1e-3
weight_decay = 1e-6

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)

In [10]:
# Train network
print('* TRAINING MODEL...')
for epoch in range(num_epochs):
    for data, labels in train_loader:
        data = data.to(device)
        labels = labels.to(device)
        
        # Forward
        scores = model(embedding(data),
                       inline_input=False,
                       inline_mats=False)
        scores = scores.pow(2)
        scores = scores / scores.norm(dim=1, keepdim=True)
        
        loss = criterion(scores, labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient descent
        optimizer.step()
        
    train_acc = check_accuracy(train_loader, model, embedding)
    test_acc = check_accuracy(test_loader, model, embedding)
        
    print(f'Epoch {epoch + 1}/{num_epochs} => Train Loss: {loss:.4f} '
          f'Train Acc.: {train_acc:.2f}, Test Acc.: {test_acc:.2f}')

# Save model
model.reset()

results_dir = os.path.join('..', '..', 'results', '1_performance', 'bars_stripes_mps')
os.makedirs(results_dir, exist_ok=True)

# torch.save(model.tensors,
#            os.path.join(results_dir,
#                         f'cores_{test_acc:.2f}.pt'))

* TRAINING MODEL...
Epoch 1/20 => Train Loss: 0.3943 Train Acc.: 89.66, Test Acc.: 88.60
Epoch 2/20 => Train Loss: 0.3485 Train Acc.: 96.24, Test Acc.: 94.55
Epoch 3/20 => Train Loss: 0.3345 Train Acc.: 97.95, Test Acc.: 96.90
Epoch 4/20 => Train Loss: 0.3488 Train Acc.: 98.94, Test Acc.: 98.20
Epoch 5/20 => Train Loss: 0.3166 Train Acc.: 99.36, Test Acc.: 99.10
Epoch 6/20 => Train Loss: 0.3157 Train Acc.: 99.56, Test Acc.: 99.45
Epoch 7/20 => Train Loss: 0.3152 Train Acc.: 99.74, Test Acc.: 99.55
Epoch 8/20 => Train Loss: 0.3146 Train Acc.: 99.89, Test Acc.: 99.60
Epoch 9/20 => Train Loss: 0.3146 Train Acc.: 99.91, Test Acc.: 99.60
Epoch 10/20 => Train Loss: 0.3147 Train Acc.: 99.94, Test Acc.: 99.60
Epoch 11/20 => Train Loss: 0.3137 Train Acc.: 99.94, Test Acc.: 99.60
Epoch 12/20 => Train Loss: 0.3135 Train Acc.: 99.95, Test Acc.: 99.60
Epoch 13/20 => Train Loss: 0.3136 Train Acc.: 100.00, Test Acc.: 99.75
Epoch 14/20 => Train Loss: 0.3135 Train Acc.: 100.00, Test Acc.: 99.75
Epoch 1

# TT-RSS

In [16]:
results_dir = os.path.join('..', '..', 'results', '1_performance', 'bars_stripes_mps')
os.makedirs(results_dir, exist_ok=True)

In [17]:
# Load cores
cores = torch.load(os.path.join(results_dir, 'cores_99.75.pt'),
                   weights_only=False)
mps = tk.models.MPSLayer(tensors=cores)

mps.trace(torch.zeros(1, input_size, 2, device=device))

In [19]:
sketch_size = 50

domain = [torch.arange(2) for _ in range(input_size)]

# Load dataset
dataset = torch.load(os.path.join(dataset_dir, f'dataset_{input_size}.pt'),
                     weights_only=False)[0]
sketch_samples = dataset[:sketch_size]

def embedding(x): return tk.embeddings.basis(x, dim=2).float()
def fun(x): return mps(embedding(x))

cores_rss, info = tt_rss(function=fun,
                         embedding=embedding,
                         sketch_samples=sketch_samples,
                         domain=domain,
                         rank=10,
                         cum_percentage=1 - 1e-5,
                         batch_size=500,
                         device=device,
                         verbose=False,
                         return_info=True)

mps.reset()

# Save cores
# torch.save(cores_rss,
#            os.path.join(results_dir, f'cores_rss_{info["total_time"]:.2f}.pt'))

In [20]:
info

{'total_time': 19.5962233543396,
 'val_eps': tensor(1.4890e-06, device='cuda:0')}

In [21]:
cores_rss = torch.load(os.path.join(results_dir, 'cores_rss_19.60.pt'),
                       weights_only=False)
mps_rss = tk.models.MPS(tensors=[c.to(device) for c in cores_rss])
mps_rss.canonicalize(renormalize=True)

In [22]:
mps_norm = mps.norm(log_scale=True)
mps_rss_norm = mps_rss.norm(log_scale=True)

print(f'MPS: {mps_norm.item()}')
print(f'MPS RSS: {mps_rss_norm.item()}')

mps.reset()
mps_rss.reset()

MPS: 31.648563385009766
MPS RSS: 31.648635864257812


In [23]:
for node1, node2 in zip(mps.mats_env, mps_rss.mats_env):
    node1['input'] ^ node2['input']

In [24]:
log_scale = 0

# Contract mps with mps_rss
stack = tk.stack(mps.mats_env)
stack_rss = tk.stack(mps_rss.mats_env)
stack ^ stack_rss

mats_results = tk.unbind(stack @ stack_rss)

mats_results[0] = mps.left_node @ (mps_rss.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps_rss.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

31.64859962463379


In [25]:
(2*approx_mps_norm - mps_norm - mps_rss_norm).exp()

tensor(1., device='cuda:0', grad_fn=<ExpBackward0>)

# TT-CROSS

In [26]:
results_dir = os.path.join('..', '..', 'results', '1_performance', 'bars_stripes_mps')
os.makedirs(results_dir, exist_ok=True)

In [27]:
# Load cores
cores = torch.load(os.path.join(results_dir, 'cores_99.75.pt'),
                   weights_only=False)
mps = tk.models.MPSLayer(tensors=cores)

mps.trace(torch.zeros(1, input_size, 2, device=device))

In [26]:
domain = [torch.arange(2, device=device) for _ in range(input_size + 1)]

def embedding(x): return tk.embeddings.basis(x.int(), dim=2).float()

def fun(x):
    out_position = (input_size + 1) // 2
    data = torch.cat([x[:, :out_position], x[:, (out_position + 1):]], dim=1)
    labels = x[:, out_position:(out_position + 1)]
    out = mps(embedding(data))
    out = out.gather(dim=1, index=labels.long()).flatten()
    return out

tt_cross, info = tn.cross(function=fun,
                          domain=domain,
                          device=device,
                          function_arg='matrix',
                          rmax=10,
                          # max_iter=5,
                          eps=1e-3,
                          verbose=True,
                          return_info=True)

cores_cross = tt_cross.cores
cores_cross[0] = cores_cross[0][0]
cores_cross[-1] = cores_cross[-1][..., 0]

mps.reset()

# Save cores
# torch.save(cores_cross,
#            os.path.join(results_dir, f'cores_cross_{info["total_time"]:.2f}.pt'))

cross device is cuda
Functions that require cross-approximation can be accelerated with the optional maxvolpy package, which can be installed by 'pip install maxvolpy'. More info is available at https://bitbucket.org/muxas/maxvolpy.
Cross-approximation over a 101D domain containing 2.5353e+30 grid points:
iter: 0  | eps: 1.000e+00 | time:   9.9796 | largest rank:   1
iter: 1  | eps: 1.000e+00 | time: 107.2939 | largest rank:   4
iter: 2  | eps: 1.041e+00 | time: 243.8108 | largest rank:   7
iter: 3  | eps: 4.262e-06 | time: 374.3675 | largest rank:  10 <- converged: eps < 0.001
Did 63778 function evaluations, which took 5.301s (1.203e+04 evals/s)



In [27]:
info

{'nsamples': 63778,
 'eval_time': 5.300762176513672,
 'val_epss': [tensor(1., device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.0000, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(1.0411, device='cuda:0', grad_fn=<DivBackward0>),
  tensor(4.2620e-06, device='cuda:0', grad_fn=<DivBackward0>)],
 'min': 0,
 'argmin': None,
 'lsets': [array([[0]]),
  array([[0, 0],
         [0, 1]]),
  array([[0, 0, 0],
         [0, 0, 1],
         [0, 1, 0],
         [0, 1, 1]]),
  array([[0, 0, 0, 0],
         [0, 0, 0, 1],
         [0, 0, 1, 0],
         [0, 0, 1, 1],
         [0, 1, 0, 0],
         [0, 1, 0, 1],
         [0, 1, 1, 0],
         [0, 1, 1, 1]]),
  array([[0, 0, 0, 1, 1],
         [0, 1, 1, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 1, 0],
         [0, 1, 1, 0, 0],
         [0, 1, 0, 1, 1],
         [0, 0, 1, 0, 1],
         [0, 1, 0, 1, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 1, 1, 1]]),
  array([[0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 1, 1, 0,

In [28]:
cores_cross = torch.load(os.path.join(results_dir, 'cores_cross_374.37.pt'),
                         weights_only=False)
mps_cross = tk.models.MPS(tensors=[c.to(device) for c in cores_cross])
mps_cross.canonicalize(renormalize=True)

In [29]:
mps_norm = mps.norm(log_scale=True)
mps_cross_norm = mps_cross.norm(log_scale=True)

print(f'MPS: {mps_norm.item()}')
print(f'MPS RSS: {mps_cross_norm.item()}')

mps.reset()
mps_cross.reset()

MPS: 31.648563385009766
MPS RSS: 31.64862823486328


In [30]:
for node1, node2 in zip(mps.mats_env, mps_cross.mats_env):
    node1['input'] ^ node2['input']

In [31]:
log_scale = 0

# Contract mps with mps_cross
stack = tk.stack(mps.mats_env)
stack_cross = tk.stack(mps_cross.mats_env)
stack ^ stack_cross

mats_results = tk.unbind(stack @ stack_cross)

mats_results[0] = mps.left_node @ (mps_cross.left_node @ mats_results[0])
mats_results[-1] = (mats_results[-1] @ mps.right_node) @ mps_cross.right_node

result = mats_results[0]
for mat in mats_results[1:]:
    result @= mat
    
    log_scale += result.norm().log()
    result = result.renormalize()

approx_mps_norm = (result.tensor.log() + log_scale) / 2
print(approx_mps_norm.item())

31.648603439331055


In [32]:
(2*approx_mps_norm - mps_norm - mps_rss_norm).exp()

tensor(1.0000, device='cuda:0', grad_fn=<ExpBackward0>)