# Classification for Bars-Stripes dataset

Follow tutorial from $\rightarrow$ https://pennylane.ai/qml/demos/tutorial_tn_circuits/
- check on images 4x4, 8x8, 16x16
- on MPS - acc = 50%
- check on TTNs - **TODO**

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax.nn.initializers import *
import optax
from tn4ml.embeddings import *
from tn4ml.util import *
from tn4ml.models.model import *
from tn4ml.models.smpo import *
from tn4ml.initializers import *
from tn4ml.loss import *

In [None]:
def generate_bars_and_stripes(n_samples, height, width, noise_std):
    """Data generation procedure for 'bars and stripes'.

    Args:
        n_samples (int): number of data samples to produce
        height (int): number of pixels for image height
        width (int): number of pixels for image width
        noise_std (float): standard deviation of Gaussian noise added to the pixels
    """
    X = np.ones([n_samples, 1, height, width]) * -1
    y = []

    for i in range(len(X)):
        if np.random.rand() > 0.5:
            rows = np.where(np.random.rand(width) > 0.5)[0]
            X[i, 0, rows, :] = 1.0
            y.append([0, 1])
        else:
            columns = np.where(np.random.rand(height) > 0.5)[0]
            X[i, 0, :, columns] = 1.0
            y.append([1, 0])
        X[i, 0] = X[i, 0] + np.random.normal(0, noise_std, size=X[i, 0].shape)

    return X, np.array(y)

train_images, train_labels = generate_bars_and_stripes(1000, 4, 4, 0.5)
test_images, test_labels = generate_bars_and_stripes(200, 4, 4, 0.5)

In [None]:
n_classes=2

In [None]:
train_images = np.squeeze(train_images)
test_images = np.squeeze(test_images)

In [None]:
fig, axes = plt.subplots(ncols=4, figsize=(8,8))

axes[0].imshow(np.reshape(-train_images[0], (4,4)), cmap='gray')
axes[1].imshow(np.reshape(-train_images[4], (4,4)), cmap='gray')
axes[2].imshow(np.reshape(-train_images[6], (4,4)), cmap='gray')
axes[3].imshow(np.reshape(-train_images[3], (4,4)), cmap='gray')

In [None]:
# # Generate N samples
# N = 1000  # For example, generate 10 samples
# n_classes = 2
# train_size = int(N * 0.9)
# # train
# #train_data = sample_bars_and_stripes(int(N*0.9)) # 90% of the data for training
# # train_images = np.array([data[0] for data in train_data])
# # train_labels = np.array([data[1] for data in train_data])
# train_data = 

# # test
# #test_data = sample_bars_and_stripes(N)
# test_images = np.array([data[0] for data in test_data])
# test_labels = np.array([data[1] for data in test_data])

In [None]:
def visualize_patterns(dataset, num_patterns=10):
    fig, axes = plt.subplots(1, num_patterns, figsize=(num_patterns * 2, 2))
    for i, ax in enumerate(axes):
        pattern = dataset[i]
        ax.imshow(pattern, cmap='binary')
    plt.tight_layout()
    plt.show()
visualize_patterns(train_images, 16)


In [1]:
import quimb as qu

ttn = qu.experimental.merabuilder.TTN_randtree_rand(16, 5, phys_dim=2, group_size=2, iso=False, seed=None)

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


AttributeError: module 'quimb' has no attribute 'experimental'

**Define TN model**

In [None]:
# model parameters
L = 16
initializer = noise_init(1e-2, dtype=jnp.float64)
key = jax.random.key(42)
shape_method = 'noteven'
bond_dim = 4
phys_dim = (2, n_classes)
spacing = L

In [None]:
model = SMPO_initialize(L=L,
                        initializer=initializer,
                        key=key,
                        shape_method=shape_method,
                        spacing=spacing,
                        bond_dim=bond_dim,
                        phys_dim=phys_dim,
                        cyclic=False)

In [None]:
model

In [None]:
# training parameters
optimizer = optax.adam
strategy = 'global'
loss = loss_wrapper_optax(optax.softmax_cross_entropy)
train_type ='supervised'
embedding = basis_quantum_encoding(basis={0: np.array([1, 0]), 1: np.array([0, 1])})
learning_rate = 1e-3

# Exponential decay of the learning rate.
scheduler = optax.exponential_decay(
    init_value=0.0001,
    transition_steps=1000,
    decay_rate=0.99)

# Combining gradient transforms using `optax.chain`.
gradient_transforms = [
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
]

In [None]:
model.configure(gradient_transforms=gradient_transforms, strategy=strategy, loss=loss, train_type=train_type, learning_rate=learning_rate)

In [None]:
epochs = 100
batch_size = 32

In [None]:
# early stopping from flax
from flax.training.early_stopping import EarlyStopping

earlystop = EarlyStopping(min_delta=0, patience=5)

In [None]:
history = model.train(train_images.reshape(train_images.shape[0], 16),
                    targets = train_labels,
                    epochs = epochs,
                    batch_size = batch_size,
                    embedding = embedding,
                    earlystop=earlystop,
                    normalize = True,

                    dtype = jnp.float64)

In [None]:
plt.figure()
plt.plot(range(len(history['loss'])), history['loss'], label='train')
#plt.plot(range(len(history['val_loss'])), history['val_loss'], label='val')
plt.legend()
plt.show()
#plt.savefig('tests/mnist_supervised_model6/loss.pdf')

**Evaluate**

In [None]:
from tn4ml.models.model import _batch_iterator

In [None]:
batch_size = 10
correct_predictions = 0; total_loss = 0

for batch_data in _batch_iterator(test_images.reshape(test_images.shape[0], 16), test_labels, batch_size=batch_size):
    x, y = batch_data
    x = jnp.array(x, dtype=jnp.float64)
    y = jnp.array(y)

    y_pred = jnp.squeeze(jnp.array(jax.vmap(model.predict, in_axes=(0, None, None))(x, embedding, False)[0]))
    y_pred
    predicted = jnp.argmax(y_pred, axis=-1)
    true = jnp.argmax(y, axis=-1)

    correct_predictions += jnp.sum(predicted == true).item() / batch_size

accuracy = correct_predictions / (len(test_images)//batch_size)
print(f"Accuracy: {accuracy}")