<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/Jax-Journey/blob/main/haiku_basics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/deepmind/dm-haiku

In [None]:
import haiku as hk
import jax.numpy as jnp
import jax

# LSTM

**Specialities :**

* ```name``` argument in ```__init__``` . Module must call ```super().__init__()``` with its ```name``` .
* ```__call__``` can take any arguments, return any. 
* Only single function(```__call__```) for both ```init_fn()``` and ```apply_fn()```. 


In [None]:
class HaikuLSTMCell(hk.Module):
    def __init__(self, in_dim, out_dim, name=None):
        super().__init__(name=name or "lstmcell")
        self.in_dim = in_dim
        self.out_dim = out_dim
    
    def __call__(self, inputs, h, c):
        weights_ih = hk.get_parameter("weight_ih", 
                                      (4*self.out_dim, self.in_dim),
                                      init = hk.initializers.UniformScaling())
        weights_hh = hk.get_parameter("weights_hh",
                                      (4*self.out_dim, self.out_dim),
                                      init=hk.initializers.UniformScaling())
        bias = hk.get_parameter("bias",
                                (4*self.out_dim,),
                                init = hk.initializers.Constant(0.0))
        
        ifgo = weights_ih @ inputs + weights_hh @ h + bias
        i, f, g, o = jnp.split(ifgo, indices_or_sections=4, axis=-1)
        
        i = jax.nn.sigmoid(i)
        f = jax.nn.sigmoid(f)
        g = jnp.tanh(g)
        o = jax.nn.sigmoid(o)

        new_c = f*c + i * g
        new_h = o*jnp.tanh(new_c)
        
        return (new_h, new_c)

The following code is even more like PyTorch :
* You can define all submodules and parameters inside ```__init__()```. All things that happen inside the function that is sent into ```hk.transform()``` will be well traced, and are valid. 

In [None]:
class HaikuLSTMLM(hk.Module):
    def __init__(self, vocab_size, dim, name=None):
        super().__init__(name=name or "lstmlm")
        _c0 = hk.get_parameter(name="c_0",
                               shape = (dim,),
                              
                               init = hk.initializers.TruncatedNormal(stddev=0.1))
        self.hc_0 = (jnp.tanh(_c0), _c0)
        self.embeddings = hk.Embed(vocab_size, dim)
        self.cell = HaikuLSTMCell(dim, dim)
    
    def forward(self, seq, hc):
        loss = 0
        for idx in seq:
            loss -= jax.nn.log_softmax(self.embeddings.embeddings@hc[0])[idx]
            hc = self.cell(self.embeddings(idx), *hc)
        return loss, hc

* It doesn't matter to ```hk.transform()``` where the submodules are defined, as long as they are defined within the function that is being transformed so that they can be purified. So both the above and below definition are valid and equivalent. 

* The second way allows you to make model sizes dependent on inputs received in ```forward()```

* The ```forward()``` function need not be named as it is , and can have any other name. Some poeple use ```__call__```, instead. We are able to use syntax like in line 16 below, (```hk.Embed(.. , ..)(idx)```) because the processing done ```hk.Embed``` is defined in its ```__call__```, rather than forward. Had it been defined in ```forward()```, we'd have to call ```hk.Embed(.. , ..).forward(idx)``` instead. 

In [None]:
class HaikuLSTMLM(hk.Module):
    def __init__(self, vocab_size, dim, name=None):
        super().__init__(name=name or "lstmlm")
        _c0 = hk.get_parameter(name="c_0",
                               shape = (dim,),
                               init = hk.initializers.TruncatedNormal(stddev=0.1))
        self.hc_0 = (jnp.tanh(_c0), _c0)
        self.vocab_size=vocab_size
        self.dim = dim
        self.cell = HaikuLSTMCell(dim, dim)
    
    def forward(self, seq, hc):
        loss = 0
        for idx in seq:
            loss -= jax.nn.log_softmax(hk.Embed(self.vocab_size, self.dim).embeddings@hc[0])[idx]
            hc = self.cell(hk.Embed(self.vocab_size, self.dim)(idx), *hc)
        return loss, hc

In [None]:
def impure_forward_fn(vocab_size, dim, seq, hc=None):
    lm = HaikuLSTMLM(vocab_size, dim)
    return lm.forward(seq, hc if hc else lm.hc_0)

In [None]:
init_fn, nojit_pure_forward_fn = hk.transform(impure_forward_fn)
pure_forward_fn = jax.jit(nojit_pure_forward_fn)

* ```init_fn()``` takes in two types of arguments. First is the random key and second are the inputs to be sent to the function that was transformed. It returns the nested params.

* ```nojit_pure_forward_function``` takes in three types of arguments. First is the ```params``` returned by ```init_fn()``` and second is the ```rng``` key and third are the arguments to the function that was transformed. Same ```rng``` key will give same result on same inputs. It returns the same things that are returned by ```impure_forward_fn()``` . 

In [None]:
rng = jax.random.PRNGKey(0)
params = init_fn(rng, vocab_size = 20, dim = 10, seq=jnp.array([0]))

In [None]:
print(params)

In [None]:
loss, hc = nojit_pure_forward_fn(params, rng, vocab_size = 20, dim=10, seq=jnp.array([0]))

In [None]:
print(loss, hc)

2.9562287 (DeviceArray([ 0.19030678, -0.04981524, -0.1435111 ,  0.14797553,
              0.01645921, -0.01669403,  0.11530687, -0.10629394,
             -0.02137115,  0.07460269], dtype=float32), DeviceArray([ 0.37595972, -0.08241095, -0.2591579 ,  0.3729893 ,
              0.0248227 , -0.03331303,  0.19235653, -0.24751279,
             -0.04453837,  0.15290585], dtype=float32))
