<a href="https://colab.research.google.com/github/durml91/Personal/blob/Safety/Transformer_implementation_I.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install einops
!pip install equinox

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1
Collecting equinox
  Downloading equinox-0.10.6-py3-none-any.whl (125 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jax>=0.4.11 (from equinox)
  Downloading jax-0.4.13.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jaxtyping>=0.2.20 (from equin

In [2]:
pip install -U jax jaxlib

Collecting jaxlib
  Downloading jaxlib-0.4.13-cp310-cp310-manylinux2014_x86_64.whl (71.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.6/71.6 MB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.10+cuda11.cudnn86
    Uninstalling jaxlib-0.4.10+cuda11.cudnn86:
      Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86
Successfully installed jaxlib-0.4.13


In [3]:
import jax
import jax.random as jr
import jax.numpy as jnp
import einops
import equinox as eqx
import optax
import tqdm

In [5]:
import functools
from typing import Dict, List, Mapping, Optional, Callable

#from datasets import load_dataset

from jaxtyping import Array, Float, Int

from tqdm import notebook as tqdm

import math

In [4]:
key = jr.PRNGKey(2022)



### Transformer modules

GELU

In [6]:
class Lambda(eqx.Module):

    fn: Callable

    def __call__(self, x, *, key=None):
        return self.fn(x)

Attention

In [7]:
class AttentionBlock(eqx.Module):

    attention: eqx.nn.MultiheadAttention
    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout
    num_heads: int = eqx.field(static=True)

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey
    ):
        self.num_heads = num_heads
        self.attention = eqx.nn.MultiheadAttention(
            num_heads=num_heads,
            query_size=hidden_size,
            use_query_bias=True,
            use_key_bias=True,
            use_value_bias=True,
            use_output_bias=True,
            dropout_p=attention_dropout_rate,
            key=key,
        )

        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)


    def make_self_attention_mask(
        self, mask: Int[Array, "seq_len"]
    ) -> Float[Array, "num_heads seq_len seq_len"]:

        mask = jnp.multiply(
            jnp.expand_dims(mask, axis=-1), jnp.expand_dims(mask, axis=-2)
        )

        #see if you can do this with einops rearrange or repeat (repeat allows you to add any number of dimensions in new axis)

        mask = jnp.expand_dims(mask, axis=-3)
        mask = jnp.repeat(mask, repeats=self.num_heads, axis=-3)

        return mask.astype(jnp.float32)


    def __call__(
        self,
        input1: Float[Array, "seq_len hidden_size"],
        input2: Float[Array, "seq_len hidden_size"],
        input3: Float[Array, "seq_len hidden_size"],
        mask: Optional[Int[Array, "seq_len"]],
        enable_dropout: bool = False,
        key: "jr.PRNGKey" = None,
    ) -> Float[Array, "seq_len hidden_size"]:

        if mask is not None:
            mask = self.make_self_attention_mask(mask)

        attention_key, dropout_key = (
            (None, None) if key is None else jr.split(key)
        )

        attention_output = self.attention(
              query=input1,
              key_=input2,
              value=input3,
              mask=mask,
              inference=not enable_dropout,
              key=attention_key
        )

        att_drop = self.dropout(attention_output, inference=not enable_dropout, key=dropout_key)
        unn_out = att_drop + inputs
        output = jax.vmap(self.layernorm)(unn_out)

        return output

MLP Block

In [8]:
class FeedForwardBlock(eqx.Module):

    mlp: eqx.nn.Sequential    #could also use MLP if this way is more fiddly

    layernorm: eqx.nn.LayerNorm
    dropout: eqx.nn.Dropout


    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        dropout_rate: float,
        key: jr.PRNGKey,
    ):

        linear1, linear2 = jr.split(key)

        self.mlp = eqx.nn.Sequential([
            eqx.nn.Linear(in_features=intermediate_size, out_features=intermediate_size, key=linear1),
            Lambda(jax.nn.gelu),
            eqx.nn.Linear(in_features=intermediate_size, out_features=hidden_size, key=linear2)
        ])

        self.layernorm = eqx.nn.LayerNorm(shape=hidden_size)
        self.dropout = eqx.nn.Dropout(dropout_rate)

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        enable_dropout: bool = True,
        key: Optional[jr.PRNGKey] = None,
    ) -> Float[Array, "seq_len hidden_size"]:

        feed_out = self.mlp(inputs)

        out_d = self.dropout(feed_out, inference=not enable_dropout, key=key)

        out_unn = out_d + inputs

        output = self.layernorm(out_unn)

        return output


Embedding table

In [9]:
class InputEmbeddings(eqx.Module):

  embedding: eqx.nn.Embedding

  def __init__(
      self,
      d_model: int,
      vocab_size: int,
      key: jr.PRNGKey,
  ):

      self.d_model = d_model
      self.vocab_size = vocab_size
      self.embedding = eqx.nn.Embedding(vocab_size, d_model, key)

  def __call__(self, x):
    return self.embedding(x) * math.sqrt(self.d_model)

Positional encoding

In [134]:
class SinusoidalPosEmb(eqx.Module):
    pos_emb: jax.Array
    dropout: eqx.nn.Dropout


    def __init__(
        self,
        d_model: int,
        seq_len: int,
        dropout_rate: float
    ):
        self.dropout = eqx.nn.Dropout(dropout_rate)

        pe = jnp.zeros((seq_len, d_model))
        position = einops.repeat(jnp.expand_dims(jnp.arange(0 , seq_len), axis=1), " s 1 -> s (r 1)", r=(d_model/2))  #shape [seq_len, d_model/2]
        div_term = jnp.exp(jnp.arange(0, d_model, 2) * -(math.log(10_000) / d_model))  #shape [d_model/2]

        ins = jax.vmap(jnp.multiply, in_axes=(1, 0), out_axes=1)(position, div_term)  #shape [seq_len, d_model/2]



        pe = pe.at[:, 0::2].set(jnp.sin(ins))
        pe = pe.at[:, 1::2].set(jnp.cos(ins))


        self.pos_emb = jnp.expand_dims(pe, axis=0)

        print(self.pos_emb.shape)
    def __call__(
        self,
        x,
        enable_dropout: bool = False,
        key: "jr.PRNGKey" = None,
    ) -> Float[Array, "seq_len d_model"]:


        x = x + self.pos_emb[:, :x.shape[1]]

        return self.dropout(x, inference=not enable_dropout, key=key)



####### need to freeze parameters use eqx.partition or try filter(static=True)

In [135]:
encod_block = SinusoidalPosEmb(d_model=20, seq_len=5000, dropout_rate=0.0)

#pe = encod_block.pos_emb.T

(1, 5000, 20)


In [121]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )

        print(div_term.shape)


        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        print(pe.shape)
        pe = pe.unsqueeze(0)
        print(pe.shape)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [122]:
pe = PositionalEncoding(20, 0)

torch.Size([10])
torch.Size([5000, 20])
torch.Size([1, 5000, 20])


In [45]:
torch.arange(0, 20, 2)

tensor([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18])

In [29]:
import pandas as pd
import altair as alt

PositionalEncoding(
  (dropout): Dropout(p=0, inplace=False)
)

In [34]:
RUN_EXAMPLES = True

In [32]:
def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)

In [136]:
def example_positional():
    pe = SinusoidalPosEmb(20, 5000, 0)
    y = pe(jnp.zeros((1, 100, 20), dtype=jnp.float32))

    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "embedding": y[0, :, dim],
                    "dimension": dim,
                    "position": list(range(100)),
                }
            )
            for dim in [4, 5, 6, 7]
        ]
    )

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive()
    )

show_example(example_positional)

(1, 5000, 20)


In [25]:
import matplotlib
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf')
from matplotlib.colors import to_rgb
matplotlib.rcParams['lines.linewidth']=2.0
import seaborn as sns
sns.reset_orig()

  set_matplotlib_formats('svg', 'pdf')


<Figure size 640x480 with 0 Axes>

Transformer Block

In [13]:
class TransformerLayer(eqx.Module):

    attention_block: AttentionBlock
    ff_block: FeedForwardBlock

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey,
    ):

        attention_key, ff_key = jr.split(key)

        self.attention_block = AttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=attention_key,
        )

        self.ff_block = FeedForwardBlock(
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            dropout_rate=dropout_rate,
            key=ff_key,
        )

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        mask: Optional[Int[Array, "seq_len"]] = None,
        *,
        enable_dropout: bool = False,
        key: Optional[jr.PRNGKey] = None,
    ) -> Float[Array, "seq_len hidden_size"]:

        attn_key, ff_key = (None, None) if key is None else jr.split(key)

        attention_output = self.attention_block(
            inputs, inputs, inputs, mask, enable_dropout=enable_dropout, key=attn_key
        )

        mlp_out = self.ff_block(
            inputs, enable_dropout=enable_dropout, key=ff_key
        )

        return mlp_out

Encoder

In [15]:
class Encoder(eqx.Module):

    embedder_block: InputEmbeddings
    pos_embed: SinusoidalPosEmb
    layers: List[TransformerLayer]


    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        seq_len: int,
        intermediate_size: int,
        num_layers: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey,
    ):

      embedder_key, layer_key = jr.split(key, num=2)

      self.embedder_block = InputEmbeddings(
          hidden_size, vocab_size, embedder_key
      )

      self.pos_embed = SinusoidalPosEmb(hidden_size, seq_len, dropout_rate)

      layer_keys = jr.split(layer_key, num=num_layers)

      self.layers = [
          TransformerLayer(
              hidden_size=hidden_size, intermediate_size=intermediate_size, num_heads=num_heads, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, key=layer_key,
          )
          for layer_key in layer_keys]



      def __call__(
          self,
          tokens: Int[Array, " seq_len"],
          mask,
          *,
          enable_dropout: bool = False,
          key: Optional[jr.PRNGKey] = None,
      ):

          embed_inputs = self.embedder_block(tokens)
          x = self.pos_embed(embed_inputs, key=key)

          for layer in self.layers:

              x = layer(x, mask)

          return x

Decoder Block

In [19]:
class DecoderLayer(eqx.Module):

    self_att_block: AttentionBlock
    cross_att_block: AttentionBlock
    ff_block: FeedForwardBlock

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        num_heads: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey,
    ):

        self_att_key, cross_att_key, ff_key = jr.split(key, num=3)

        self.self_att_block = AttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=self_att_key,
        )

        self.cross_att_block = AttentionBlock(
            hidden_size=hidden_size,
            num_heads=num_heads,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            key=cross_att_key,
        )

        self.ff_block = FeedForwardBlock(
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            dropout_rate=dropout_rate,
            key=ff_key,
        )

    def __call__(
        self,
        inputs: Float[Array, "seq_len hidden_size"],
        encoder_output,
        src_mask,
        tgt_mask,
        *,
        enable_dropout: bool = False,
        key,
    ):

        self_attn_key, cross_attn_key, ff_key = (None, None, None) if key is None else jr.split(key, num=3)

        self_attention_output = self.self_att_block(
            inputs, inputs, inputs, tgt_mask, enable_dropout=enable_dropout, key=self_attn_key
        )

        cross_attention_output = self.cross_att_block(
            encoder_output, encoder_output, self_attention_output, src_mask, enable_dropout=enable_dropout, key=cross_attn_key
        )

        mlp_out = self.ff_block(
            cross_attention_output, enable_dropout=enable_dropout, key=ff_key
        )

        return mlp_out

Decoder

In [20]:
class Decoder(eqx.Module):

    embedder_block: InputEmbeddings
    pos_embed: SinusoidalPosEmb
    layers: List[DecoderLayer]

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        seq_len: int,
        intermediate_size: int,
        num_heads: int,
        num_layers: int,
        dropout_rate: float,
        attention_dropout_rate: float,
        key: jr.PRNGKey
    ):


        embedder_key, layer_key = jr.split(key, num=2)

        self.embedder_block = InputEmbeddings(
          hidden_size, vocab_size, embedder_key
        )

        self.pos_embed = SinusoidalPosEmb(hidden_size, seq_len, dropout_rate)

        layer_keys = jr.split(layer_key, num=num_layers)

        self.layers = [
          DecoderLayer(
              hidden_size=hidden_size, intermediate_size=intermediate_size, num_heads=num_heads, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, key=layer_key,
          )
          for layer_key in layer_keys]

    def __call__(
        self,
        x,
        encoder_output,
        src_mask,
        tgt_mask,
        *,
        enable_dropout: bool = False,
        key: Optional[jr.PRNGKey] = None,
    ):

        embed_inputs = self.embedder_block(x)
        x = self.pos_embed(embed_inputs, key=key)


        for layer, subkey in zip(self.layers, jr.split(key, len(self.layers))):

              x = layer(x, encoder_output, src_mask, tgt_mask, subkey)

        return x

Final Layer

In [22]:
class Out_Projection_Layer(eqx.Module):

    proj: eqx.nn.Linear

    def __init__(
        self,
        d_model:int,
        vocab_size: int,
        key: jr.PRNGKey
    ):

        self.proj = eqx.nn.Linear(in_features=d_model, out_features=vocab_size, key=key)

    def __call__(
        self,
        x
    ):
        out = jax.vmap(self.proj)(x)
        return jax.nn.log_softmax(out, axis=-1)


Transformer

In [23]:
class Transformer(eqx.Module):

    encoder: Encoder
    decoder: Decoder

    out_proj: Out_Projection_Layer


    def __init__(
        self,
        config: Mapping,
        key: jr.PRNGKey
    ):

        encoder_key, decoder_key, out_proj_key = jr.split(key, num=3)

        self.encoder = Encoder(
            vocab_size=config["src_vocab_size"],
            hidden_size=config["hidden_size"],
            seq_len=config["src_seq_len"],
            intermediate_size=config["intermediate_size"],
            num_layers=config["num_hidden_layers"],
            num_heads=config["num_attention_heads"],
            dropout_rate=config["hidden_dropout_prob"],
            attention_dropout_rate=config["attention_dropout_prob"],
            key=encoder_key,
        )

        self.decoder = Decoder(
            vocab_size=config["tgt_vocab_size"],
            hidden_size=config["hidden_size"],
            seq_len=config["tgt_seq_len"],
            intermediate_size=config["intermediate_size"],
            num_layers=config["num_hidden_layers"],
            num_heads=config["num_attention_heads"],
            dropout_rate=config["hidden_dropout_prob"],
            attention_dropout_rate=config["attention_dropout_prob"],
            key=decoder_key,
        )

        self.out_proj(
            d_model=config["hidden_size"],
            vocab_size=config["vocab_size"],
            key=out_proj_key)

    def __call__(
        self,
        src,
        src_mask,
        tgt,
        tgt_mask,
        key
    ):
        enc_key, dec_key = jr.split(key, num=2)

        #encode
        enc = self.encoder(src, src_mask, key=enc_key)

        #decode
        dec = self.decoder(tgt, enc, src_mask, tgt_mask, key=dec_key)

        #out projection
        out_proj = self.Out_Projection_Layer(dec)

        return out_proj

Configuration

In [24]:
gpt_config = {
    "src_vocab_size": 11,
    "tgt_vocab_size": 11,
    "src_seq_len": 4,
    "tgt_seq_len": 4,
    "hidden_size": 128,     #d_model - 128
    "num_hidden_layers": 2,   #N - 2
    "num_attention_heads": 2,   #h - 2
    "intermediate_size": 512,    #d_ff - 512
    "hidden_dropout_prob": 0.1,
    "attention_dropout_prob": 0.1,
}

Test

In [None]:
test_model = Transformer(gpt_config, key)