In [10]:
import jax
import jax.numpy as jnp
from flax import linen as nn  # Linen API

In [13]:
class dataloader:
    def __init__(self, B, T):
        self.B = B
        self.T = T
        text = open('shakespeare.txt', 'r').read()
        self.stoi = {ch: i for i, ch in enumerate(sorted(set(text)))}
        self.itos = {i: ch for ch, i in self.stoi.items()}
        self.encode = lambda x: [self.stoi[ch] for ch in x]
        self.decode = lambda x: ''.join(self.itos[i] for i in x)
        vcb = len(self.stoi)
        print(f"Vocab length is: {vcb}")
        self.data = jnp.array(self.encode(text))
        self.train_counter = 0
    
    def train(self):
        B, T = self.B, self.T
        buf = self.data[self.train_counter:self.train_counter+B*T+1]
        x = buf[:-1].reshape(B, T)
        y = buf[1:].reshape(B, T)
        self.train_counter += B*T
        if self.train_counter + B*T + 1 > len(self.data):
            self.train_counter = B*T + 1 - (len(self.data) - self.train_counter)
        return x, y

In [14]:
dl = dataloader(64, 64)

Vocab length is: 61


In [23]:
from dataclasses import dataclass
@dataclass
class Config:
    n_embd: int = 128
    hidden_dim: int = 256
    n_hidden: int = 2
    vocab_size: int = 64
    # dropout_rate: float = 0.2
    batch_size: int = 64
    seq_len: int = 64
    lr: float = 1e-3

class ShakeRNN(nn.Module):
    def setup(self):
        self.config = Config()
        self.vcb = self.config.vocab_size
        self.n_embd = self.config.n_embd
        self.embd = nn.Embed(num_embeddings=self.vcb, features=self.n_embd)
        self.RNN = nn.ModuleDict(dict(
            inputs = nn.Dense(features=self.config.hidden_dim),
            hidden = nn.Dense(features=self.config.hidden_dim),
            hidden2 = nn.Dense(features=self.config.hidden_dim),
            output = nn.Dense(features=self.vcb)
        ))

    def __call__(self, x, hidden_state, target=None):
        x = self.embd(x)
        h = jnp.zeros((x.shape[0], self.config.hidden_dim))
        x = jax.nn.gelu(self.RNN['inputs'](x) + self.RNN['hidden'](h))
        x += jax.nn.gelu(self.RNN['hidden2'](x))
        hidden_state = x
        output = self.RNN['output'](x)
        loss = None
        if target is not None:
            loss = jnp.mean(jax.nn.log_softmax(output, axis=-1) * target)
        return output, hidden_state, loss

In [24]:
rnn = ShakeRNN()