In [None]:
import numpy as np
import numpy as np
# np.random.seed(420)
import jax.numpy as jnp
import functools, itertools
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time

import sys
module_path = '/home/chao/Pcode/TNL_2024/'
# Add the directory to sys.path
if module_path not in sys.path:
    sys.path.append(module_path)

from tenmul7 import NeuroTN


In [None]:

def generate_TT_adj_matrix(order, rank, dim_mode):
    adjm = np.diag(np.full((order-1,),rank), 1)
    adjm = adjm + adjm.transpose()
    np.fill_diagonal(adjm, dim_mode)

    return adjm

def generate_TR_adj_matrix(order, rank, dim_mode):
    adjm = np.diag(np.full((order-1,), rank), 1) if np.isscalar(rank) else np.diag(rank[:-1], 1)
    adjm[0, order-1] = rank if np.isscalar(rank) else rank[-1]
    adjm = adjm + adjm.transpose()
    np.fill_diagonal(adjm, dim_mode)

    return adjm

def index_to_onehot(indices, num_class):
    idx = np.asarray(indices) if isinstance(indices, list) else indices
    
    if idx.ndim != 2:
        raise ValueError("indices must be a 2D list or array")
    
    N, M = idx.shape

    one_hot_encoded = np.zeros((N, M, num_class), dtype=float)

    one_hot_encoded[np.arange(N)[:, None], np.arange(M), idx] = 1

    return one_hot_encoded



In [None]:
# Parameters
order_tensor = 8 # Order of the tensor
rank_tensor = 10 # Rank of the tensor
dim_tensor = 2 # Mode dimension
percentage_of_train = 0.7
percentage_of_test = 0.3
alpha = 0.5

if percentage_of_train + percentage_of_test > 1:
    raise ValueError('The total percentage should be less than 1')

# Use NeuroTN to generate tensor
adjm = generate_TT_adj_matrix(order_tensor,rank_tensor,dim_tensor)
print('adjm_data:\n', adjm)


# Data generation 
output_dim = np.array([0] * (order_tensor-1)+[0])
init_TN = functools.partial(np.random.normal, loc=0.0, scale=alpha/(np.sqrt(rank_tensor)))
# init_TN = functools.partial(np.random.normal, loc=0.0, scale=1)
DATA =  NeuroTN(adjm, output_dim, activation=lambda x:x, initializer = init_TN, core_mode=2)

# siz_data, siz_cores = compression_ratio(adjm, output_dim)

# print('siz_data:', siz_data)
# print('siz_cores:', siz_cores)

idx_data = [list(combo) for combo in itertools.product(range(dim_tensor), repeat=order_tensor)]
idx_onehot = index_to_onehot(idx_data, num_class=dim_tensor)
values = DATA.network_contraction(idx_onehot, return_contraction=True)

permuted_idx = np.random.permutation(idx_onehot.shape[0])
length_training = int(len(permuted_idx)*percentage_of_train)
length_test = int(len(permuted_idx)*percentage_of_test)

data_training = idx_onehot[permuted_idx[:length_training]]
values_training = values[permuted_idx[:length_training]]
data_test = idx_onehot[permuted_idx[length_training:(length_test+length_training)]]
values_test = values[permuted_idx[length_training:(length_test+length_training)]]

print(len(permuted_idx))
print(length_training)

print('data_test.shape', data_test.shape)
print('values_test.shape', values_test.shape)
print('Mean of data_train', jnp.mean(values_training))
print('Variance of data_train', jnp.var(values_training))

# hist, bin_edges = np.histogram(values_training, bins=50)

# plt.hist(values_training, bins=10, edgecolor='black')
# plt.show()

In [None]:
adjm_decomp = generate_TT_adj_matrix(order_tensor,300,dim_tensor)
output_decomp = output_dim

print('adjm_decomp:\n', adjm_decomp)

print('========================')

TN = NeuroTN(adjm_decomp, output_decomp, activation=lambda x:x, initializer = init_TN, core_mode=2)
TN.target_shape = None



size_data = idx_onehot.shape[0]
batch_size = size_data
learning_rate = 1e-5

# print(data_training.shape)
# print(values_training.shape)
# print(type(data_training))


In [None]:
%matplotlib inline
loss_training = []
loss_test = []
eign_values = []
for epoch in range(100000):
    loss_training.append(TN.iteration(learning_rate, data_training, values_training, verbose=True))
    if epoch % 100 == 0:
        ntk_tensor = TN.ntk(data_training, opt_path='optimal')
        eign_value, eign_vector = jnp.linalg.eigh(ntk_tensor[0])
        eign_values.append(eign_value)
        clear_output(wait=True)
        plt.subplot(2, 2, 1)
        for ev in eign_values:
            plt.plot(jnp.log(ev/jnp.max(ev)))
        plt.title('eignValue of NTK')

        plt.subplot(2,2,3)
        # print(eign_vector[:,-1].shape)
        plt.plot(eign_vector[:,-1])
        plt.title('1st eignVector')
        

        predicted = TN.network_contraction(data_test, return_contraction=True)
        loss_test.append(np.mean(np.sum(np.square(predicted - values_test).reshape(predicted.shape[0],-1), axis=-1)))
        plt.subplot(2,2,2)
        plt.plot(np.log(loss_training))
        plt.title('Training loss (log)')
        plt.subplot(2,2,4)
        plt.plot(np.log(loss_test))
        plt.title('Test loss (log)')
        plt.show()
        print('Epoch: ', epoch, 'Training Loss: ', loss_training[-1], '; Testing Loss: ', loss_test[-1])