In [1]:
# %%
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import torch
from numpy import genfromtxt
torch.set_default_dtype(torch.float64)

In [2]:
# train
x = genfromtxt('data/mdata_train.txt', delimiter=',')
X_train = x.T
y_train = genfromtxt('data/mdata_train_l.txt',delimiter=',')

x_test = genfromtxt('data/mdata_test.txt', delimiter=',')
X_test = x_test.T
y_test = genfromtxt('data/mdata_test_l.txt',delimiter=',')

# Convert to torch
X_train = torch.tensor(X_train, dtype=torch.float64)
y_train = torch.tensor(y_train == 1, dtype=torch.long).unsqueeze(1)
X_test = torch.tensor(X_test, dtype=torch.float64)
y_test = torch.tensor(y_test == 1, dtype=torch.long).unsqueeze(1)

# Move data to GPU
X_train, y_train = X_train.cuda(), y_train.cuda()
X_test, y_test = X_test.cuda(), y_test.cuda()

# Prepare train input
xinp_train = torch.cat([torch.ones(X_train.shape[0], 1, dtype=X_train.dtype, device=X_train.device), X_train], dim=-1)
xinp_test = torch.cat([torch.ones(X_test.shape[0], 1, dtype=X_test.dtype, device=X_test.device), X_test], dim=-1)

In [3]:
from tensor.layers import TensorTrainLayer
from tensor.bregman import KLDivBregman, AutogradBregman, BinaryKLDivBregman

N = 2
r = 2
p = X_train.shape[1]+1
#C = y_train.shape[1]

# Define Bregman function
layer = TensorTrainLayer(N, r, p, output_shape=1).cuda()
y_pred = layer(xinp_train)
w = 1/y_pred.std().item()
del y_pred
bf = BinaryKLDivBregman(w=w)

def convergence_criterion(y_pred, y_true):
    accuracy = (((w*y_pred).sigmoid() > 0.5).long() == y_true).float().mean().item()
    print("Accuracy:", accuracy)
    return accuracy > 0.82

In [4]:
with torch.inference_mode():
    layer.tensor_network.accumulating_swipe(xinp_train, y_train, bf, batch_size=64, lr=1.0, convergence_criterion=convergence_criterion, orthonormalize=False, method='exact', eps=1e-5, verbose=True, num_swipes=10)

Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 1.1968429600293327
Accuracy: 0.7699542045593262


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.798898693073401
Accuracy: 0.47732946276664734


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 4.225444908772426
Accuracy: 0.5430358648300171


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 10.600055020156042
Accuracy: 0.45669251680374146


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 15.009878531773067
Accuracy: 0.5872096419334412


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 11.408387546298432
Accuracy: 0.7433779835700989


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 7.088439739400984
Accuracy: 0.7701597809791565


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.34867185161424
Accuracy: 0.41052183508872986


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 16.291556869929174
Accuracy: 0.6976110935211182


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 8.352618698191232
Accuracy: 0.7584353685379028


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.6725224907190315
Accuracy: 0.7570772171020508


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.710039175872159
Accuracy: 0.40097787976264954


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 16.55518144837794
Accuracy: 0.6894106268882751


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 8.579132578494061
Accuracy: 0.7548307180404663


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.772091061945116
Accuracy: 0.7599917650222778


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.6295323565531055
Accuracy: 0.4228482246398926


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 15.951076626510924
Accuracy: 0.6631134152412415


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 9.305518166400173
Accuracy: 0.754206657409668


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.789327945360886
Accuracy: 0.7652189135551453


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.485147556558141
Accuracy: 0.41736412048339844


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 16.102559111729278
Accuracy: 0.6620121598243713


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 9.335936357797571
Accuracy: 0.7510057687759399


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.877743417311222
Accuracy: 0.7711802124977112


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.320483836255918
Accuracy: 0.4200437664985657


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 16.028541782032367
Accuracy: 0.6661013960838318


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 9.222983451072706
Accuracy: 0.7543094754219055


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.786488910983783
Accuracy: 0.7571506500244141


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.708011349352038
Accuracy: 0.43292808532714844


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 15.672648675542385
Accuracy: 0.6589214205741882


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 9.421310038825617
Accuracy: 0.7577526569366455


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.691381509135428
Accuracy: 0.775188684463501


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.209762036356448
Accuracy: 0.4083854556083679


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 16.35056857581268
Accuracy: 0.6742357611656189


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 8.9982948836292
Accuracy: 0.7533403635025024


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.813256969899681
Accuracy: 0.7701230645179749


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.349685204746469
Accuracy: 0.4172540009021759


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 16.105601126215635
Accuracy: 0.6635245084762573


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 9.294161843156646
Accuracy: 0.7595219016075134


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 6.642509705744195
Accuracy: 0.7592576146125793


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 6.649811259163539
Accuracy: 0.4265703558921814


In [5]:
# Calculate accuracy on train set
y_pred_train = layer(xinp_train)
accuracy_train = (((w*y_pred_train).sigmoid() > 0.5).long() == y_train).float().mean().item()
print('Train Acc:', accuracy_train)

Train Acc: 0.4265703558921814


In [6]:
# Calculate accuracy on test set
y_pred_test = layer(xinp_test)
accuracy_test = (((w*y_pred_test).sigmoid() > 0.5).long() == y_test).float().mean().item()
print('Test Acc:', accuracy_test)

Test Acc: 0.5607324838638306


In [7]:
layer.tensor_network.nodes

[TensorNode(name=A1, shape=torch.Size([51, 2]), labels=['p', 'r2']),
 TensorNode(name=A2, shape=torch.Size([2, 51]), labels=['r2', 'p']),
 TensorNode(name=X1, shape=torch.Size([15346, 51]), labels=['s', 'p']),
 TensorNode(name=X2, shape=torch.Size([15346, 51]), labels=['s', 'p'])]

In [3]:
import xgboost as xgb
from sklearn.metrics import accuracy_score, balanced_accuracy_score

# Convert PyTorch tensors to NumPy arrays
X_train_np = X_train.cpu().numpy()
y_train_np = y_train.cpu().numpy()
X_test_np = X_test.cpu().numpy()
y_test_np = y_test.cpu().numpy()

# Train an XGBoost classifier
xgb_clf = xgb.XGBClassifier(eval_metric='logloss')
xgb_clf.fit(X_train_np, y_train_np)

# Predict on the test set
y_pred_test_xgb = xgb_clf.predict(X_test_np)

# Calculate accuracy
accuracy_test_xgb = balanced_accuracy_score(y_test_np, y_pred_test_xgb)
print('XGBoost Test Accuracy:', accuracy_test_xgb)

XGBoost Test Accuracy: 0.9130282124122588
