In [None]:
# %%
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
from numpy import genfromtxt
torch.set_default_dtype(torch.float64)

In [None]:
from torch.nn import functional as F

# train
x = genfromtxt('data/mdata_train.txt', delimiter=',')
X_train = x.T
y_train = genfromtxt('data/mdata_train_l.txt',delimiter=',')

x_test = genfromtxt('data/mdata_test.txt', delimiter=',')
X_test = x_test.T
y_test = genfromtxt('data/mdata_test_l.txt',delimiter=',')

# Convert to torch
X_train = torch.tensor(X_train, dtype=torch.float64)
y_train = torch.tensor(y_train == 1, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float64)
y_test = torch.tensor(y_test == 1, dtype=torch.long)

y_train = F.one_hot(y_train, num_classes=2).to(dtype=torch.float64)
y_test = F.one_hot(y_test, num_classes=2).to(dtype=torch.float64)

# Move data to GPU
X_train, y_train = X_train.cuda(), y_train.cuda()
X_test, y_test = X_test.cuda(), y_test.cuda()

# Prepare train input
xinp_train = torch.cat([torch.ones(X_train.shape[0], 1, dtype=X_train.dtype, device=X_train.device), X_train], dim=-1)
xinp_test = torch.cat([torch.ones(X_test.shape[0], 1, dtype=X_test.dtype, device=X_test.device), X_test], dim=-1)

In [None]:
from tensor.layers import TensorTrainLayer
from tensor.bregman import KLDivBregman, XEAutogradBregman
from sklearn.metrics import balanced_accuracy_score

N = 2
r = 5
p = X_train.shape[1]+1
C = y_train.shape[1]-1

def convergence_criterion(y_pred, y_true):
    y_pred = torch.cat((y_pred, torch.zeros_like(y_pred[:, :1])), dim=1)
    #accuracy = (y_pred.argmax(dim=-1) == y_true.argmax(dim=-1)).float().mean().item()
    balanced_acc = balanced_accuracy_score(y_true.argmax(dim=-1).cpu().numpy(), y_pred.argmax(dim=-1).cpu().numpy())
    print("Balanced Accuracy:", balanced_acc)
    #return accuracy > 0.95
    return False

# Define Bregman function
layer = TensorTrainLayer(N, r, p, output_shape=C).cuda()

In [None]:
layer.tensor_network.nodes

In [None]:
with torch.inference_mode():
    y_pred = layer(xinp_train)
    w = 1/y_pred.std().item()
    del y_pred
bf = XEAutogradBregman(w=w)

layer.tensor_network.accumulating_swipe(xinp_train, y_train, bf, batch_size=64, lr=1.0, convergence_criterion=convergence_criterion, orthonormalize=False, method='exact', eps=1e-4, verbose=True, num_swipes=10)

In [None]:
print("Train accuracy:")
convergence_criterion(layer(xinp_train), y_train)
print("Test accuracy:")
convergence_criterion(layer(xinp_test), y_test)
None

In [None]:
from tensor.utils import visualize_tensornetwork
visualize_tensornetwork(layer.tensor_network)