# evoprotgrad


In [None]:
from evo_prot_grad.experts.base_experts import AttributeExpert, Expert
from evo_prot_grad import DirectedEvolution
from evo_prot_grad.common.tokenizers import ExpertTokenizer, OneHotTokenizer
from evo_prot_grad.common.utils import CANONICAL_ALPHABET
import evo_prot_grad.common.utils as utils
import evo_prot_grad.common.tokenizers as tokenizers
from evo_prot_grad.common.embeddings import OneHotEmbedding
from typing import List, Tuple, Optional
from torch import nn

In [None]:
class SaprotOneHotTokenizer(OneHotTokenizer):
    
    def __init__(self, alphabet: List[str]=CANONICAL_ALPHABET):
        """
        Args:
            alphabet (List[str]): A list of amino acid characters.
        """
        super().__init__(alphabet)
    
    def remove_structure_tokens(self, seq):
        return ''.join([aa[0] for aa in seq if aa[0] in self.alphabet])
    
    def __call__(self, seqs: List[str]) -> torch.FloatTensor:
        # strip structural char
        seqs_only = [self.remove_structure_tokens(seq) for seq in seqs]
        return super().__call__(seqs=seqs_only)
    
    def decode(self, ohs: torch.Tensor) -> List[str]:
       seqs = super().decode(ohs)
       return [self.add_structure_tokens(seq) for seq in seqs]
    
    def add_structure_tokens(self, seq):
        return ''.join([aa + '#' for aa in seq])

In [None]:
class SaprotRegressionExpert(AttributeExpert):
    
    def __init__(self, 
                 temperature: float,
                 model: nn.Module,
                 scoring_strategy: str,
                 device: str,
                 tokenizer: Optional[tokenizers.ExpertTokenizer] = None):
        """
        Args:
            temperature (float): Hyperparameter for re-scaling this expert in the Product of Experts.
            model (nn.Module): The model to use for the expert.
            scoring_strategy (str): The approach used to score mutations with this expert.
            tokenizer (ExpertTokenizer): The tokenizer to use for the expert.
            device (str): The device to use for the expert.
        """
        
        super().__init__(
            temperature,
            model,
            scoring_strategy,
            device,
            tokenizer=tokenizer
            )
        self.model.esm.embeddings.word_embeddings = OneHotEmbedding(model.esm.embeddings.word_embeddings)
    
    def tokenize(self, inputs: List[str]):
        """Tokenizes a list of protein sequences.
        
        Args:
            inputs (List[str]): A list of protein sequences.
        """
        tokenized = self.tokenizer(inputs, return_tensors="pt", padding=True)
        return {k: v.to(self.device) for k, v in tokenized.items()}
    
    def get_model_output(self, inputs: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns both the onehot-encoded inputs and model's predictions.

        Args:
            inputs (List[str]): A list of protein sequence strings of len [parallel_chains].
        Returns: 
            x_oh: (torch.Tensor) of shape [parallel_chains, seq_len, vocab_size]
            attribute_values: (torch.Tensor) of shape [parallel_chains, seq_len, vocab_size]            
        """
        encoded_inputs = self.tokenize(inputs)
        attribute_values = self.model(**encoded_inputs)
        oh = self._get_last_one_hots()
        return oh, attribute_values
    
    def _get_last_one_hots(self) -> torch.Tensor:
        """ Returns the one-hot tensors *most recently passed* as input.
        """
        return self.model.esm.embeddings.word_embeddings.one_hots
    
class SaprotDirectedEvolution(DirectedEvolution):
    
    def __init__(self, 
                 experts: List[Expert],
                 parallel_chains: int,
                 n_steps: int,
                 max_mutations: int,
                 output: str = 'last',
                 preserved_regions: Optional[List[Tuple[int, int]]] = None,
                 wt_protein: Optional[str] = None,
                 wt_fasta: Optional[str] = None,
                 verbose: bool = False,
                 random_seed: Optional[int] = None):
        """
        Args:
            experts (List[Expert]): List of experts
            parallel_chains (int): number of parallel chains
            n_steps (int): number of steps to run directed evolution
            max_mutations (int): maximum mutation distance from WT, disable by setting to -1.
            output (str): output type, either 'best', 'last' or 'all'. Default is 'last'.
            preserved_regions (List[Tuple[int,int]]): list of tuples of (start, end) of preserved regions. Default is None.
            wt_protein (str): wt sequence as a string. Must provide one of wt_protein or wt_fasta.
            wt_fasta (str): path to fasta file containing wt sequence.
                Must provide one of wt_protein or wt_fasta.
            verbose (bool): whether to print verbose output. Default is False.
            random_seed (int): random seed for reproducibility. Default is None.
        Raises:
            ValueError: if `n_steps` < 1.
            ValueError: if neither `wt_protein` nor `wt_fasta` is provided.
            ValueError: if a fasta file is passed to `wt_protein` argument.
            ValueError: if `output` is not one of 'best', 'last' or 'all'.
            ValueError: if no experts are provided.
            ValueError: if any of the preserved regions are < 1 amino acid long.
        """
        self.experts = experts
        self.parallel_chains = parallel_chains
        self.n_steps = n_steps
        self.max_mutations = max_mutations
        self.output = output
        self.preserved_regions = preserved_regions
        self.wt_protein = wt_protein
        self.wt_fasta = wt_fasta
        self.verbose = verbose
        self.random_seed = random_seed
        self.device = self.experts[0].device
        
        # Checks
        if self.n_steps < 1:
            raise ValueError("`n_steps` must be >= 1")
        if not (self.wt_protein is not None or self.wt_fasta is not None):
            raise ValueError("Must provide one of `wt_protein` or `wt_fasta`")
        if output not in ['best', 'last', 'all']:
            raise ValueError("`output` must be one of 'best', 'last' or 'all'")
        if len(self.experts) < 1:
            raise ValueError("Must provide at least one expert")
        
        if random_seed is not None:
            utils.set_seed(random_seed)
        if self.preserved_regions is not None:
            for start, end in self.preserved_regions:
                if end - start < 0:
                    raise ValueError("Preserved regions must be at least 1 amino acid long")
                
        # maintains a tokenizer with canonical alphabet
        # for the one-hot encoded chains
        self.canonical_chain_tokenizer = SaprotOneHotTokenizer()
        
        if self.wt_protein is not None:
            if '.fasta' in self.wt_protein:
                raise ValueError("Did you mean to use the `wt_fasta` argument instead of `wt_protein`?")    
            self.wtseq = self.wt_protein
            # Add spaces between each amino acid if necessary
            if ' ' not in self.wtseq:
                self.wtseq = ' '.join(self.wtseq)
        # Check if wt_protein is a fasta file
        elif self.wt_fasta is not None:
            self.wtseq = utils.read_fasta(self.wt_fasta)
        if self.verbose:
            print(f">Wildtype sequence: {self.wtseq}")
        self.reset()

        ### Hyperparams
        self.max_pas_path_length = 2

In [None]:
yfp_expert = SaprotRegressionExpert(
    temperature=1,
    model=models['saprot_yfp'],
    tokenizer=models['saprot_tokenizer'],
    device='cuda',
    scoring_strategy = 'attribute_value',
)

In [None]:
s = top_variants['variant_90']

In [None]:
sa = ''.join([aa + '#' for aa in s]) + '#'

In [None]:
models['saprot_tokenizer']

'M#V#T#R#L#E#I#H#Y#T#G#E#I#P#V#R#Y#N#L#K#A#D#F#E#G#S#R#Y#T#V#E#G#K#G#T#V#N#P#A#T#G#K#L#T#L#R#L#V#C#T#T#G#D#L#P#V#Y#W#P#T#L#V#T#T#F#G#Y#G#L#Q#C#F#A#E#E#Q#K#G#N#R#I#Y#P#F#M#G#S#W#G#P#R#K#K#V#L#T#R#H#I#T#D#G#K#D#I#V#D#A#T#F#A#F#E#G#N#V#L#V#T#D#V#N#L#Y#A#D#K#G#A#I#N#G#A#I#M#R#K#L#L#K#K#Q#E#R#P#Y#L#H#H#W#R#Y#D#P#E#R#Q#G#F#M#G#A#Q#R#V#F#Q#H#L#K#N#G#K#E#A#E#V#L#E#A#I#E#I#V#K#T#D#N#F#G#H#G#R#P#S#E#Y#V#T#K#Y#T#S#Y#L#G#H#H#A#D#L#L#E#D#A#I#E#I#E#V#A#L#E#Q#F#G#A#D#S#N#G#L#I#A#R#L#G#S#D##'

In [None]:
yfp_expert.get_model_output([sa])[0]

torch.Size([1, 242, 446])

In [None]:
encoded = models['saprot_tokenizer'](sa, return_tensors='pt', add_special_tokens=False)

In [None]:
encoded = {k: v.to('cuda') for k, v in encoded.items()}

In [None]:
models['saprot_yfp'](**encoded)

SequenceClassifierOutput(loss=None, logits=tensor([[-0.3582]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [None]:
directed_evolution = SaprotDirectedEvolution(
                   wt_protein = sa,    # path to wild type fasta file
                   output = 'all',                # return best, last, all variants    
                   experts = [yfp_expert],   # list of experts to compose
                   parallel_chains = 1,            # number of parallel chains to run
                   n_steps = 20,                   # number of MCMC steps per chain
                   max_mutations = 10,             # maximum number of mutations per variant
                   verbose = True                  # print debug info to command line
)

>Wildtype sequence: M # V # T # R # L # E # I # H # Y # T # G # E # I # P # V # R # Y # N # L # K # A # D # F # E # G # S # R # Y # T # V # E # G # K # G # T # V # N # P # A # T # G # K # L # T # L # R # L # V # C # T # T # G # D # L # P # V # Y # W # P # T # L # V # T # T # F # G # Y # G # L # Q # C # F # A # E # E # Q # K # G # N # R # I # Y # P # F # M # G # S # W # G # P # R # K # K # V # L # T # R # H # I # T # D # G # K # D # I # V # D # A # T # F # A # F # E # G # N # V # L # V # T # D # V # N # L # Y # A # D # K # G # A # I # N # G # A # I # M # R # K # L # L # K # K # Q # E # R # P # Y # L # H # H # W # R # Y # D # P # E # R # Q # G # F # M # G # A # Q # R # V # F # Q # H # L # K # N # G # K # E # A # E # V # L # E # A # I # E # I # V # K # T # D # N # F # G # H # G # R # P # S # E # Y # V # T # K # Y # T # S # Y # L # G # H # H # A # D # L # L # E # D # A # I # E # I # E # V # A # L # E # Q # F # G # A # D # S # N # G # L # I # A # R # L # G # S # D # #


IndexError: tuple index out of range