In [1]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
import haiku as hk
import os
import sys
import jax
import jax.numpy as jnp
from typing import Union, Optional, Dict, List, Callable, Tuple


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)

root_dir = '..'


ModuleNotFoundError: No module named 'src'

In [None]:

from src.models.layers import (
    ESMLearnedPositionalEmbeddings,
    RobertaLMHead,
    SelfAttentionBlock,
    TokensDropout,
)
from src.models.nucleotide_transformer import NucleotideTransformerConfig
from src.models.types import (
    AttentionMask,
    Embedding,
    Tokens,
    TransformerOutput,
)


In [None]:

def cross_entropy_loss(x):
    return 1

def loss_func(x, y):
    return cross_entropy_loss(x)

def compute_grads(params, x, y, model):
    logits = model.apply(params, None, x)
    return loss_func(logits, y)

def update(params, grads, lr = 0.1):
    return jax.tree_util.tree_map(lambda p, g: p - g * lr, params, grads)

def train_step(params, x, y, model):
    params_grads = jax.grad(compute_grads)(params, x, y, model)
    return update(params, params_grads)

def train(iterations, params, xdata, ydata, model):
    for i in iterations:
        x = next(xdata)
        y = next(ydata)
        params = train_step(params, x, y, model)
    return params



In [3]:

class OKConfig:

    ok: int

class OK(hk.module):
    
    def __init__(
        self,
        config: OKConfig,
        name: Optional[str] = None,
    ) -> None: 
        
        self._config = config
        super().__init__(name=name)
        
        # Layer declarations incl norms
        
        # Prcoess config if needed
        
        # Check config compatibility
    
    def _attention_block(self, layer_idx) -> SelfAttentionBlock:

        return SelfAttentionBlock(  # type: ignore
            num_heads=self._config.attention_heads,
            embed_dim=self._config.embed_dim,
            key_size=self._config.key_size,
            ffn_embed_dim=self._config.ffn_embed_dim,
            name=f"attention_layer_{layer_idx}",
        )

    @hk.transparent
    def apply_attention_blocks(
        self,
        x: Embedding,
        outs: Dict[str, Embedding],
        attention_mask: Optional[AttentionMask] = None,
    ) -> Tuple[Embedding, Dict[str, Embedding]]:
        """
        Create the blocks of attention layers and applies them.

        Args:
            x: The sequence embedding.
            outs: A dictionary to carry through the attention layers which stores the
                intermediate sequence embedding and attention maps.
            attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).

        Returns:
            The output sequence embedding.
            The optional intermediate results (embeddings of the layer and attention
                weights).
        """

        layers: List[Callable] = [
            self._attention_block(layer_idx)
            for layer_idx in range(self._config.num_layers)
        ]
        for layer_idx, layer in enumerate(layers):
            output = layer(
                x=x,
                attention_mask=attention_mask,
            )
            x = output["embeddings"]
            # Save intermediate embeddings if needed
            if (layer_idx + 1) in self._config.embeddings_layers_to_save:
                outs[f"embeddings_{(layer_idx+1)}"] = output["embeddings"]
            # Save intermediate attention maps if needed
            if (layer_idx + 1) in self._attention_layers_to_save:
                for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
                    dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
                    outs[dkey] = output["attention_weights"][:, map_number + 1]

        return x, outs
        
    def __call__(self, inputs: jnp.array):
        
        # Compute embeddings
        # Prepare outputs dict
        outs: Dict[str, jnp.ndarray] = {}

        # Compute embeddings
        x = self._embed_layer(inputs)
        
        # If masked, scale
        x, outs = self.apply_attention_blocks(
            x=x,
            outs=outs,
            attention_mask=attention_mask,
        )
        
        # If requirements on inputs, check
        
        # Positional embeddings
        
        # Construct layers, eg attention / Head
        
        
        return outs
        
        


['/home/wadh6511/Kode/env_evo/lib/python3.9/site-packages/haiku']

In [None]:
class OkTransformer(hk.Module):
    """
    Jax implementation of Nucleotide Transformer models.
    """

    def __init__(
        self,
        config: NucleotideTransformerConfig,
        name: Optional[str] = None,
    ):
        """
        Initialize a Nucleotide Transformer model.

        Args:
            config: Dataclass containing model hyperparameters.
            name: Name for module (custom will break weight loading).
        """

        self._config = config
        super().__init__(name=name)

        self._embed_layer = hk.Embed(self._config.alphabet_size, self._config.embed_dim)

        # self._pos_embed_layer = ESMLearnedPositionalEmbeddings(
        #     config.max_positions, config.embed_dim, config.pad_token_id
        # )

        self._lm_head = RobertaLMHead(
            embed_dim=self._config.embed_dim,
            alphabet_size=self._config.alphabet_size,
            name="roberta_lm_head",
        )

        if self._config.emb_layer_norm_before:
            self.emb_ln_before = hk.LayerNorm(
                axis=-1,
                create_scale=True,
                create_offset=True,
                name="emb_layer_norm_before",
            )

        # Process attention maps to save requirement into more suitable format
        attention_maps_to_save = config.attention_maps_to_save
        self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save})
        self._attention_maps_per_layer_to_save = {t[0]: list(t[1:]) for t in attention_maps_to_save}

        # Checking user request can be executed, raise error otherwise
        max_layer = max(self._attention_layers_to_save + [0])
        if max_layer > config.num_layers:
            raise ValueError(
                f"You are requiring attention maps for layer {max_layer}, "
                f"while the model has {config.num_layers} layers only."
            )

        for layer, maps in self._attention_maps_per_layer_to_save.items():
            max_map = max(maps)
            if max_map > config.attention_heads:
                raise ValueError(
                    f"You are requiring attention maps number {max_map} "
                    f"at layer {layer}, while the model has {config.attention_heads} "
                    f"only."
                )

    @hk.transparent
    def apply_attention_blocks(
        self,
        x: Embedding,
        outs: Dict[str, Embedding],
        attention_mask: Optional[AttentionMask] = None,
    ) -> Tuple[Embedding, Dict[str, Embedding]]:
        """
        Create the blocks of attention layers and applies them.

        Args:
            x: The sequence embedding.
            outs: A dictionary to carry through the attention layers which stores the
                intermediate sequence embedding and attention maps.
            attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).

        Returns:
            The output sequence embedding.
            The optional intermediate results (embeddings of the layer and attention
                weights).
        """

        layers: List[Callable] = [
            self._attention_block(layer_idx)
            for layer_idx in range(self._config.num_layers)
        ]

        if self._config.use_gradient_checkpointing:
            # the remat-ed function cannot take control flow arguments
            layers = [hk.remat(layer) for layer in layers]
        for layer_idx, layer in enumerate(layers):
            output = layer(
                x=x,
                attention_mask=attention_mask,
            )
            x = output["embeddings"]
            # Save intermediate embeddings if needed
            if (layer_idx + 1) in self._config.embeddings_layers_to_save:
                outs[f"embeddings_{(layer_idx+1)}"] = output["embeddings"]
            # Save intermediate attention maps if needed
            if (layer_idx + 1) in self._attention_layers_to_save:
                for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
                    dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
                    outs[dkey] = output["attention_weights"][:, map_number + 1]

        return x, outs

    @hk.transparent
    def _attention_block(self, layer_idx: int) -> SelfAttentionBlock:

        return SelfAttentionBlock(  # type: ignore
            num_heads=self._config.attention_heads,
            embed_dim=self._config.embed_dim,
            key_size=self._config.key_size,
            ffn_embed_dim=self._config.ffn_embed_dim,
            name=f"attention_layer_{layer_idx}",
        )

    def __call__(
        self,
        tokens: Tokens,
        attention_mask: Optional[AttentionMask] = None,
    ) -> TransformerOutput:
        """
        Computes the embeddings based on the input tokens.

        Args:
            tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
            attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
                If no mask is provided, a mask by default which equals 1 over all non
                pad tokens and 0 over pad tokens is computed.

        Returns:
            Dictionary containing the final embeddings and logits.
        """
        # Prepare outputs dict
        outs: Dict[str, jnp.ndarray] = {}

        # Compute embeddings
        x = self._embed_layer(tokens)
        # Tokens dropout if needed
        if self._config.token_dropout:
            x = TokensDropout(
                embed_dim=self._config.embed_dim,
                mask_token_id=self._config.mask_token_id,
                pad_token_id=self._config.pad_token_id,
                masking_ratio=self._config.masking_ratio,
                masking_prob=self._config.masking_prob,
            )(x, tokens)

        # RoBERTa's mask scaling factor
        x = self._config.embed_scale * x

        # Add check that the sequence fed into the transformer is not longer
        # than the max positions used to instantiate the learned positional
        # embeddings layer
        assert tokens.shape[1] <= self._config.max_positions, (
            "Inputs to the learned positional embeddings layer have a length "
            f"{x.shape[1]} greater than the max positions used to instantiate "
            f"it: {self._config.max_positions}"
        )

        # !! Removed pos emb
        # Positional Embedding
        # x = x + self._pos_embed_layer(tokens)

        if self._config.emb_layer_norm_before:
            x = self.emb_ln_before(x)

        # Attention mask
        if attention_mask is None:
            attention_mask = build_padding_attention_mask(
                tokens=tokens, pad_token_id=self._config.pad_token_id
            )

        # construct a tower of attention layers
        x, outs = self.apply_attention_blocks(
            x=x,
            outs=outs,
            attention_mask=attention_mask,
        )

        # Language Model Head
        lm_head_outs = self._lm_head(x)
        outs["logits"] = lm_head_outs["logits"]

        embeddings = lm_head_outs["embeddings"]
        # Save final embeddings if needed
        if self._config.num_layers in self._config.embeddings_layers_to_save:
            outs[f"embeddings_{self._config.num_layers}"] = embeddings

        return outs  # type: ignore