# CARP

In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [16]:
"""Add encoding classes with class methods"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence

import numpy as np
from tqdm import tqdm

import torch
from sequence_models.pretrained import load_model_and_alphabet

from scr.params.emb import TRANSFORMER_INFO, TRANSFORMER_MAX_SEQ_LEN, CARP_INFO
from scr.params.sys import DEVICE


class AbstractEncoder(ABC):
    """
    An abstract encoder class to fill in for different kinds of encoders

    All encoders will have an "encode" function
    """

    def __init__(self, encoder_name: str, embed_layer: int):

        """
        Args:
        - encoder_name: str, the name of the encoder
        - embed_layer: int, the layer number of the embedding
        """

        self._encoder_name = encoder_name
        self._embed_layer = embed_layer

    def encode(
        self,
        mut_seqs: Sequence[str] | str,
        batch_size: int = 0,
        flatten_emb: bool | str = False,
        mut_names: Sequence[str] | str | None = None,
    ) -> Iterable[np.ndarray]:
        """
        A function takes a list of sequences to yield a batch of encoded elements

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names
        """

        if isinstance(mut_seqs, str):
            mut_seqs = [mut_seqs]

        # If the batch size is 0, then encode all at once in a single batch
        if batch_size == 0:
            yield self._encode_batch(
                mut_seqs=mut_seqs, flatten_emb=flatten_emb, mut_names=mut_names
            )

        # Otherwise, yield chunks of encoded sequence
        else:

            for i in tqdm(range(0, len(mut_seqs), batch_size)):

                # figure out what mut_names to feed in
                if mut_names is None:
                    mut_name_batch = mut_names
                else:
                    mut_name_batch = mut_names[i : i + batch_size]

                yield self._encode_batch(
                    mut_seqs=mut_seqs[i : i + batch_size],
                    flatten_emb=flatten_emb,
                    mut_names=mut_name_batch,
                )

    def flatten_encode(
        self, encoded_mut_seqs: np.ndarray, flatten_emb: bool | str
    ) -> np.ndarray:
        """
        Flatten the embedding or just return the encoded mutants.

        Args:
        - encoded_mut_seqs: np.ndarray, shape [batch_size, seq_len, embed_dim]
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
            - True -> shape [batch_size, seq_len * embed_dim]
            - "max" or "mean" -> shape [batch_size, embed_dim]
            - False or everything else -> [batch_size, seq_len, embed_dim]

        Returns:
        - np.ndarray, shape depends on flatten_emb parameter
        """
        assert encoded_mut_seqs.shape[2] == self._embed_dim, "Wrong embed dim"

        if flatten_emb is True:
            # shape [batch_size, seq_len * embed_dim]
            return encoded_mut_seqs.reshape(encoded_mut_seqs.shape[0], -1)

        elif isinstance(flatten_emb, str):
            if flatten_emb == "mean":
                # [batch_size, embed_dim]
                return encoded_mut_seqs.mean(axis=1)
            elif flatten_emb == "max":
                # [batch_size, embed_dim]
                return encoded_mut_seqs.max(axis=1)

        else:
            print("No embedding flattening")
            # [batch_size, seq_len, embed_dim]
            return encoded_mut_seqs

    @abstractmethod
    def _encode_batch(
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encode a single batch of mut_seqs
        """
        pass

    @property
    def embed_dim(self) -> int:
        """The dim of the embedding"""
        return self._embed_dim

    @property
    def embed_layer(self) -> int:
        """The layer nubmer of the embedding"""
        return self._embed_layer

    @property
    def encoder_name(self) -> str:
        """The name of the encoding method"""
        return self._encoder_name


class ESMEncoder(AbstractEncoder):
    """
    Build an ESM encoder
    """

    def __init__(
        self,
        encoder_name: str,
        embed_layer: int,
        iftrimCLS: bool = True,
        iftrimEOS: bool = True,
    ):
        """
        Args
        - encoder_name: str, the name of the encoder, one of the keys of TRANSFORMER_INFO
        - embed_layer: int, the layer number of the embedding
        - iftrimCLS: bool, whether to trim the first classifification token
        - iftrimEOS: bool, whether to trim the end of sequence token, if exists
        """

        super().__init__(encoder_name, embed_layer)

        self._iftrimCLS = iftrimCLS
        self._iftrimEOS = iftrimEOS

        # load model from torch.hub
        print(f"Loading {self._encoder_name} using {self._embed_layer} layer embedding")
        self.model, self.alphabet = torch.hub.load(
            "facebookresearch/esm:main", model=self._encoder_name
        )
        self.batch_converter = self.alphabet.get_batch_converter()

        # set model to eval mode
        self.model.eval()
        self.model.to(DEVICE)

        self._embed_dim, self._max_emb_layer, _ = TRANSFORMER_INFO[self._encoder_name]

        assert (
            self._embed_layer <= self._max_emb_layer
        ), f"{self._embed_layer} exceeds {self._max_emb_layer}"

        expected_num_layers = int(self._encoder_name.split("_")[-3][1:])
        assert (
            expected_num_layers == self._max_emb_layer
        ), "Wrong ESM model name or layer"

    def _encode_batch(
        self,
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encodes a batch of mutant sequences.

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - np.ndarray or a tuple(np.ndarray, list[str]) where the list is batch_labels
        """

        if isinstance(mut_names, str):
            mut_names = [mut_names]

        # pair the mut_names and mut_seqs
        if mut_names is not None:
            assert len(mut_names) == len(
                mut_seqs
            ), "mutant_name and mut_seqs different length"
            mut_seqs = [(n, m) for (n, m) in zip(mut_names, mut_seqs)]
        else:
            mut_seqs = [("", m) for m in mut_seqs]

        # convert raw mutant sequences to tokens
        batch_labels, _, batch_tokens = self.batch_converter(mut_seqs)
        batch_tokens = batch_tokens.to(DEVICE)

        # Turn off gradients and pass the batch through
        with torch.no_grad():
            # shape [batch_size, seq_len + pad, embed_dim]
            if batch_tokens.shape[1] > TRANSFORMER_MAX_SEQ_LEN:
                print(f"Sequence exceeds {TRANSFORMER_MAX_SEQ_LEN}, chopping the end")
                batch_tokens = batch_tokens[:, :TRANSFORMER_MAX_SEQ_LEN]

            encoded_mut_seqs = (
                self.model(batch_tokens, repr_layers=[self._embed_layer])[
                    "representations"
                ][self._embed_layer]
                .cpu()
                .numpy()
            )

        # https://github.com/facebookresearch/esm/blob/main/esm/data.py
        # from_architecture

        # trim off initial classification token [CLS]
        # both "ESM-1" and "ESM-1b" have prepend_bos = True
        if self._iftrimCLS and self._encoder_name.split("_")[0] in ["esm1", "esm1b"]:
            encoded_mut_seqs = encoded_mut_seqs[:, 1:, :]

        # trim off end-of-sequence token [EOS]
        # only "ESM-1b" has append_eos = True
        if self._iftrimEOS and self._encoder_name.split("_")[0] == "esm1b":
            encoded_mut_seqs = encoded_mut_seqs[:, :-1, :]

        if mut_names is not None:
            return self.flatten_encode(encoded_mut_seqs, flatten_emb), batch_labels
        else:
            return self.flatten_encode(encoded_mut_seqs, flatten_emb)


class CARPEncoder(AbstractEncoder):
    """
    Build a CARP encoder
    """

    def __init__(
        self,
        encoder_name: str,
        embed_layer: int,
    ):
        """
        Args
        - encoder_name: str, the name of the encoder, one of the keys of CARP_INFO
        - embed_layer: int, the layer number of the embedding
        """

        super().__init__(encoder_name, embed_layer)

        # load model from torch.hub
        print(f"Loading {self._encoder_name} using {self._embed_layer} layer embedding")

        self.model, self.collater = load_model_and_alphabet(self._encoder_name)

        # set model to eval mode
        self.model.eval()
        self.model.to(DEVICE)

        self._embed_dim, self._max_emb_layer = CARP_INFO[self._encoder_name]

        assert (
            self._embed_layer <= self._max_emb_layer
        ), f"{self._embed_layer} exceeds {self._max_emb_layer}"

    def _encode_batch(
        self,
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encodes a batch of mutant sequences.

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - np.ndarray or a tuple(np.ndarray, list[str]) where the list is batch_labels
        """

        mut_seqs = [[m] for m in mut_seqs]

        x = self.collater(mut_seqs)[0]

        layer_name = f"layer{str(self._embed_layer)}"

        activation = {}

        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.detach()

            return hook

        # convert raw mutant sequences to tokens
        self.model.model.embedder.layers[self._embed_layer].register_forward_hook(
            get_activation(layer_name)
        )

        rep = self.model(x)

        encoded_mut_seqs = activation[layer_name].cpu().numpy()

        return self.flatten_encode(encoded_mut_seqs, flatten_emb)

In [3]:
from scr.encoding.encoding_classes import CARPEncoder

In [4]:
from scr.utils import pickle_load

In [5]:
df = pickle_load("data/proeng/gb1/two_vs_rest.pkl")

In [6]:
df_train = df.loc[(df["set"] == "train") & (df["validation"] != True)]
df_val = df.loc[(df["set"] == "train") & (df["validation"] == True)]
df_test = df.loc[(df["set"] == "test")]

len(df_train), len(df_val), len(df_test), len(df)

(381, 43, 8309, 8733)

In [8]:

seqs = [df_val.sequence.astype(str).str[0 : 56].values[0]]
seqs

['MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVNGEWTYDDATKTFTVTE']

In [9]:
no_flat_encoder = CARPEncoder(
    encoder_name="carp_600k",
    embed_layer=0,
).encode(mut_seqs=seqs)
one_emb = next(no_flat_encoder)
one_emb, one_emb.shape

Loading carp_600k using 0 layer embedding
No embedding flattening


(array([[[ 0.11233783, -0.03096454,  0.09588867, ..., -0.21515825,
          -0.41078544,  0.8417189 ],
         [ 0.20500943, -0.14373404, -0.15035412, ...,  0.13152511,
          -0.12934107,  0.50128317],
         [ 0.4236132 , -0.09426615,  0.04951378, ...,  0.22084077,
          -3.9057853 ,  0.7927066 ],
         ...,
         [ 0.31486928, -0.00696401,  0.12247568, ..., -0.05261481,
          -4.633561  ,  0.55936426],
         [ 0.27489296, -0.01615845,  0.14888927, ..., -0.03519902,
           0.37774587,  0.5697537 ],
         [ 0.34991002,  0.25660813,  0.21758929, ...,  0.10277108,
          -3.0541043 ,  0.31116033]]], dtype=float32),
 (1, 56, 128))

In [13]:
mean_flat_encoder = CARPEncoder(
    encoder_name="carp_600k",
    embed_layer=0,
).encode(mut_seqs=seqs,flatten_emb="mean")
one_mean_emb = next(mean_flat_encoder)
one_mean_emb, one_mean_emb.shape

Loading carp_600k using 0 layer embedding


(array([[ 1.64436132e-01,  6.57109246e-02,  8.94564912e-02,
          3.09289932e-01,  1.32049993e-01,  1.55501708e-01,
          1.28102107e-02,  1.46981403e-01, -9.05814171e-02,
         -1.82426244e-01,  1.15854070e-01,  4.80966233e-02,
          2.20418856e-01,  2.25005895e-01,  9.72223431e-02,
          5.32191932e-01,  7.49111250e-02,  3.54565412e-01,
          7.33187646e-02,  1.70479506e-01,  3.80291134e-01,
         -8.53402689e-02, -6.54226467e-02,  3.57502520e-01,
          1.94397107e-01, -5.21607280e-01, -2.41224188e-02,
          3.47265989e-01, -1.94647282e-01,  5.34702502e-02,
          3.03542484e-02,  7.43700191e-02,  6.19242154e-02,
         -1.18623592e-01,  1.64853826e-01,  9.36362669e-02,
         -9.65878293e-02,  1.05360612e-01,  5.56887463e-02,
          1.70761794e-01,  1.93100184e-01,  1.66344896e-01,
         -5.06169535e-02, -2.64716838e-02, -1.86315969e-01,
          9.10815150e-02, -3.17426980e-03, -1.68990362e-02,
         -2.34090220e-02,  1.71465054e-0

In [15]:
model, collater = load_model_and_alphabet('carp_600k')

x = collater([seqs])[0]  # (n, max_len)
# rep = model(x)  # (n, max_len, d_model)

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

model.model.embedder.layers[0].register_forward_hook(get_activation("layer0"))

rep = model(x)

activation["layer0"], activation["layer0"].shape

(tensor([[[ 0.1123, -0.0310,  0.0959,  ..., -0.2152, -0.4108,  0.8417],
          [ 0.2050, -0.1437, -0.1504,  ...,  0.1315, -0.1293,  0.5013],
          [ 0.4236, -0.0943,  0.0495,  ...,  0.2208, -3.9058,  0.7927],
          ...,
          [ 0.3149, -0.0070,  0.1225,  ..., -0.0526, -4.6336,  0.5594],
          [ 0.2749, -0.0162,  0.1489,  ..., -0.0352,  0.3777,  0.5698],
          [ 0.3499,  0.2566,  0.2176,  ...,  0.1028, -3.0541,  0.3112]]]),
 torch.Size([1, 56, 128]))