In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from typing import Callable, Tuple, Any
import numpy as np
from functools import partial
import pdb
import tensorflow as tf

import matplotlib.pyplot as plt
import seaborn

In [None]:
# Load the MNIST dataset
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
X_train, Y_train, X_test, Y_test = jnp.expand_dims(jnp.array(X_train/255), axis=-1), jnp.expand_dims(jnp.array(Y_train), axis=-1), jnp.expand_dims(jnp.array(X_test/255), axis=-1), jnp.expand_dims(jnp.array(Y_test), axis=-1)

In [None]:
Y_train.shape

(60000, 1)

In [None]:
@jax.jit
def get_sample_and_label(idx, inputs, labels):
    return inputs[idx], labels[idx]

arr = jnp.array([32,67,432,2])
xb, yb = jax.vmap(get_sample_and_label, in_axes=(0,None,None))(arr, X_train, Y_train)
xb.shape, yb.shape

((4, 28, 28, 1), (4, 1))

In [None]:
class S6_Unet(nn.Module):
    embed_dim: int = 64
    n_latent_dim: int = 8

    @nn.compact
    def __call__(self, x, h=0):
        # x.shape = (Ba, n_filters, seq_size, hidden_dim)
        _, seq_size , hidden_dim = x.shape
        Ba = x.shape[0]
        A = -1* self.param('A', nn.initializers.ones, (1,  self.n_latent_dim, self.embed_dim , self.embed_dim))
        B = self.param('B', nn.initializers.ones, (1,  self.n_latent_dim, 1, self.embed_dim))
        C = self.param('C', jax.random.normal, (1,  self.n_latent_dim, self.embed_dim, 1))
        D = self.param('D', jax.random.normal, (1,  self.n_latent_dim, hidden_dim, 1))
        S = -0.5* self.param('S', nn.initializers.ones, (1,  self.n_latent_dim, 1, self.embed_dim))
        h = h*jnp.ones((Ba,  self.n_latent_dim, hidden_dim, self.embed_dim))
        x = self.ssm(x, A, B, C, D, S, h)
        return x
    def ssm(self, x, A, B, C, D, S, h):
        def compute_delta(x, state, A, B):
            # state shape: (B, N, D)
            # A shape: (N, N)
            # B shape: (N, D)

            # State update: (B, N, D, H) <- (B, N, D, H) @ (1, N, H, H)
            state_contribution = nn.RMSNorm()(state@A) #jnp.einsum('bnd,nn->bnd', state, A)
            # Input contribution: (B, N, D, H) <- (B, 1, D, 1) @ (1, N, 1, H)
            input_contribution = x@B
            # print("state_contribution + input_contribution: ", state_contribution.sum(), input_contribution.sum())
            return state_contribution + input_contribution

        for k in range(x.shape[1]):
            # Compute ∆, the state update candidates
#             pdb.set_trace()
            # print("k: ", k)
            delta = compute_delta(jnp.expand_dims(x[:,k,:], axis=(1,3)), h, A, B)
            #h: (B, N, D, H) <- (B, N, D, H)*(B, 1, D, 1)*(1, N, 1, H)
            h += delta*jnp.expand_dims(x[:,k,:], axis=(1,3))*S
            # print("h sum:", h.min(), h.max(), jnp.sum(h))
        # pdb.set_trace()
        #y: (B, N, D, 1) <- (B, N, D, H) @ (1, N, H, 1) + (B, 1, D, 1)*(1, N, D, 1)
        y = h @ C + jnp.expand_dims(x[:,-1,:], axis=(1,3))*D
        return jnp.einsum('bnd1->b1d', y)

In [None]:
class VisionEncoder(nn.Module):
    embed_dim: int = 64
    patch_size: int = 2
    hidden_dim: int = 32

    @nn.compact
    def __call__(self, x):
        # Assuming input shape: (batch, height, width, channels)
        # Patch embedding

########################################################################################################################################

        for i in range(3):
            x = nn.Conv(
                features=self.hidden_dim*(2**i),
                kernel_size=(self.patch_size, self.patch_size),
                strides=(1, 1)
                )(x)

            x = nn.max_pool(x, window_shape=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size))

        print("After conv", x.shape)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        print("After reshape",x.shape)
        x = S6_Unet()(jax.nn.silu(nn.RMSNorm()(x)))
        print("After S6", x.shape)

        return nn.Dense(10)(x)
########################################################################################################################################



########################################################################################################################################
        # x = jax.nn.silu(nn.Conv(features=64, kernel_size=(self.patch_size, self.patch_size))(x))
        # print("After first conv: ", x.shape)
        # x = nn.max_pool(x, window_shape=(2, 2), strides=(self.patch_size, self.patch_size))
        # print("After first max pool: ", x.shape)
        # x = jax.nn.silu(nn.Conv(features=128, kernel_size=(self.patch_size, self.patch_size))(x))
        # print("After 2nd conv: ", x.shape)
        # x = nn.max_pool(x, window_shape=(2, 2), strides=(self.patch_size, self.patch_size))
        # print("After 2nd max pool: ", x.shape)
        # x = jax.nn.silu(nn.Conv(features=256, kernel_size=(self.patch_size, self.patch_size))(x))
        # print("After 3rd conv: ", x.shape)
        # x = x.reshape(x.shape[0], -1)  # flatten
        # print("After flatten: ", x.shape)
        # x = jax.nn.silu(nn.Dense(256)(x))
########################################################################################################################################

########################################################################################################################################
        # # Flatten patches
        # x =x.reshape(x.shape[0], -1, x.shape[-1])

        # # Transformer blocks
        # for _ in range(6):  # 6 layers
        #     x = nn.MultiHeadDotProductAttention(
        #         num_heads=8,
        #         qkv_features=self.hidden_dim
        #     )(x)
        #     x = nn.LayerNorm()(x)

        # # Projection to embedding dimension
        # x = jax.nn.silu(nn.Dense(self.embed_dim)(jax.nn.silu(x.mean(axis=1))))
########################################################################################################################################

        # x = nn.Dense(10)(x)
        # return jnp.expand_dims(x, axis=1)

model = VisionEncoder()
params = model.init(jax.random.PRNGKey(451),jnp.ones((1,28,28,1)))
n_params = sum(p.size for p in jax.tree_util.tree_leaves(params))
print(f"Total number of parameters: {n_params:_}")


After conv (1, 3, 3, 128)
After reshape (1, 9, 128)
After S6 (1, 1, 128)
Total number of parameters: 78_634


In [None]:
model.apply(params,jnp.ones((1,28,28,1))).shape

After conv (1, 3, 3, 128)
After reshape (1, 9, 128)
After S6 (1, 1, 128)


(1, 1, 10)

In [None]:
# opt = optax.sgd(learning_rate, momentum)
opt = optax.adamw(learning_rate=0.001)


opt_state = opt.init(params)

In [None]:
def loss_fun(params, x, y):
    logits = model.apply(params, x)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y)
    return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean(), accuracy

@jax.jit
def eval_step(params, x, y):
    logits = model.apply(params, x)
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == y)
    return optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=y).mean(), accuracy

In [None]:
%%time
key = jax.random.PRNGKey(42)  # Replace 42 with any random seed
BATCH_SIZE = 64
learning_rate = 0.1
momentum = 0.9
train_length = len(X_train)//BATCH_SIZE
kernel_s = 2
eval_iters = 1000

all_train_losses = []
all_eval_losses = []

all_train_accuracy =  []
all_test_accuracy = []

eval_iters = 50000

@jax.jit
def step(idx, params, opt_state):
    xb, yb = jax.vmap(get_sample_and_label, in_axes=(0,None,None))(idx, X_train, Y_train)
    (loss, train_accuracy), grad = jax.value_and_grad(loss_fun, has_aux=True)(params, xb, yb)
    updates, opt_state = opt.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, train_accuracy

for i in range(eval_iters):
    key, subkey = jax.random.split(key)
    iix = jax.random.randint(subkey, shape=(BATCH_SIZE,), minval=0, maxval=len(X_train))
    params, opt_state, loss, train_accuracy = step(iix, params, opt_state)

    # once every N_FREQ_EVAL we compute loss on the validation set
    if i%100 == 0:
        key, subkey = jax.random.split(key)
        indxs = jax.random.randint(subkey, shape=(BATCH_SIZE,), minval=0, maxval=len(X_test))
        xt, yt = jax.vmap(get_sample_and_label, in_axes=(0,None,None))(indxs, X_test, Y_test)
        eval_loss, eval_accuracy = eval_step(params, xt, yt)
        all_train_losses.append(loss)
        all_eval_losses.append(eval_loss)
        all_train_accuracy.append(train_accuracy)
        all_test_accuracy.append(eval_accuracy)
        print('####################################################################################################')
        print("Step: ", i,"\t\t Train Loss: ", format(loss, ".6f"),"\t\t Train Accuracy: ", format(train_accuracy, ".2%"))
        print("Step: ", i,"\t\t Eval Loss: ", format(eval_loss, ".6f"),"\t\t Eval Accuracy: ", format(eval_accuracy, ".2%"))

After conv (64, 3, 3, 128)
After reshape (64, 9, 128)
After S6 (64, 1, 128)
After conv (64, 3, 3, 128)
After reshape (64, 9, 128)
After S6 (64, 1, 128)
####################################################################################################
Step:  0 		 Train Loss:  626.939819 		 Train Accuracy:  3.12%
Step:  0 		 Eval Loss:  418.483856 		 Eval Accuracy:  4.69%
####################################################################################################
Step:  100 		 Train Loss:  1.832799 		 Train Accuracy:  34.38%
Step:  100 		 Eval Loss:  1.944350 		 Eval Accuracy:  40.62%
####################################################################################################
Step:  200 		 Train Loss:  2.074398 		 Train Accuracy:  29.69%
Step:  200 		 Eval Loss:  2.050316 		 Eval Accuracy:  28.12%
####################################################################################################
Step:  300 		 Train Loss:  2.727752 		 Train Accuracy:  15.62%
Step:  300 

In [None]:
# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')



ax1.plot(all_train_losses, label='train_loss')
ax1.plot(all_eval_losses, label='eval_loss')

ax2.plot(all_train_accuracy, label='train_accuracy')
ax2.plot(all_test_accuracy, label='eval_accuracy')

ax1.legend()
ax2.legend()
plt.show()
plt.clf()

In [None]:
asfdghywqdw

NameError: name 'asfdghywqdw' is not defined

In [None]:
# Define the Text Encoder
class TextEncoder(nn.Module):
    embed_dim: int = 128
    vocab_size: int = 10
    max_length: int = 77

    @nn.compact
    def __call__(self, x):
        # x shape: (batch, sequence_length)
        x = nn.Embed(self.vocab_size, self.embed_dim)(x)

        x = S6_Unet()(nn.RMSNorm()(x))

        # Transformer blocks
        for _ in range(6):
            x = nn.MultiHeadDotProductAttention(
                num_heads=8,
                qkv_features=self.embed_dim
            )(x)
            x = nn.LayerNorm()(x)

        # Projection to embedding dimension
        x = nn.Dense(self.embed_dim)(x[:, 0])  # Take CLS token
        return x

In [None]:
key = jax.random.PRNGKey(42)
# x = jnp.expand_dims(xb[0],axis=0)
x = jnp.ones((64,20), dtype=jnp.int64)

model = TextEncoder()

params = model.init(jax.random.PRNGKey(45),x)
print(params.keys())
n_params = sum(p.size for p in jax.tree_util.tree_leaves(params))

print(f"Total number of parameters: {n_params:_}")

output = model.apply(params, x)
print(output.shape)

In [None]:
params['params']['Embed_0']['embedding'].shape