<a href="https://colab.research.google.com/github/durml91/Personal/blob/main/Transformer_implementation_I.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einops
!pip install equinox

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting equinox
  Downloading equinox-0.10.6-py3-none-any.whl (125 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jax>=0.4.11 (from equinox)
  Downloading jax-0.4.13.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jaxtyping>=0.2.20 (from equin

In [2]:
pip install -U jax jaxlib

Collecting jaxlib
  Downloading jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl (71.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.6/71.6 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.10+cuda11.cudnn86
    Uninstalling jaxlib-0.4.10+cuda11.cudnn86:
      Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86
Successfully installed jaxlib-0.4.13


In [3]:
import jax
import jax.random as jr
import jax.numpy as jnp
import einops
import equinox as eqx
import optax
import tqdm

In [4]:
key = jr.PRNGKey(2022)



### Transformer modules

Embedding table

In [7]:
class InputEmbeddings(eqx.Module):

  embedding: eqx.nn.Embedding

  def __init__(self, d_model: int, vocab_size: int, key):

      self.d_model = d_model
      self.vocab_size = vocab_size
      self.embedding = eqx.nn.Embedding(vocab_size, d_model, key)

  def __call__(self, x):
    return self.embedding(x)* jnp.sqrt(self.d_model)

Positional encoding

In [10]:
class PositionalEncoding(eqx.Module):

    dropout: eqx.nn.Dropout

    def __init__(self, d_model: int, seq_len: int, dropout_rate:float, key) -> None:

        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = eqx.nn.Dropout(dropout_rate)

        pe = jnp.zeros((seq_len, d_model), dtype = float)

        position = jnp.arange(0, seq_len, dtype=float)
        position = jnp.expand_dims(position, 1)

        div_term = jnp.exp(jnp.arange(0, d_model, 2) * (-jnp.log(10000.0) / d_model))

        pe[:, 0::2] = jnp.sin(position * div_term)
        pe[:, 1::2] = jnp.cos(position * div_term)

        pe = jnp.expand_dims(pe, 0)

    def __call__(self, x):

        x = x + (self.pe[:, :x.shape[1], :])

        return self.dropout(x)


####### need to freeze parameters

In [None]:
class LayerNormalisation(eqx.Module):

    def __init__(self, eps: float)