In [None]:
%%capture
!pip install haiku
!pip install --upgrade jax

In [None]:
import jax 
import haiku as hk 
from sklearn import datasets
from sklearn.model_selection import train_test_split
from jax import numpy as jnp

from src.light_vision_attention import VisionAttn

X, Y = datasets.load_digits(n_class=10, return_X_y=True, as_frame=False)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=0.8, stratify=Y, random_state=123)

X_train, X_test, Y_train, Y_test = jnp.array(X_train, dtype=jnp.float32),\
                                   jnp.array(X_test, dtype=jnp.float32),\
                                   jnp.array(Y_train, dtype=jnp.float32),\
                                   jnp.array(Y_test, dtype=jnp.float32)

samples, features = X_train.shape
classes = jnp.unique(Y_test)

X_train = X_train.reshape(-1, 8, 8, 1) / 255
X_test = X_test.reshape(-1, 8, 8, 1) / 255

X_train.shape, X_test.shape, Y_train.shape, Y_test.shape

In [None]:
def VisionAttnFn(x):
    van = VisionAttn(32, 64, 4, 2, 4, 16, 0.2, use_fask_attn=False)
    dense = hk.Linear(len(classes))
    flatten = hk.Flatten()
    return jax.nn.softmax(dense(flatten(van(x))))

In [None]:
from jax import value_and_grad

rng = jax.random.PRNGKey(42) ## Reproducibility ## Initializes model with same weights each time.

# model = hk.transform(ConvNet)
model = hk.transform(VisionAttnFn)
params = model.init(rng, X_train[:5])
epochs = 25
batch_size = 256
learning_rate = jnp.array(1/1e4)

def CrossEntropyLoss(weights, input_data, actual):
    preds = model.apply(weights, rng, input_data)
    one_hot_actual = jax.nn.one_hot(actual, num_classes=len(classes))
    log_preds = jnp.log(preds)
    return - jnp.sum(one_hot_actual * log_preds)

def UpdateWeights(weights,gradients):
    return weights - learning_rate * gradients

    
for i in range(1, epochs+1):
    batches = jnp.arange((X_train.shape[0]//batch_size)+1) ### Batch Indices

    losses = [] ## Record loss of each batch
    for batch in batches:
        if batch != batches[-1]:
            start, end = int(batch*batch_size), int(batch*batch_size+batch_size)
        else:
            start, end = int(batch*batch_size), None

        X_batch, Y_batch = X_train[start:end], Y_train[start:end] ## Single batch of data

        loss, param_grads = value_and_grad(CrossEntropyLoss)(params, X_batch, Y_batch)
        #print(param_grads)
        params = jax.tree_map(UpdateWeights, params, param_grads) ## Update Params
        losses.append(loss) ## Record Loss

    print("CrossEntropy Loss : {:.3f}".format(jnp.array(losses).mean()))