In [1]:
from functools import partial   
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union
from flax import linen as nn
from flax.linen.activation import sigmoid
from flax.linen.activation import tanh
from flax.linen.initializers import zeros
from flax.linen.linear import default_kernel_init
from flax.linen.initializers import orthogonal
import jax.numpy as jnp
import numpy as np
import jax

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any  # this could be a real type?
Array = Any

class LSTMCell(nn.Module):  
  gate_fn: Callable[..., Any] = sigmoid
  activation_fn: Callable[..., Any] = tanh
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
  recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal()
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
  dtype: Optional[Dtype] = None
  param_dtype: Dtype = jnp.float32

  @nn.compact
  def __call__(self, carry, inputs):
    c, h = carry
    hidden_features = h.shape[-1]
    dense_h = partial(nn.Dense,
                      features=hidden_features,
                      use_bias=True,
                      kernel_init=self.recurrent_kernel_init,
                      bias_init=self.bias_init,
                      dtype=self.dtype,
                      param_dtype=self.param_dtype)
    dense_i = partial(nn.Dense,
                      features=hidden_features,
                      use_bias=False,
                      kernel_init=self.kernel_init,
                      dtype=self.dtype,
                      param_dtype=self.param_dtype)
    i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h))
    f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h))
    g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h))
    o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h))
    new_c = f * c + i * g
    new_h = o * self.activation_fn(new_c)
    return (new_c, new_h), new_h

  @staticmethod
  def initialize_carry(rng, batch_dims, size, init_fn=zeros):
    key1, key2 = jax.random.split(rng)
    mem_shape = batch_dims + (size,)
    return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


In [2]:
class SimpleScan(nn.Module):
    @nn.compact
    def __call__(self, xs,is_training:bool=True):
        LSTM = nn.scan(LSTMCell,
                    variable_broadcast="params",
                    split_rngs={"params": False},
                    in_axes=1,
                    out_axes=1)            
       
        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 100)
        ch, xs=LSTM()(ch, xs)                
        # xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)

        k=jnp.einsum("btv,bvs->bts",xs,xs.transpose(0,2,1))
        k=jax.nn.softmax(k/xs.shape[-1]**0.5,axis=1)
        xs=jnp.einsum("btv,bvs->bts",k,xs)

        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 200)
        ch, xs=LSTM()(ch, xs)  
        # xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)

        k=jnp.einsum("btv,bvs->bts",xs,xs.transpose(0,2,1))
        k=jax.nn.softmax(k/xs.shape[-1]**0.5,axis=1)
        xs=jnp.einsum("btv,bvs->bts",k,xs)

        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 10)
        ch, xs=LSTM()(ch, xs)        
        # xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)
        
        # ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 200)
        # ch, xs=LSTM()(ch, xs)
        # xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)
        
        xs=xs[:,-1,:]
        # xs=nn.Dense(features=10)(xs)        
        return xs

lstm = SimpleScan()
key_1, key_2, key_3 = jax.random.split(jax.random.PRNGKey(0), 3)
variables=lstm.init({'params': key_1, 'dropout':key_1}, jnp.ones([100, 28, 28]))

In [3]:
@jax.jit
def apply_model(state, images, labels,dropout_rng):
  def loss_fn(params,labels):
    logits= state.apply_fn({'params': params}, images,is_training=True,rngs={'dropout':dropout_rng})
    
    labels=jax.nn.one_hot(labels,10)
    loss = jnp.mean(jnp.sum(0.5*(logits-labels)**2,axis=-1))   
    
    return loss, (logits)    

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits)), grads = grad_fn(state.params,labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)  
  return grads, loss, accuracy,logits

@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [4]:
def create_train_state(rng):
  lstm = SimpleScan()
  key_1, key_2, key_3 = jax.random.split(jax.random.PRNGKey(0), 3)
  variables=lstm.init({'params': key_1, 'dropout':key_1}, jnp.ones([1, 28, 28]))
  params = variables['params']
  
  import optax
  from flax.training import train_state
  
  tx = optax.adam(learning_rate=0.01)
  state=train_state.TrainState.create(apply_fn=lstm.apply, params=params, tx=tx)
  return state,variables

In [5]:
@jax.jit
def predict(state, variables,image_i):  
  logits= state.apply_fn({'params': state.params},image_i,is_training=False)
  return logits
def test(state, batch_stats,test_ds):
  images = test_ds['image']
  labels = test_ds['label']
  images=images.reshape(-1,28,28)
  batchs=1000
  accuracy=0
  for i in range(0,len(images),batchs):
    image_i=images[i:i+batchs]
    label_i=labels[i:i+batchs]
    logits= predict(state, batch_stats,image_i)
    accuracy += jnp.sum(jnp.argmax(logits, -1) == label_i)  
  return accuracy/len(images)

In [6]:
def train_epoch(state, train_ds,  rng):
  batch_size=1000
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []
  for perm in perms:
    batch_images = train_ds['image'][perm, ...]
    batch_labels = train_ds['label'][perm, ...]
    batch_images=batch_images.reshape(-1,28,28)
    dropout_rng,rng=jax.random.split(rng)
    grads, loss, accuracy,y = apply_model(state, batch_images, batch_labels,dropout_rng=dropout_rng)
    
    state = update_model(state, grads)
    
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy,y

In [7]:
import tensorflow_datasets as tfds
import tensorflow as tf 
def get_datasets():
  with tf.device('/cpu:0'):
    ds_builder = tfds.builder('fashion_mnist')
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
    train_ds['image'] = jnp.float32(train_ds['image']) / 255.
    test_ds['image'] = jnp.float32(test_ds['image']) / 255.
    return train_ds, test_ds

In [8]:
def train_and_evaluate():  
  train_ds, test_ds = get_datasets()
  rng = jax.random.PRNGKey(0)  
  rng, init_rng = jax.random.split(rng)
  state,variables = create_train_state(init_rng)
  for epoch in range(1, 100 + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy,y = train_epoch(state, train_ds,
                                                    rng=input_rng)   
    print("") 
    print("epoch:",epoch)                                         
    print(f"train Error_ratio={train_accuracy:2.4f}")     
    print(f"test Error_ratio={test(state, variables,test_ds):2.4f}")
    # print(y)
  
  

In [9]:
train_and_evaluate()

Instructions for updating:
Use `tf.data.Dataset.get_single_element()`.



epoch: 1
train Error_ratio=0.6364
test Error_ratio=0.7418

epoch: 2
train Error_ratio=0.8011
test Error_ratio=0.8036

epoch: 3
train Error_ratio=0.8353
test Error_ratio=0.8385

epoch: 4
train Error_ratio=0.8490
test Error_ratio=0.8444

epoch: 5
train Error_ratio=0.8623
test Error_ratio=0.8571

epoch: 6
train Error_ratio=0.8714
test Error_ratio=0.8651

epoch: 7
train Error_ratio=0.8747
test Error_ratio=0.8632

epoch: 8
train Error_ratio=0.8808
test Error_ratio=0.8702

epoch: 9
train Error_ratio=0.8882
test Error_ratio=0.8709

epoch: 10
train Error_ratio=0.8917
test Error_ratio=0.8751

epoch: 11
train Error_ratio=0.8938
test Error_ratio=0.8730

epoch: 12
train Error_ratio=0.8990
test Error_ratio=0.8836

epoch: 13
train Error_ratio=0.9031
test Error_ratio=0.8842

epoch: 14
train Error_ratio=0.9042
test Error_ratio=0.8826

epoch: 15
train Error_ratio=0.9069
test Error_ratio=0.8898

epoch: 16
train Error_ratio=0.9088
test Error_ratio=0.8892

epoch: 17
train Error_ratio=0.9110
test Error_ra