### Not working fully!

In [1]:
import numpy as np
from random import randint
import orbax.checkpoint
import optax
import jax
from flax.training import train_state, orbax_utils
import jax.numpy as jnp
import jax.profiler
from tn4ml.models.lotenet import loTeNet
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", True)

# Setup
- make sure you have JAX and CUDA versions set up correctly
- refer to:
    - [Jax for GPU](https://jax.readthedocs.io/en/latest/installation.html#nvidia-gpu)
    - [Jax releases with CUDA](https://storage.googleapis.com/jax-releases/jax_cuda_releases.html)


In [None]:
gpus_to_use = jax.devices('gpu')
gpus_to_use

In [None]:
# for saving checkpoints
ckpt_dir = '/Users/emapuljak/WorkDirs/medical_study'

## Define training parameters

In [None]:
config = {}
config['batch_size'] = 4
config['epochs'] = 5
config['learning_rate'] = 1e-3
config['seed'] = 42
config['validation_split'] = 0.2

## Define model parameters

In [None]:
config_model = {}
config_model['kernel'] = 2
config_model['output_dim'] = 10 # number of classes
config_model['bond_dim'] = 5
config_model['virtual_dim'] = config_model['bond_dim']

## Define dataset

In [None]:
train_size = 12
test_size=100

In [None]:
# check on fashion_mnist

train_ds, test_ds = tfds.load('fashion_mnist', split=['train','test'], as_supervised=True, data_dir='/eos/user/e/epuljak/')

In [None]:
def zero_pad_image(image, target_size):
    if len(image.shape) not in [2,3]:
        ValueError("Image must be 2D or 3D!")

    current_size = np.array(image.shape)
    pad_width = target_size - current_size

    # Calculate padding for each side
    pad_before = pad_width // 2
    pad_after = pad_width - pad_before

    # Pad the image with zeros
    if len(current_size) == 3:
        padded_image = np.pad(image, ((pad_before[0], pad_after[0]),
                                    (pad_before[1], pad_after[1]),
                                    (pad_before[2], pad_after[2])),
                            mode='constant', constant_values=0)
    else:
        padded_image = np.pad(image, ((pad_before[0], pad_after[0]),
                                    (pad_before[1], pad_after[1])),
                            mode='constant', constant_values=0)
    return padded_image

In [None]:
def check_image_shape(image, k):
    # check if image needs to be zero padded - each dimension to power of k
    new_shape = [int(k ** np.ceil(np.log2(dim) / np.log2(k))) for dim in image.shape]
    if new_shape != image.shape:
        image = zero_pad_image(image, new_shape)
    return image

In [None]:
train_images=[]; train_labels=[]
for image, labels in train_ds:
    # zero pad if needed
    image = np.array(image)/255.0
    image = check_image_shape(image, config_model['kernel'])

    train_images.append(image)

    train_labels.append(labels)

In [None]:
test_images=[]; test_labels=[]
for image, labels in test_ds:
    # zero pad if needed
    image = np.array(image)/255.0
    image = check_image_shape(image, config_model['kernel'])

    test_images.append(image)

    test_labels.append(labels)

In [None]:
train_images = np.array(train_images[:train_size])
train_labels = np.array(train_labels[:train_size])

test_images = np.array(test_images[:test_size])
test_labels = np.array(test_labels[:test_size])

In [None]:
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(buffer_size=len(train_images), seed=1234).batch(config['batch_size'], drop_remainder=True)
train_dataset_labels = tf.data.Dataset.from_tensor_slices(train_labels).shuffle(buffer_size=len(train_labels), seed=1234).batch(config['batch_size'], drop_remainder=True)

## Define model loTeNet

In [None]:
model = loTeNet(input_dim=train_images[0].shape,\
                   output_dim=config_model['output_dim'],\
                    bond_dim=config_model['bond_dim'],\
                    kernel = config_model['kernel'],\
                    virtual_dim = config_model['virtual_dim'])

## Training pipeline

In [None]:
key = jax.random.PRNGKey(config['seed'])
optimiser = optax.adam(learning_rate=config['learning_rate'])

In [None]:
def create_train_step(key, model, optimiser, image_shape):
  dummy_input = jnp.ones(shape=image_shape) # Dummy Input for initialization of MODEL
  params = model.init(key, dummy_input)
  state = train_state.TrainState.create(apply_fn=model.apply,
                                        params=params['params'],
                                        tx=optimiser)
  #opt_state = optimiser.init(params)

  @jax.jit
  def parallelized_loss(loss, inputs, devices, in_axes=(None, 0)):
    v_loss = jax.vmap(loss, in_axes=in_axes)
    p_loss = jax.pmap(v_loss, axis_name='batch')

    return p_loss(inputs)
  
  @jax.jit
  def loss_fn(params, data, y_true):
    # vmap for batching
    #with jax.Device(gpus_to_use):
    y_pred = parallelized_loss(state.apply_fn, ({'params': params}, data), gpus_to_use)
    
    loss = optax.softmax_cross_entropy_with_integer_labels(y_pred, y_true).mean()
    return loss

  @jax.jit
  def train_step(state, data, y_true):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, data, y_true)

    #updates, opt_state = optimiser.update(grads, opt_state, params)
    state = state.apply_gradients(grads=grads)

    return state, loss

  return train_step, state

def model_train(n_epochs, train_dataset, targets, state, train_step_func):
    history={}
    history['loss'] = []
    for epoch in range(n_epochs):
        loss_batch = 0
        batch_num = 0
        
        for batch_x, batch_y in zip(list(train_dataset.as_numpy_iterator()), list(targets.as_numpy_iterator())):
            state, loss_curr = train_step_func(state, jnp.asarray(batch_x), jnp.asarray(batch_y))
            loss_batch += loss_curr
            batch_num+=1
        
        print(f'Epoch: {epoch}, loss = {loss_batch/batch_num}')
        history['loss'].append(loss_batch/batch_num)
    return history, state

In [None]:
train_step, state = create_train_step(key, model, optimiser, train_images[0].shape)

In [None]:
history, state = model_train(config['epochs'], train_dataset, train_dataset_labels, state, train_step)

In [None]:
# plot loss function
plt.figure()
plt.plot(range(len(history['loss'])), history['loss'])
plt.show()

## Save model checkpoints 
- **TODO**: check does it work

In [None]:
save_dir = '.'

In [None]:
trace_data = jax.profiler.get_trace()
print(trace_data)

In [None]:
ckpt = {}
ckpt['model'] = state
ckpt['history'] = history

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save(f'{save_dir}/test_save', ckpt, save_args=save_args)

In [None]:
# restore checkpoints
raw_restored = orbax_checkpointer.restore(f'{save_dir}/test_save')

In [None]:
raw_restored['model']['params'].keys()

In [None]:
def loss_fn(params, data, y_true):
    # vmap for batching
    y_pred = jax.vmap(state.apply_fn, in_axes=(None, 0))({'params': params}, data)
    loss = optax.softmax_cross_entropy_with_integer_labels(y_pred, y_true).mean()
    return loss

In [None]:
loss_fn(raw_restored['model']['params'], test_images, test_labels)