# CARP

In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [3]:
from sequence_models.pretrained import load_model_and_alphabet

In [17]:
model, collater = load_model_and_alphabet('carp_600k')

In [26]:
seqs = [df_val.sequence.astype(str).str[0 : 56].values[0]]
x = collater(seqs)[0]  # (n, max_len)
# rep = model(x)  # (n, max_len, d_model)

In [19]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [20]:
model.model.embedder.layers[0].register_forward_hook(get_activation("layer0"))

<torch.utils.hooks.RemovableHandle at 0x7f446f7dc280>

In [28]:
rep = model(x)

In [29]:
activation["layer0"]

tensor([[[-3.1794e-01,  1.1626e-01,  6.0831e-02,  7.6150e-02,  7.5652e-02,
           1.4893e-01,  1.9834e-02, -3.0836e-01,  2.5717e-01, -3.4265e-01,
           7.3477e-01,  1.3106e+00,  8.5214e-01, -2.4842e-01,  1.0407e+00,
           2.3907e+00,  2.6350e-01,  7.8502e-01,  1.1730e-01,  7.7125e-02,
          -1.5943e-01, -1.7217e-01, -7.6545e-01, -3.2751e+00, -4.3581e-01,
           1.2820e+00, -2.3568e-01,  6.7061e-02,  8.5233e-01,  1.5238e-01,
           3.1420e-01,  3.3563e-01,  7.9613e-02, -6.2080e-01,  2.9033e-01,
           5.0830e-01, -3.7001e-01, -2.1508e-01, -8.2633e-01,  5.3360e-01,
           2.3634e-02, -5.6023e-01, -1.3095e-01,  1.3029e-01, -6.4150e+00,
          -3.3860e-01, -3.2213e-02,  5.3964e-01, -5.5691e-01, -9.4326e-02,
          -4.9920e-01, -7.1323e-01, -2.7627e-01, -2.5626e-01, -4.8774e-02,
          -2.0168e+00, -3.7258e-01,  5.3224e-01, -1.0354e-02, -1.3404e+00,
          -1.0159e+00, -3.6293e-01,  1.7592e-01, -9.3403e-02, -4.4597e+00,
          -9.5616e-02, -1

In [13]:
"""Add encoding classes with class methods"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence

import numpy as np
from tqdm import tqdm

import torch
from sequence_models.pretrained import load_model_and_alphabet

from scr.params.emb import TRANSFORMER_INFO, TRANSFORMER_MAX_SEQ_LEN, CARP_INFO
from scr.params.sys import DEVICE


class AbstractEncoder(ABC):
    """
    An abstract encoder class to fill in for different kinds of encoders

    All encoders will have an "encode" function
    """

    def __init__(self, encoder_name: str, embed_layer: int):

        """
        Args:
        - encoder_name: str, the name of the encoder
        - embed_layer: int, the layer number of the embedding
        """

        self._encoder_name = encoder_name
        self._embed_layer = embed_layer

    def encode(
        self,
        mut_seqs: Sequence[str] | str,
        batch_size: int = 0,
        flatten_emb: bool | str = False,
        mut_names: Sequence[str] | str | None = None,
    ) -> Iterable[np.ndarray]:
        """
        A function takes a list of sequences to yield a batch of encoded elements

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names
        """

        if isinstance(mut_seqs, str):
            mut_seqs = [mut_seqs]

        # If the batch size is 0, then encode all at once in a single batch
        if batch_size == 0:
            yield self._encode_batch(
                mut_seqs=mut_seqs, flatten_emb=flatten_emb, mut_names=mut_names
            )

        # Otherwise, yield chunks of encoded sequence
        else:

            for i in tqdm(range(0, len(mut_seqs), batch_size)):

                # figure out what mut_names to feed in
                if mut_names is None:
                    mut_name_batch = mut_names
                else:
                    mut_name_batch = mut_names[i : i + batch_size]

                yield self._encode_batch(
                    mut_seqs=mut_seqs[i : i + batch_size],
                    flatten_emb=flatten_emb,
                    mut_names=mut_name_batch,
                )

    def flatten_encode(
        self, encoded_mut_seqs: np.ndarray, flatten_emb: bool | str
    ) -> np.ndarray:
        """
        Flatten the embedding or just return the encoded mutants.

        Args:
        - encoded_mut_seqs: np.ndarray, shape [batch_size, seq_len, embed_dim]
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
            - True -> shape [batch_size, seq_len * embed_dim]
            - "max" or "mean" -> shape [batch_size, embed_dim]
            - False or everything else -> [batch_size, seq_len, embed_dim]

        Returns:
        - np.ndarray, shape depends on flatten_emb parameter
        """
        assert encoded_mut_seqs.shape[2] == self._embed_dim, "Wrong embed dim"

        if flatten_emb is True:
            # shape [batch_size, seq_len * embed_dim]
            return encoded_mut_seqs.reshape(encoded_mut_seqs.shape[0], -1)

        elif isinstance(flatten_emb, str):
            if flatten_emb == "mean":
                # [batch_size, embed_dim]
                return encoded_mut_seqs.mean(axis=1)
            elif flatten_emb == "max":
                # [batch_size, embed_dim]
                return encoded_mut_seqs.max(axis=1)

        else:
            print("No embedding flattening")
            # [batch_size, seq_len, embed_dim]
            return encoded_mut_seqs

    @abstractmethod
    def _encode_batch(
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encode a single batch of mut_seqs
        """
        pass

    @property
    def embed_dim(self) -> int:
        """The dim of the embedding"""
        return self._embed_dim

    @property
    def embed_layer(self) -> int:
        """The layer nubmer of the embedding"""
        return self._embed_layer

    @property
    def encoder_name(self) -> str:
        """The name of the encoding method"""
        return self._encoder_name


class ESMEncoder(AbstractEncoder):
    """
    Build an ESM encoder
    """

    def __init__(
        self,
        encoder_name: str,
        embed_layer: int,
        iftrimCLS: bool = True,
        iftrimEOS: bool = True,
    ):
        """
        Args
        - encoder_name: str, the name of the encoder, one of the keys of TRANSFORMER_INFO
        - embed_layer: int, the layer number of the embedding
        - iftrimCLS: bool, whether to trim the first classifification token
        - iftrimEOS: bool, whether to trim the end of sequence token, if exists
        """

        super().__init__(encoder_name, embed_layer)

        self._iftrimCLS = iftrimCLS
        self._iftrimEOS = iftrimEOS

        # load model from torch.hub
        print(f"Loading {self._encoder_name} using {self._embed_layer} layer embedding")
        self.model, self.alphabet = torch.hub.load(
            "facebookresearch/esm:main", model=self._encoder_name
        )
        self.batch_converter = self.alphabet.get_batch_converter()

        # set model to eval mode
        self.model.eval()
        self.model.to(DEVICE)

        self._embed_dim, self._max_emb_layer, _ = TRANSFORMER_INFO[self._encoder_name]

        assert (
            self._embed_layer <= self._max_emb_layer
        ), f"{self._embed_layer} exceeds {self._max_emb_layer}"

        expected_num_layers = int(self._encoder_name.split("_")[-3][1:])
        assert (
            expected_num_layers == self._max_emb_layer
        ), "Wrong ESM model name or layer"

    def _encode_batch(
        self,
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encodes a batch of mutant sequences.

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - np.ndarray or a tuple(np.ndarray, list[str]) where the list is batch_labels
        """

        if isinstance(mut_names, str):
            mut_names = [mut_names]

        # pair the mut_names and mut_seqs
        if mut_names is not None:
            assert len(mut_names) == len(
                mut_seqs
            ), "mutant_name and mut_seqs different length"
            mut_seqs = [(n, m) for (n, m) in zip(mut_names, mut_seqs)]
        else:
            mut_seqs = [("", m) for m in mut_seqs]

        # convert raw mutant sequences to tokens
        batch_labels, _, batch_tokens = self.batch_converter(mut_seqs)
        batch_tokens = batch_tokens.to(DEVICE)

        # Turn off gradients and pass the batch through
        with torch.no_grad():
            # shape [batch_size, seq_len + pad, embed_dim]
            if batch_tokens.shape[1] > TRANSFORMER_MAX_SEQ_LEN:
                print(f"Sequence exceeds {TRANSFORMER_MAX_SEQ_LEN}, chopping the end")
                batch_tokens = batch_tokens[:, :TRANSFORMER_MAX_SEQ_LEN]

            encoded_mut_seqs = (
                self.model(batch_tokens, repr_layers=[self._embed_layer])[
                    "representations"
                ][self._embed_layer]
                .cpu()
                .numpy()
            )

        # https://github.com/facebookresearch/esm/blob/main/esm/data.py
        # from_architecture

        # trim off initial classification token [CLS]
        # both "ESM-1" and "ESM-1b" have prepend_bos = True
        if self._iftrimCLS and self._encoder_name.split("_")[0] in ["esm1", "esm1b"]:
            encoded_mut_seqs = encoded_mut_seqs[:, 1:, :]

        # trim off end-of-sequence token [EOS]
        # only "ESM-1b" has append_eos = True
        if self._iftrimEOS and self._encoder_name.split("_")[0] == "esm1b":
            encoded_mut_seqs = encoded_mut_seqs[:, :-1, :]

        if mut_names is not None:
            return self.flatten_encode(encoded_mut_seqs, flatten_emb), batch_labels
        else:
            return self.flatten_encode(encoded_mut_seqs, flatten_emb)


class CARPEncoder(AbstractEncoder):
    """
    Build a CARP encoder
    """

    def __init__(
        self,
        encoder_name: str,
        embed_layer: int,
    ):
        """
        Args
        - encoder_name: str, the name of the encoder, one of the keys of CARP_INFO
        - embed_layer: int, the layer number of the embedding
        """

        super().__init__(encoder_name, embed_layer)

        # load model from torch.hub
        print(f"Loading {self._encoder_name} using {self._embed_layer} layer embedding")

        self.model, self.collater = load_model_and_alphabet(self._encoder_name)

        # set model to eval mode
        self.model.eval()
        self.model.to(DEVICE)

        self._embed_dim, self._max_emb_layer = CARP_INFO[self._encoder_name]

        assert (
            self._embed_layer <= self._max_emb_layer
        ), f"{self._embed_layer} exceeds {self._max_emb_layer}"

    def _encode_batch(
        self,
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encodes a batch of mutant sequences.

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - np.ndarray or a tuple(np.ndarray, list[str]) where the list is batch_labels
        """

        mut_seqs = [[m] for m in mut_seqs]
        
        x = self.collater(mut_seqs)[0]

        layer_name = f"layer{str(self._embed_layer)}"

        activation = {}
        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.detach()

            return hook

        # convert raw mutant sequences to tokens
        self.model.model.embedder.layers[self._embed_layer].register_forward_hook(
            get_activation(layer_name)
        )

        rep = self.model(x)
        
        encoded_mut_seqs = activation[layer_name].cpu().numpy()

        return self.flatten_encode(encoded_mut_seqs, flatten_emb)

In [5]:
from scr.utils import pickle_load

In [6]:
df = pickle_load("data/proeng/gb1/two_vs_rest.pkl")

In [7]:
df_train = df.loc[(df["set"] == "train") & (df["validation"] != True)]
df_val = df.loc[(df["set"] == "train") & (df["validation"] == True)]
df_test = df.loc[(df["set"] == "test")]

len(df_train), len(df_val), len(df_test), len(df)

(381, 43, 8309, 8733)

In [8]:
CARP_INFO.keys()

dict_keys(['carp_600k', 'carp_38M', 'carp_76M', 'carp_640M'])

In [14]:
no_flat_encoder = CARPEncoder(
    encoder_name="carp_600k",
    embed_layer=0,
).encode(mut_seqs=list(df_val.sequence.astype(str).str[0 : 56].values))
one_emb = next(no_flat_encoder)
one_emb.shape

Loading carp_600k using 0 layer embedding
No embedding flattening


(43, 56, 128)

In [11]:
mean_flat_encoder = CARPEncoder(
    encoder_name="carp_600k",
    embed_layer=0,
).encode(mut_seqs=list(df_val.sequence.astype(str).str[0 : 56].values),flatten_emb="mean")
one_mean_emb = next(mean_flat_encoder)
one_mean_emb.shape

Loading carp_600k using 0 layer embedding


(43, 128)

In [30]:
one_emb[0]

array([[ 0.11233792, -0.03096455,  0.09588867, ..., -0.2151581 ,
        -0.41078562,  0.8417188 ],
       [ 0.2050094 , -0.14373404, -0.15035413, ...,  0.13152504,
        -0.12934083,  0.50128305],
       [ 0.4236132 , -0.09426624,  0.04951383, ...,  0.22084078,
        -3.9057856 ,  0.79270667],
       ...,
       [ 0.31486923, -0.00696398,  0.12247562, ..., -0.05261476,
        -4.633561  ,  0.5593643 ],
       [ 0.27489296, -0.01615845,  0.14888933, ..., -0.03519906,
         0.37774575,  0.5697537 ],
       [ 0.34990993,  0.256608  ,  0.21758929, ...,  0.10277098,
        -3.0541036 ,  0.31116045]], dtype=float32)

Bad pipe message: %s [b'"\x19\xfc']
Bad pipe message: %s [b"z\xb5\xbb\xa8z\xf3a\x80\xb0\xa4^/\xbb\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00<\x005\x00/\x00\x9a\x00\x99\xc0\x07\xc0\x11\x00\x96\x00\x05\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00", b'#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x000\x00.\x04\x03\x05\x03\x06\x03']
Bad pipe message: %s [b'\x08\x08\x08\t\x08\n\x08', b'\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06']
Bad pipe message: %s [b'', b'\x03\x03']
Bad pipe message: %s [b'']
Bad pipe message: %s [b'', b'\x02']
Bad pipe message: %s [b'\x05\x02\x06']
Bad pipe message: %s