In [2]:
from functools import partial   
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple, Union
from flax import linen as nn
from flax.linen.activation import sigmoid
from flax.linen.activation import tanh
from flax.linen.initializers import zeros
from flax.linen.linear import default_kernel_init
from flax.linen.initializers import orthogonal
import jax.numpy as jnp
import numpy as np
import jax

PRNGKey = Any
Shape = Tuple[int]
Dtype = Any  # this could be a real type?
Array = Any

class LSTMCell(nn.Module):  
  gate_fn: Callable[..., Any] = sigmoid
  activation_fn: Callable[..., Any] = tanh
  kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init
  recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = orthogonal()
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros
  dtype: Optional[Dtype] = None
  param_dtype: Dtype = jnp.float32

  @nn.compact
  def __call__(self, carry, inputs):
    c, h = carry
    hidden_features = h.shape[-1]
    dense_h = partial(nn.Dense,
                      features=hidden_features,
                      use_bias=True,
                      kernel_init=self.recurrent_kernel_init,
                      bias_init=self.bias_init,
                      dtype=self.dtype,
                      param_dtype=self.param_dtype)
    dense_i = partial(nn.Dense,
                      features=hidden_features,
                      use_bias=False,
                      kernel_init=self.kernel_init,
                      dtype=self.dtype,
                      param_dtype=self.param_dtype)
    i = self.gate_fn(dense_i(name='ii')(inputs) + dense_h(name='hi')(h))
    f = self.gate_fn(dense_i(name='if')(inputs) + dense_h(name='hf')(h))
    g = self.activation_fn(dense_i(name='ig')(inputs) + dense_h(name='hg')(h))
    o = self.gate_fn(dense_i(name='io')(inputs) + dense_h(name='ho')(h))
    new_c = f * c + i * g
    new_h = o * self.activation_fn(new_c)
    return (new_c, new_h), new_h

  @staticmethod
  def initialize_carry(rng, batch_dims, size, init_fn=zeros):
    key1, key2 = jax.random.split(rng)
    mem_shape = batch_dims + (size,)
    return init_fn(key1, mem_shape), init_fn(key2, mem_shape)


In [3]:
class SimpleScan(nn.Module):
    @nn.compact
    def __call__(self, xs,is_training:bool=True):
        LSTM = nn.scan(LSTMCell,
                    variable_broadcast="params",
                    split_rngs={"params": False},
                    in_axes=1,
                    out_axes=1)            
       
        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 100)
        ch, xs=LSTM()(ch, xs)                
        xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)

        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 100)
        ch, xs=LSTM()(ch, xs)  
        xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)

        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 200)
        ch, xs=LSTM()(ch, xs)        
        xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)
        
        ch = LSTMCell.initialize_carry(jax.random.PRNGKey(0), (xs.shape[0],), 200)
        ch, xs=LSTM()(ch, xs)
        xs=nn.Dropout(rate=0.2,deterministic=not is_training)(xs)

        xs=xs[:,-1,:]
        xs=nn.Dense(features=6)(xs)        
        return xs

In [4]:
from dataset_pytorch import StocksSet
from torch.utils.data import DataLoader
trainset=StocksSet()
test_dataset=StocksSet(is_train=False)
training_generator=DataLoader(trainset, batch_size=1000, num_workers=4)
test_generator = DataLoader(test_dataset, batch_size=700, num_workers=4)

In [5]:


@jax.jit
def apply_model(state, images, labels,old_variables,dropout_rng):
  def loss_fn(params,old_variables):
    logits= state.apply_fn({'params': params}, images,is_training=True,rngs={'dropout':dropout_rng})
    mutated_vars =None
    loss = jnp.mean(jnp.sum(0.5*(logits-labels)**2,axis=-1))   
    return loss, (logits,mutated_vars)    

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits,new_batch_stats)), grads = grad_fn(state.params,old_variables)
  logits=logits*(trainset.data_max-trainset.data_min)+trainset.data_min
  labels=labels*(trainset.data_max-trainset.data_min)+trainset.data_min
  accuracy =jnp.sum(jnp.average(jnp.abs(logits-labels)/labels*100,axis=-1))  
  return grads, loss, accuracy,new_batch_stats,logits

@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [6]:
def create_train_state(rng):
  lstm = SimpleScan()
  key_1, key_2, key_3 = jax.random.split(jax.random.PRNGKey(0), 3)
  variables=lstm.init({'params': key_1, 'dropout':key_1}, jnp.ones([1, 60, 6]))
  params = variables['params']
  
  import optax
  from flax.training import train_state
  
  tx = optax.adam(learning_rate=0.001)
  state=train_state.TrainState.create(apply_fn=lstm.apply, params=params, tx=tx)
  return state,variables

In [7]:
@jax.jit
def predict(state, variables,image_i):  
  logits= state.apply_fn({'params': state.params},image_i,is_training=False)
  return logits
def test(state, variables,test_generator):  
  accuracy=0
  for (x,y) in test_generator:
    x=np.array(x)
    y=np.array(y)
    logits= predict(state, variables,x)
    logits=logits*(trainset.data_max-trainset.data_min)+trainset.data_min
    label_i=y*(trainset.data_max-trainset.data_min)+trainset.data_min
    accuracy += jnp.sum(jnp.average(jnp.abs(logits-label_i)/label_i*100,axis=-1))  
  return accuracy/trainset.test_len

In [11]:
def train_epoch(state, training_generator, rng,variables):
  epoch_loss = []
  sum_accuracy = 0
  for (x,y) in training_generator:
    x=np.array(x)
    y=np.array(y)
    rng, dropout_rng = jax.random.split(rng)
    grads, loss, accuracy ,variables,y= apply_model(state, x, y,variables,dropout_rng)
    state = update_model(state, grads)
    
    epoch_loss.append(loss)
    sum_accuracy +=accuracy
  train_loss = np.mean(epoch_loss)
  train_accuracy = sum_accuracy/trainset.train_len
  return state, train_loss, train_accuracy,variables,y

In [12]:
def train_and_evaluate():  
  rng = jax.random.PRNGKey(0)  
  rng, init_rng = jax.random.split(rng)
  state,variables = create_train_state(init_rng)
  for epoch in range(1, 100 + 1):
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy,variables,y = train_epoch(state, training_generator,
                                                    rng=input_rng,variables=variables)   
    print("") 
    print("epoch:",epoch)                                         
    print(f"train Error_ratio={train_accuracy:2.1f}%")     
    print(f"test Error_ratio={test(state, variables,test_generator):2.1f}%")
    # print(y)
  
  return state

In [13]:
train_and_evaluate()


epoch: 1
train Error_ratio=51.2%
test Error_ratio=38.9%

epoch: 2
train Error_ratio=97.0%
test Error_ratio=49.5%

epoch: 3
train Error_ratio=31.7%
test Error_ratio=22.8%

epoch: 4
train Error_ratio=44.5%
test Error_ratio=27.1%

epoch: 5
train Error_ratio=26.2%
test Error_ratio=19.6%

epoch: 6
train Error_ratio=23.6%
test Error_ratio=22.1%

epoch: 7
train Error_ratio=22.8%
test Error_ratio=22.1%

epoch: 8
train Error_ratio=20.7%
test Error_ratio=11.7%

epoch: 9
train Error_ratio=18.9%
test Error_ratio=14.7%

epoch: 10
train Error_ratio=21.9%
test Error_ratio=20.3%

epoch: 11
train Error_ratio=22.8%
test Error_ratio=10.4%

epoch: 12
train Error_ratio=20.5%
test Error_ratio=13.5%

epoch: 13
train Error_ratio=21.5%
test Error_ratio=15.2%

epoch: 14
train Error_ratio=19.6%
test Error_ratio=10.5%

epoch: 15
train Error_ratio=19.6%
test Error_ratio=16.0%

epoch: 16
train Error_ratio=19.7%
test Error_ratio=13.0%

epoch: 17
train Error_ratio=19.3%
test Error_ratio=13.2%

epoch: 18
train Error_

TrainState(step=DeviceArray(700, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of SimpleScan()>, params=FrozenDict({
    Dense_0: {
        bias: DeviceArray([0.04237481, 0.04425671, 0.04616028, 0.0451198 , 0.04073124,
                     0.03462777], dtype=float32),
        kernel: DeviceArray([[-0.1499559 ,  0.02018046, -0.02613124, -0.05212953,
                      -0.0960778 , -0.02873716],
                     [-0.15572481, -0.08805181,  0.02136124,  0.08609159,
                      -0.01363001,  0.01203101],
                     [-0.07379939,  0.0256145 ,  0.09215546,  0.10036606,
                      -0.06341026,  0.01690302],
                     ...,
                     [ 0.04462855,  0.1297037 , -0.1342365 ,  0.1099339 ,
                       0.11303055,  0.02191444],
                     [ 0.09413638,  0.00934959,  0.05945162, -0.01935158,
                       0.13688673, -0.0268121 ],
                     [ 0.04522821,  0.08156604, -0.05282593,  