<a href="https://colab.research.google.com/github/durml91/Personal/blob/main/Transformer_implementation_I.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einops
!pip install equinox

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting equinox
  Downloading equinox-0.10.8-py3-none-any.whl (130 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m130.2/130.2 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jax>=0.4.13 (from equinox)
  Downloading jax-0.4.13.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jaxtyping>=0.2.20 (from equin

In [2]:
pip install -U jax jaxlib

Collecting jaxlib
  Downloading jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl (71.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.6/71.6 MB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.10+cuda11.cudnn86
    Uninstalling jaxlib-0.4.10+cuda11.cudnn86:
      Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86
Successfully installed jaxlib-0.4.13


In [3]:
import jax
import jax.random as jr
import jax.numpy as jnp
import einops
import equinox as eqx
import optax
import tqdm

In [11]:
import functools
from typing import Dict, List, Mapping, Optional, Callable

#from datasets import load_dataset

from jaxtyping import Array, Float, Int

from tqdm import notebook as tqdm

import math

In [6]:
key = jr.PRNGKey(2022)



### Transformer modules

GELU

In [None]:
class Lambda(eqx.Module):

    fn: Callable

    def __call__(self, x, *, key=None):
        return self.fn(x)

Attention

In [9]:
class AttentionBlock(eqx.Module):

    attention: eqx.nn.MultiheadAttention
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout
    num_heads: int = eqx.field(static=True)

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey
    ):
        self.num_heads = num_heads
        self.attention = eqx.nn.MultiheadAttention(
            num_heads=num_heads,
            query_size=hidden_size,
            use_query_bias=True,
            use_key_bias=True,
            use_value_bias=True,
            use_output_bias=True,
            dropout_p=attention_dropout_rate,
            key=key,
        )

        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)


    def make_self_attention_mask(
        self, mask: Int[Array, "seq_len"]
    ) -> Float[Array, "num_heads seq_len seq_len"]:

        mask = jnp.multiply(
            jnp.expand_dims(mask, axis=-1), jnp.expand_dims(mask, axis=-2)
        )

        #see if you can do this with einops rearrange or repeat (repeat allows you to add any number of dimensions in new axis)

        mask = jnp.expand_dims(mask, axis=-3)
        mask = jnp.repeat(mask, repeats=self.num_heads, axis=-3)

        return mask.astype(jnp.float32)


    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        mask: Optional[Int[Array, "seq_len"]],
        enable_dropout: bool = False,
        key: "jr.PRNGKey" = None,
    ) -> Float[Array, "seq_len hidden_size"]:

        if mask is not None:
            mask = self.make_self_attention_mask(mask)

        attention_key, dropout_key = (
            (None, None) if key is None else jr.split(key)
        )

        attention_output = self.attention(
              query=inputs,
              key_=inputs,
              value=inputs,
              mask=mask,
              inference=not enable_dropout,
              key=attention_key
        )

        att_drop = self.dropout(attention_output, inference=not enable_dropout, key=dropout_key)
        unn_out = att_drop + inputs
        output = jax.vmap(self.layernorm)(unn_out)

        return output

MLP Block

In [10]:
class FeedForwardBlock(eqx.Module):

    mlp: eqx.nn.Sequential

    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout


    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        dropout_rate: float,
        key: jr.PRNGKey,
    ):

        linear1, linear2 = jr.split(key)

        self.mlp = eqx.nn.Sequential([
            eqx.nn.Linear(in_features=intermediate_size, out_features=intermediate_size, key=linear1),
            Lambda(jax.nn.gelu),
            eqx.nn.Linear(in_features=intermediate_size, out_features=hidden_size, key=linear2)
        ])

        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        enable_dropout: bool = True,
        key: Optional[jr.PRNGKey] = None,
    ) -> Float[Array, "seq_len hidden_size"]:

        feed_out = self.mlp(inputs)

        out_d = self.dropout(feed_out, inference=not enable_dropout, key=key)

        out_unn = out_d + inputs

        output = self.layernorm(out_unn)

        return output


Embedding table

In [None]:
class InputEmbeddings(eqx.Module):

  embedding: eqx.nn.Embedding

  def __init__(self, d_model: int, vocab_size: int, key):

      self.d_model = d_model
      self.vocab_size = vocab_size
      self.embedding = eqx.nn.Embedding(vocab_size, d_model, key)

  def __call__(self, x):
    return self.embedding(x) * math.sqrt(self.d_model)

Positional encoding

In [12]:
class SinusoidalPosEmb(eqx.Module):
    emb: jax.Array


    def __init__(self, dim):
        half_dim = dim // 2
        emb = math.log(10_000) / (half_dim - 1)
        self.emb = jnp.exp(jnp.arange(half_dim) * -emb)

    def __call__(self, ):

        emb = jnp.concatenate((jnp.sin(emb)), (jnp.cos(emb)), axis=-1)

        return emb



####### need to freeze parameters use eqx.partition or try filter(static=True)

Transformer Block

In [None]:
class TransformerLayer(eqx.Module):

    attention_block: AttentionBlock
    ff_block: FeedForwardBlock

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey,
    ):

        attention_key, ff_key = jr.split(key)

        self.attention_block = AttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=attention_key,
        )

        self.ff_block = FeedForwardBlock(
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            dropout_rate=dropout_rate,
            key=ff_key,
        )

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        mask: Optional[Int[Array, "seq_len"]] = None,
        *,
        enable_dropout: bool = False,
        key: Optional[jr.PRNGKey] = None,
    ) -> Float[Array, "seq_len hidden_size"]:

        attn_key, ff_key = (None, None) if key is None else jr.split(key)

        attention_output = self.attention_block(
            inputs, mask, enable_dropout=enable_dropout, key=attn_key
        )

        mlp_out = self.ff_block(
            inputs, enable_dropout=enable_dropout, ff_key
        )

        return mlp_out

Encoder

In [None]:
class Encoder(eqx.Module):

    embedder_block: InputEmbeddings
    pos_embed: SinusoidalPosEmb
    layers: List[TransformerLayer]


    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        intermediate_size: int,
        num_layers: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey,
    ):

      embedder_key, layer_key = jr.split(key, num=2)

      self.embedder_block = InputEmeddings(
          hidden_size, vocab_size
      )

      self.pos_embed = SinusoidalPosEmb(hidden_size)

      layer_keys = jr.split(layer_key, num=num_layers)

      self.layers = [
          TransformerLayer(
              hidden_size=hidden_size, intermediate_isze=intermediate_size, num_heads=num_heads, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, key=layer_key,
          )
          for layer_key in layer_keys]



      def __call__(
          self,
          tokens: Int[Array, " seq_len"],
          *,
          enable_dropout: bool = False,
          key: Optional[jr.PRNGKey] = None,
      ):

          embed_inputs = self.embedder_block(tokens)
          pos_enc = self.pos_embed()

          x = embed_inputs + pos_enc

          for layer in self.layers


