In [1]:
# %%
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import torch
from numpy import genfromtxt
torch.set_default_dtype(torch.float64)

In [2]:
from torch.nn import functional as F

# train
x = genfromtxt('data/mdata_train.txt', delimiter=',')
X_train = x.T
y_train = genfromtxt('data/mdata_train_l.txt',delimiter=',')

x_test = genfromtxt('data/mdata_test.txt', delimiter=',')
X_test = x_test.T
y_test = genfromtxt('data/mdata_test_l.txt',delimiter=',')

# Convert to torch
X_train = torch.tensor(X_train, dtype=torch.float64)
y_train = torch.tensor(y_train == 1, dtype=torch.long)
X_test = torch.tensor(X_test, dtype=torch.float64)
y_test = torch.tensor(y_test == 1, dtype=torch.long)

y_train = F.one_hot(y_train, num_classes=2).to(dtype=torch.float64)
y_test = F.one_hot(y_test, num_classes=2).to(dtype=torch.float64)

# Move data to GPU
X_train, y_train = X_train.cuda(), y_train.cuda()
X_test, y_test = X_test.cuda(), y_test.cuda()

# Prepare train input
xinp_train = torch.cat([torch.ones(X_train.shape[0], 1, dtype=X_train.dtype, device=X_train.device), X_train], dim=-1)
xinp_test = torch.cat([torch.ones(X_test.shape[0], 1, dtype=X_test.dtype, device=X_test.device), X_test], dim=-1)

In [3]:
from tensor.layers import TensorTrainLayer
from tensor.bregman import KLDivBregman, XEAutogradBregman

N = 2
r = 5
p = X_train.shape[1]+1
C = y_train.shape[1]-1

def convergence_criterion(y_pred, y_true):
    y_pred = torch.cat((y_pred, torch.zeros_like(y_pred[:, :1])), dim=1)
    accuracy = (y_pred.argmax(dim=-1) == y_true.argmax(dim=-1)).float().mean().item()
    print("Accuracy:", accuracy)
    #return accuracy > 0.95
    return False

# Define Bregman function
layer = TensorTrainLayer(N, r, p, output_shape=C).cuda()

In [4]:
with torch.inference_mode():
    y_pred = layer(xinp_train)
    w = 1/y_pred.std().item()
    del y_pred
bf = XEAutogradBregman(w=w)

layer.tensor_network.accumulating_swipe(xinp_train, y_train, bf, batch_size=64, lr=1.0, convergence_criterion=convergence_criterion, orthonormalize=False, method='exact', eps=1e-4, verbose=True, num_swipes=10)

Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.9622435052195917
Accuracy: 0.7928816676139832


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.483949990438929
Accuracy: 0.8244428038597107


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.40118026879141416
Accuracy: 0.8363286852836609


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.38197519161031135
Accuracy: 0.8437362313270569


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.3704797250305111
Accuracy: 0.8458579182624817


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.3691920097084087
Accuracy: 0.8469665050506592


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.366287089024954
Accuracy: 0.847707986831665


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.3658157091699136
Accuracy: 0.8491616249084473


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.3626640183897214
Accuracy: 0.8498150110244751


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.36201644450856346
Accuracy: 0.8502408266067505


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3608057348273088
Accuracy: 0.850314199924469


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.3607612516246867
Accuracy: 0.8503876328468323


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.3608846514464747
Accuracy: 0.8508574962615967


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.3600133506552372
Accuracy: 0.8512906432151794


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3591933755278234
Accuracy: 0.8514007925987244


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.35913692631153
Accuracy: 0.8513787388801575


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.35871505605399107
Accuracy: 0.8516577482223511


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.358669566741379
Accuracy: 0.8516870737075806


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3585465674137915
Accuracy: 0.8517898917198181


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.3584334258827094
Accuracy: 0.8517751693725586


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.35828225526250057
Accuracy: 0.8519073128700256


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.3582509705343618
Accuracy: 0.8521936535835266


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3580691530225355
Accuracy: 0.852333128452301


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.35796150486311384
Accuracy: 0.8521936535835266


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.3578704833075499
Accuracy: 0.8521789908409119


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.357817747624383
Accuracy: 0.8524065613746643


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3577440224135541
Accuracy: 0.8524946570396423


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.35762951027355305
Accuracy: 0.8525460362434387


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.35768238871775554
Accuracy: 0.8526414632797241


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.35765450961913425
Accuracy: 0.852663516998291


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3576251676132071
Accuracy: 0.8528029918670654


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.35753664524711676
Accuracy: 0.8526708483695984


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.35762295961506
Accuracy: 0.8527589440345764


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.35761022846279655
Accuracy: 0.85270756483078


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.3575840202376702
Accuracy: 0.8527736067771912


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.35752148190713706
Accuracy: 0.8526781797409058


Left to right pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A1): 0.35760805626339043
Accuracy: 0.8527883291244507


Left to right pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Left loss (A2): 0.3576158118302777
Accuracy: 0.8526561856269836


Right to left pass (A2):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A2): 0.35757342716765134
Accuracy: 0.8527736067771912


Right to left pass (A1):   0%|          | 0/2129 [00:00<?, ?it/s]

Right loss (A1): 0.3575356110388983
Accuracy: 0.8527883291244507


False

In [5]:
print("Train accuracy:")
convergence_criterion(layer(xinp_train), y_train)
print("Test accuracy:")
convergence_criterion(layer(xinp_test), y_test)
None

Train accuracy:
Accuracy: 0.8527883291244507
Test accuracy:
Accuracy: 0.8393067121505737
