In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [3]:
"""Pre-processing the dataset"""

from __future__ import annotations

from collections import Sequence

import os
from glob import glob
import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

from scr.utils import pickle_save, pickle_load, replace_ext
from scr.params.sys import RAND_SEED
from scr.params.emb import TRANSFORMER_INFO
from scr.preprocess.seq_loader import SeqLoader
from scr.encoding.encoding_classes import AbstractEncoder, ESMEncoder


def get_mut_name(mut_seq: str, parent_seq: str) -> str:
    """
    A function for returning the mutant name

    Args:
    - mut_seq: str, the full mutant sequence
    - parent_seq: str, the full parent sequence

    Returns:
    - str, parent, indel, or mutant name in the format of
        ParentAAMutLocMutAA:ParentAAMutLocMutAA:..., ie. W39W:D40G:G41C:V54Q
    """

    mut_list = []
    if parent_seq == mut_seq:
        return "parent"
    elif len(parent_seq) == len(mut_seq):
        for i, (p, m) in enumerate(zip(list(parent_seq), list(mut_seq))):
            if p != m:
                mut_list.append(f"{p}{i+1}{m}")
        return ":".join(mut_list)
    else:
        return "indel"


class AddMutInfo:
    """A class for appending mutation info for mainly protein engineering tasks"""

    def __init__(self, parent_seq_path: str, csv_path: str):

        # Load the parent sequence from the fasta file
        self._parent_seq = SeqLoader(parent_seq_path=parent_seq_path)

        # load the dataframe
        self._init_df = pd.read_csv(csv_path)

        self._df = self._init_df.copy()
        # add a column with the mutant names
        self._df["mut_name"] = self._init_df["sequence"].apply(
            get_mut_name, parent_seq=self._parent_seq
        )
        # add a column with the number of mutations
        self._df["mut_numb"] = (
            self._df["mut_name"].str.split(":").map(len, na_action="ignore")
        )

        # get the pickle file path
        self._pkl_path = replace_ext(input_path=csv_path, ext=".pkl")

        pickle_save(what2save=self._df, where2save=self._pkl_path)

    @property
    def parent_seq(self) -> str:
        """Return the parent sequence"""
        return self._parent_seq

    @property
    def pkl_path(self) -> str:
        """Return the pkl file path for the processed dataframe"""
        return self._pkl_path

    @property
    def df(self) -> pd.DataFrame:
        """Return the processed dataframe"""
        return self._df


class TaskProcess:
    """A class for handling different downstream tasks"""

    def __init__(self, data_folder: str = "data/"):
        """
        Args:
        - data_folder: str, a folder path with all the tasks as subfolders where
            all the subfolders have datasets as the subsubfolders, ie

            {data_folder}/
                proeng/
                    aav/
                        one_vs_many.csv
                        two_vs_many.csv
                        P03135.fasta
                    thermo/
                        mixed.csv
        """

        if data_folder[-1] == "/":
            self._data_folder = data_folder
        else:
            self._data_folder = data_folder + "/"

        # sumamarize all files i nthe data folder
        self._sum_file_df = self.sum_files()

    def sum_files(self) -> pd.DataFrame:
        """
        Summarize all files in the data folder

        Returns:
        - A dataframe with "task", "dataset", "split", "csv_path", "fasta_path", "pkl_path" as columns, ie.
            (proeng, gb1, low_vs_high, data/proeng/gb1/low_vs_high.csv, data/proeng/gb1/5LDE_1.fasta)
            note that csv_path is the list of lmdb files for the structure task
        """
        dataset_folders = glob(f"{self._data_folder}*/*")
        # need a list of tuples in the order of:
        # (task, dataset, split, csv_path, fasta_path)
        list_for_df = []
        for dataset_folder in dataset_folders:
            _, task, dataset = dataset_folder.split("/")
            if task == "structure":
                structure_file_list = [
                    file_path
                    for file_path in glob(f"{dataset_folder}/*.*")
                    if os.path.basename(os.path.splitext(file_path)[0]).split("_")[-1]
                    in ["train", "valid", "cb513"]
                ]
                list_for_df.append(
                    tuple([task, dataset, "cb513", structure_file_list, "", ""])
                )
            else:
                csv_paths = glob(f"{dataset_folder}/*.csv")
                fasta_paths = glob(f"{dataset_folder}/*.fasta")
                pkl_paths = glob(f"{dataset_folder}/*.pkl")

                assert len(csv_paths) >= 1, "Less than one csv"
                assert len(fasta_paths) <= 1, "More than one fasta"

                for csv_path in csv_paths:
                    # if parent seq fasta exists
                    if len(fasta_paths) == 1:
                        fasta_path = fasta_paths[0]

                        # if no existing pkl file, generate and save
                        if len(pkl_paths) == 0:
                            print(f"Adding mutation info to {csv_path}...")
                            pkl_path = AddMutInfo(
                                parent_seq_path=fasta_path, csv_path=csv_path
                            ).pkl_path
                        # pkl file exits
                        else:
                            pkl_path = replace_ext(input_path=csv_path, ext=".pkl")
                    # no parent fasta no pkl file
                    else:
                        fasta_path = ""
                        pkl_path = ""

                    list_for_df.append(
                        tuple(
                            [
                                task,
                                dataset,
                                os.path.basename(os.path.splitext(csv_path)[0]),
                                csv_path,
                                fasta_path,
                                pkl_path,
                            ]
                        )
                    )

        return pd.DataFrame(
            list_for_df,
            columns=["task", "dataset", "split", "csv_path", "fasta_path", "pkl_path"],
        )

    @property
    def sum_file_df(self) -> pd.DataFrame:
        """A summary table for all files in the data folder"""
        return self._sum_file_df


class ProtranDataset(Dataset):

    """A dataset class for processing protein transfer data"""

    def __init__(
        self,
        dataset_path: str,
        subset: str,
        encoder_class: AbstractEncoder,
        encoder_name: str,
        embed_layer: int,
        embed_batch_size: int = 0,
        flatten_emb: bool | str = False,
        embed_path: str = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        **encoder_params,
    ):

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
        - subset: str, train, val, test
        - encoder_class: AbstractEncoder, the encoder class
        - encoder_name: str, the name of the encoder
        - embed_layer: int, the layer number of the embedding
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_path: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - encoder_params: kwarg, additional parameters for encoding
        """

        # with additional info mut_name, mut_numb
        if os.path.splitext(dataset_path)[-1] in [".pkl", ".PKL", ""]:
            self._df = pickle_load(dataset_path)
            self._add_mut_info = True
        # without such info
        else:
            self._df = pd.read_csv(dataset_path)
            self._add_mut_info = False

        assert "set" in self._df.columns, f"set is not a column in {dataset_path}"
        assert (
            "validation" in self._df.columns
        ), f"validation is not a column in {dataset_path}"

        self._df_train = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] != True)
        ]
        self._df_val = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] == True)
        ]
        self._df_test = self._df.loc[(self._df["set"] == "test")]

        self._df_dict = {
            "train": self._df_train,
            "val": self._df_val,
            "test": self._df_test,
        }

        assert subset in list(
            self._df_dict.keys()
        ), "split can only be 'train', 'val', or 'test'"
        self._subset = subset

        self._subdf_len = len(self._df_dict[self._subset])

        # not specified seq start will be from 0
        if seq_start_idx == False:
            self._seq_start_idx = 0
        else:
            self._seq_start_idx = int(seq_start_idx)
        # not specified seq end will be the full sequence length
        if seq_end_idx == False:
            self._seq_end_idx = -1
        else:
            self._seq_end_idx = int(seq_end_idx)

        # get unencoded string of input sequence
        # will need to convert data type
        self.sequence = self._get_column_value("sequence")

        # check if pregenerated embedding
        if embed_path is not None:
            print(f"Loading pregenerated embeddings from {embed_path}")
            encoded_sequences = pickle_load(embed_path)

        # encode the sequences without the mut_name
        else:
            encoded_sequences = []

            for x in encoder_class(
                encoder_name=encoder_name, embed_layer=embed_layer
            ).encode(
                mut_seqs=self.sequence,
                batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                **encoder_params,
            ):
                encoded_sequences.append(x)

        self.x = torch.tensor(np.vstack(encoded_sequences), dtype=torch.float32)

        # get and format the fitness or secondary structure values
        # can be numbers or string
        # will need to convert data type
        # make 1D tensor 2D
        self.y = np.expand_dims(self._get_column_value("target"), 1)

        # add mut_name and mut_numb for relevant proeng datasets
        if self._add_mut_info:
            self.mut_name = self._get_column_value("mut_name")
            self.mut_numb = self._get_column_value("mut_numb")
        else:
            self.mut_name = [""] * self._subdf_len
            self.mut_numb = [np.nan] * self._subdf_len

    def __len__(self):
        """Return the length of the selected subset of the dataframe"""
        return self._subdf_len

    def __getitem__(self, idx: int):

        """
        Return the item in the order of
        encoded sequence (x), target (y), sequence, mut_name (optional), mut_numb (optional)
        """

        return (
            self.x[idx],
            self.y[idx],
            self.sequence[idx],
            self.mut_name[idx],
            self.mut_numb[idx],
        )

    def _get_column_value(self, column_name: str) -> np.ndarray:
        """
        Check and return the column values of the selected dataframe subset
        """
        if column_name in self._df.columns:
            if column_name == "sequence":
                return (
                    self._df_dict[self._subset]["sequence"]
                    .astype(str)
                    .str[self._seq_start_idx : self._seq_end_idx]
                    .values
                )
            else:
                return self._df_dict[self._subset][column_name].values

    @property
    def df_full(self) -> pd.DataFrame:
        """Return the full loaded dataset"""
        return self._df

    @property
    def df_train(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_train

    @property
    def df_val(self) -> pd.DataFrame:
        """Return the dataset for validation only"""
        return self._df_val

    @property
    def df_test(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_test


def split_protrain_loader(
    dataset_path: str,
    encoder_name: str,
    embed_layer: int,
    embed_batch_size: int = 128,
    flatten_emb: bool | str = False,
    embed_path: str | None = None,
    seq_start_idx: bool | int = False,
    seq_end_idx: bool | int = False,
    subset_list: list[str] = ["train", "val", "test"],
    loader_batch_size: int = 64,
    worker_seed: int = RAND_SEED,
    **encoder_params,
):

    """
    A function encode and load the data from a path

    Args:
    - dataset_path: str, full path to the dataset, in pkl or panda readable format
        columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
    - encoder_name: str, the name of the encoder
    - embed_layer: int, the layer number of the embedding
    - embed_batch_size: int, set to 0 to encode all in a single batch
    - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
    - embed_path: str = None, path to presaved embedding
    - seq_start_idx: bool | int = False, the index for the start of the sequence
    - seq_end_idx: bool | int = False, the index for the end of the sequence
    - subset_list: list of str, train, val, test
    - loader_batch_size: int, the batch size for train, val, and test dataloader
    - worker_seed: int, the seed for dataloader
    - encoder_params: kwarg, additional parameters for encoding
    """

    assert set(subset_list) <= set(
        ["train", "val", "test"]
    ), "subset_list can only contrain terms with in be 'train', 'val', or 'test'"

    # specify no shuffling for validation and test
    if_shuffle_list = [True if subset == "train" else False for subset in subset_list]

    if encoder_name in TRANSFORMER_INFO.keys():
        encoder_class = ESMEncoder

    return (
        DataLoader(
            dataset=ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_class=encoder_class,
                encoder_name=encoder_name,
                embed_layer=embed_layer,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_path=embed_path,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                **encoder_params,
            ),
            batch_size=loader_batch_size,
            shuffle=if_shuffle,
            worker_init_fn=worker_seed,
        )
        for subset, if_shuffle in zip(subset_list, if_shuffle_list)
    )

In [14]:
"""Add encoding classes with class methods"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence

import numpy as np
from tqdm import tqdm

import torch
from sequence_models.pretrained import load_model_and_alphabet

from scr.params.emb import TRANSFORMER_INFO, TRANSFORMER_MAX_SEQ_LEN, CARP_INFO
from scr.params.sys import DEVICE


class AbstractEncoder(ABC):
    """
    An abstract encoder class to fill in for different kinds of encoders

    All encoders will have an "encode" function
    """

    def __init__(self, encoder_name: str, embed_layer: int):

        """
        Args:
        - encoder_name: str, the name of the encoder
        - embed_layer: int, the layer number of the embedding
        """

        self._encoder_name = encoder_name
        self._embed_layer = embed_layer

    def encode(
        self,
        mut_seqs: Sequence[str] | str,
        batch_size: int = 0,
        flatten_emb: bool | str = False,
        mut_names: Sequence[str] | str | None = None,
    ) -> Iterable[np.ndarray]:
        """
        A function takes a list of sequences to yield a batch of encoded elements

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - generator: dict with layer number as keys and
            encoded flattened sequence with or without labels as value
        """

        if isinstance(mut_seqs, str):
            mut_seqs = [mut_seqs]

        # If the batch size is 0, then encode all at once in a single batch
        if batch_size == 0:
            yield self._encode_batch(
                mut_seqs=mut_seqs, flatten_emb=flatten_emb, mut_names=mut_names
            )

        # Otherwise, yield chunks of encoded sequence
        else:

            for i in tqdm(range(0, len(mut_seqs), batch_size)):

                # figure out what mut_names to feed in
                if mut_names is None:
                    mut_name_batch = mut_names
                else:
                    mut_name_batch = mut_names[i : i + batch_size]

                yield self._encode_batch(
                    mut_seqs=mut_seqs[i : i + batch_size],
                    flatten_emb=flatten_emb,
                    mut_names=mut_name_batch,
                )

    def flatten_encode(
        self, encoded_mut_seqs: np.ndarray, flatten_emb: bool | str
    ) -> np.ndarray:
        """
        Flatten the embedding or just return the encoded mutants.

        Args:
        - encoded_mut_seqs: np.ndarray, shape [batch_size, seq_len, embed_dim]
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
            - True -> shape [batch_size, seq_len * embed_dim]
            - "max" or "mean" -> shape [batch_size, embed_dim]
            - False or everything else -> [batch_size, seq_len, embed_dim]

        Returns:
        - np.ndarray, shape depends on flatten_emb parameter
        """
        assert encoded_mut_seqs.shape[2] == self._embed_dim, "Wrong embed dim"

        if flatten_emb is True:
            # shape [batch_size, seq_len * embed_dim]
            return encoded_mut_seqs.reshape(encoded_mut_seqs.shape[0], -1)

        elif isinstance(flatten_emb, str):
            if flatten_emb == "mean":
                # [batch_size, embed_dim]
                return encoded_mut_seqs.mean(axis=1)
            elif flatten_emb == "max":
                # [batch_size, embed_dim]
                return encoded_mut_seqs.max(axis=1)

        else:
            print("No embedding flattening")
            # [batch_size, seq_len, embed_dim]
            return encoded_mut_seqs

    @abstractmethod
    def _encode_batch(
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encode a single batch of mut_seqs
        """
        pass

    @property
    def embed_dim(self) -> int:
        """The dim of the embedding"""
        return self._embed_dim

    @property
    def embed_layer(self) -> int:
        """The layer nubmer of the embedding"""
        return self._embed_layer

    @property
    def encoder_name(self) -> str:
        """The name of the encoding method"""
        return self._encoder_name


class ESMEncoder(AbstractEncoder):
    """
    Build an ESM encoder
    """

    def __init__(
        self,
        encoder_name: str,
        embed_layer: int,
        iftrimCLS: bool = True,
        iftrimEOS: bool = True,
    ):
        """
        Args
        - encoder_name: str, the name of the encoder, one of the keys of TRANSFORMER_INFO
        - embed_layer: int, the layer number of the embedding
        - iftrimCLS: bool, whether to trim the first classifification token
        - iftrimEOS: bool, whether to trim the end of sequence token, if exists
        """

        super().__init__(encoder_name, embed_layer)

        self._iftrimCLS = iftrimCLS
        self._iftrimEOS = iftrimEOS

        # load model from torch.hub
        print(f"Loading {self._encoder_name} using {self._embed_layer} layer embedding")
        self.model, self.alphabet = torch.hub.load(
            "facebookresearch/esm:main", model=self._encoder_name
        )
        self.batch_converter = self.alphabet.get_batch_converter()

        # set model to eval mode
        self.model.eval()
        self.model.to(DEVICE)

        self._embed_dim, self._max_emb_layer, _ = TRANSFORMER_INFO[self._encoder_name]

        assert (
            self._embed_layer <= self._max_emb_layer
        ), f"{self._embed_layer} exceeds {self._max_emb_layer}"

        expected_num_layers = int(self._encoder_name.split("_")[-3][1:])
        assert (
            expected_num_layers == self._max_emb_layer
        ), "Wrong ESM model name or layer"

    def _encode_batch(
        self,
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encodes a batch of mutant sequences.

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - np.ndarray or a tuple(np.ndarray, list[str]) where the list is batch_labels
        """

        if isinstance(mut_names, str):
            mut_names = [mut_names]

        # pair the mut_names and mut_seqs
        if mut_names is not None:
            assert len(mut_names) == len(
                mut_seqs
            ), "mutant_name and mut_seqs different length"
            mut_seqs = [(n, m) for (n, m) in zip(mut_names, mut_seqs)]
        else:
            mut_seqs = [("", m) for m in mut_seqs]

        # convert raw mutant sequences to tokens
        batch_labels, _, batch_tokens = self.batch_converter(mut_seqs)
        batch_tokens = batch_tokens.to(DEVICE)

        # Turn off gradients and pass the batch through
        with torch.no_grad():
            # shape [batch_size, seq_len + pad, embed_dim]
            if batch_tokens.shape[1] > TRANSFORMER_MAX_SEQ_LEN:
                print(f"Sequence exceeds {TRANSFORMER_MAX_SEQ_LEN}, chopping the end")
                batch_tokens = batch_tokens[:, :TRANSFORMER_MAX_SEQ_LEN]

            dict_encoded_mut_seqs = self.model(
                batch_tokens, repr_layers=list(range(self._embed_layer + 1))
            )["representations"]

        for layer, encoded_mut_seqs in dict_encoded_mut_seqs.items():

            encoded_mut_seqs = encoded_mut_seqs.cpu().numpy()
            # https://github.com/facebookresearch/esm/blob/main/esm/data.py
            # from_architecture

            # trim off initial classification token [CLS]
            # both "ESM-1" and "ESM-1b" have prepend_bos = True
            if self._iftrimCLS and self._encoder_name.split("_")[0] in [
                "esm1",
                "esm1b",
            ]:
                encoded_mut_seqs = encoded_mut_seqs[:, 1:, :]

            # trim off end-of-sequence token [EOS]
            # only "ESM-1b" has append_eos = True
            if self._iftrimEOS and self._encoder_name.split("_")[0] == "esm1b":
                encoded_mut_seqs = encoded_mut_seqs[:, :-1, :]

            if mut_names is not None:
                dict_encoded_mut_seqs[layer] = (
                    self.flatten_encode(encoded_mut_seqs, flatten_emb),
                    batch_labels,
                )
            else:
                dict_encoded_mut_seqs[layer] = self.flatten_encode(
                    encoded_mut_seqs, flatten_emb
                )

        return dict_encoded_mut_seqs


class CARPEncoder(AbstractEncoder):
    """
    Build a CARP encoder
    """

    def __init__(
        self,
        encoder_name: str,
        embed_layer: int,
    ):
        """
        Args
        - encoder_name: str, the name of the encoder, one of the keys of CARP_INFO
        - embed_layer: int, the layer number of the embedding
        """

        super().__init__(encoder_name, embed_layer)

        # load model from torch.hub
        print(f"Loading {self._encoder_name} using {self._embed_layer} layer embedding")

        self.model, self.collater = load_model_and_alphabet(self._encoder_name)

        # set model to eval mode
        self.model.eval()
        self.model.to(DEVICE)

        self._embed_dim, self._max_emb_layer = CARP_INFO[self._encoder_name]

        assert (
            self._embed_layer <= self._max_emb_layer
        ), f"{self._embed_layer} exceeds {self._max_emb_layer}"

    def _encode_batch(
        self,
        mut_seqs: Sequence[str] | str,
        flatten_emb: bool | str,
        mut_names: Sequence[str] | str | None = None,
    ) -> np.ndarray:
        """
        Encodes a batch of mutant sequences.

        Args:
        - mut_seqs: list of str or str, mutant sequences of the same length
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - mut_names: list of str or str or None, mutant names

        Returns:
        - np.ndarray or a tuple(np.ndarray, list[str]) where the list is batch_labels
        """

        mut_seqs = [[m] for m in mut_seqs]

        x = self.collater(mut_seqs)[0]

        activation = {}

        def get_activation(name):
            def hook(model, input, output):
                activation[name] = output.detach()

            return hook

        # convert raw mutant sequences to tokens
        for layer_numb in list(range(self._embed_layer + 1)):
            self.model.model.embedder.layers[layer_numb].register_forward_hook(
                get_activation(layer_numb)
            )

        rep = self.model(x)

        for layer_numb, encoded_mut_seqs in activation.items():
            activation[layer_numb] = self.flatten_encode(
                activation[layer_numb].cpu().numpy(), flatten_emb
            )

        return activation

In [7]:
from scr.utils import pickle_load

In [8]:
df = pickle_load("data/proeng/gb1/two_vs_rest.pkl")

In [9]:
df_train = df.loc[(df["set"] == "train") & (df["validation"] != True)]
df_val = df.loc[(df["set"] == "train") & (df["validation"] == True)]
df_test = df.loc[(df["set"] == "test")]

len(df_train), len(df_val), len(df_test), len(df)

(381, 43, 8309, 8733)

In [10]:

seqs = [[seq] for seq in df_val.sequence.astype(str).str[0 : 56].values[0:2]]
seqs

[['MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVNGEWTYDDATKTFTVTE'],
 ['MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGEYGEWTYDDATKTFTVTE']]

In [16]:
no_flat_encoder = CARPEncoder(
    encoder_name="carp_600k",
    embed_layer=6,
).encode(mut_seqs=list(df_val.sequence.astype(str).str[0 : 56].values[0:2]))
one_emb = next(no_flat_encoder)
one_emb.keys(), one_emb[0].shape

Loading carp_600k using 6 layer embedding
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening


(dict_keys([0, 1, 2, 3, 4, 5, 6]), (2, 56, 128))

In [18]:
mean_flat_encoder = CARPEncoder(
    encoder_name="carp_600k",
    embed_layer=6,
).encode(mut_seqs=list(df_val.sequence.astype(str).str[0 : 56].values[0:2]),flatten_emb="mean")
one_mean_emb = next(mean_flat_encoder)
one_mean_emb.keys(), one_mean_emb[0].shape

Loading carp_600k using 6 layer embedding


(dict_keys([0, 1, 2, 3, 4, 5, 6]), (2, 128))

In [21]:
no_flat_encoder = ESMEncoder(
    encoder_name="esm1b_t33_650M_UR50S",
    embed_layer=33,
).encode(mut_seqs=list(df_val.sequence), mut_names=list(df_val.mut_name))
one_emb = next(no_flat_encoder)
one_emb.keys(), one_emb[0][0].shape

Loading esm1b_t33_650M_UR50S using 33 layer embedding


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening


(dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]),
 (43, 265, 1280))

In [22]:
mean_flat_encoder = ESMEncoder(
    encoder_name="esm1b_t33_650M_UR50S",
    embed_layer=33,
).encode(mut_seqs=list(df_val.sequence), mut_names=list(df_val.mut_name),flatten_emb="mean")
one_mean_emb = next(mean_flat_encoder)
one_mean_emb.keys(), one_mean_emb[0][0].shape

Loading esm1b_t33_650M_UR50S using 33 layer embedding


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


(dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]),
 (43, 1280))

In [23]:
no_flat_encoder = ESMEncoder(
    encoder_name="esm1b_t33_650M_UR50S",
    embed_layer=33,
).encode(mut_seqs=list(df_val.sequence))
one_emb = next(no_flat_encoder)
one_emb.keys(), one_emb[0].shape

Loading esm1b_t33_650M_UR50S using 33 layer embedding


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening
No embedding flattening


(dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]),
 (43, 265, 1280))

In [24]:
one_emb

{0: array([[[ 0.27928993, -0.00919111,  0.05307703, ...,  0.0052783 ,
          -0.08536895,  0.2144556 ],
         [ 0.86820096,  0.01390768, -0.03625386, ..., -0.07747608,
          -0.06879894,  0.10098855],
         [ 0.9904541 ,  0.02066673,  0.04346693, ..., -0.04951385,
           0.04328922, -0.09477672],
         ...,
         [ 0.01445957,  0.15244195,  0.13133049, ...,  0.00364488,
           0.50209504,  0.2052859 ],
         [ 0.1433187 ,  0.08901688,  0.06859887, ..., -0.04608465,
           0.14192118,  0.12139929],
         [-0.07774346,  0.03318178,  0.14130217, ..., -0.08603121,
           0.30741137,  0.01056057]],
 
        [[ 0.27928993, -0.00919111,  0.05307703, ...,  0.0052783 ,
          -0.08536895,  0.2144556 ],
         [ 0.86820096,  0.01390768, -0.03625386, ..., -0.07747608,
          -0.06879894,  0.10098855],
         [ 0.9904541 ,  0.02066673,  0.04346693, ..., -0.04951385,
           0.04328922, -0.09477672],
         ...,
         [ 0.01445957,  0.1524

In [25]:
one_mean_emb

{0: (array([[-0.17739098, -0.00067187,  0.01667726, ...,  0.00655907,
           0.03614971,  0.09654725],
         [-0.1747657 , -0.00065019,  0.01660449, ...,  0.00632377,
           0.03780067,  0.09638622],
         [-0.1749173 , -0.00061091,  0.01679106, ...,  0.00663596,
           0.03946506,  0.09655192],
         ...,
         [-0.177018  , -0.00064966,  0.0166867 , ...,  0.00618463,
           0.03673364,  0.09675171],
         [-0.1757744 ,  0.00030468,  0.01655487, ...,  0.00629132,
           0.03641494,  0.09631334],
         [-0.17786627, -0.00057332,  0.01687178, ...,  0.00618563,
           0.03663646,  0.09616786]], dtype=float32),
  ['D40N',
   'V39E:D40Y',
   'V39F:V54G',
   'V39K:D40I',
   'V39L:D40A',
   'V39L:D40I',
   'V39L:D40R',
   'V39L:D40S',
   'V39L:D40V',
   'V39M:V54L',
   'V39M:D40Q',
   'V39M:D40T',
   'V39Q:D40H',
   'V39Q:D40S',
   'V39R:D40R',
   'V39R:D40Y',
   'V39T:G41K',
   'V39T:D40K',
   'V39T:D40T',
   'D40A:V54N',
   'D40A:V54T',
   'D40A:V5