In [2]:
import torch
import esm

In [3]:
class ESMTokenizer:
    def __init__(self, esm_model_name='esm2_t12_35M_UR50D', sequences=True):
        self.model, self.alphabet = getattr(esm.pretrained, esm_model_name)()
        self.batch_converter = self.alphabet.get_batch_converter()
        self.padding_idx = self.alphabet.padding_idx

    @property
    def pad_id(self):
        return self.padding_idx

    def tokenize(self, seq_tuple):
        # unpack tuple input
        seq_str = seq_tuple[0] if isinstance(seq_tuple, tuple) else seq_tuple
        _, _, tokens = self.batch_converter([("seq", seq_str)])
        return tokens[0].numpy()

    def untokenize(self, tokenized_seq):
        # tokenized_seq is a tensor or numpy array
        if torch.is_tensor(tokenized_seq):
            tokenized_seq = tokenized_seq.cpu().numpy()
        tokens = [self.alphabet.get_tok(int(tok)) for tok in tokenized_seq if tok != self.padding_idx]
        # Remove special tokens from untokenized output
        tokens = [tok for tok in tokens if tok not in ('<cls>', '<eos>', '<pad>')]
        return "".join(tokens)

In [4]:
tokenizer = ESMTokenizer()

In [5]:
tokenized = tokenizer.tokenize("MKGYFGPYGGQYVPEILMGALEELEAAYEGIMKDESFWKEFNDLLRDYAGRPTPLYFARRLSEKYGARVYLKREDLLHTGAHKINNAIGQVLLAKLMGKTRIIAETGAGQHGVATATAAALFGMECVIYMGEEDTIRQKLNVERMKLLGAKVVPVKSGSRTLKDAIDEALRDWITNLQTTYYVFGSVVGPHPYPIIVRNFQKVIGEETKKQIPEKEGRLPDYIVACVSGGSNAAGIFYPFIDSGVKLIGVEAGGEGLETGKHAASLLKGKIGYLHGSKTFVLQDDWGQVQVSHSVSAGLDYSGVGPEHAYWRETGKVLYDAVTDEEALDAFIELSRLEGIIPALESSHALAYLKKINIKGKVVVVNLSGRGDKDLESVLNHPYVRERIR")
tokenized

array([ 0, 20, 15,  6, 19, 18,  6, 14, 19,  6,  6, 16, 19,  7, 14,  9, 12,
        4, 20,  6,  5,  4,  9,  9,  4,  9,  5,  5, 19,  9,  6, 12, 20, 15,
       13,  9,  8, 18, 22, 15,  9, 18, 17, 13,  4,  4, 10, 13, 19,  5,  6,
       10, 14, 11, 14,  4, 19, 18,  5, 10, 10,  4,  8,  9, 15, 19,  6,  5,
       10,  7, 19,  4, 15, 10,  9, 13,  4,  4, 21, 11,  6,  5, 21, 15, 12,
       17, 17,  5, 12,  6, 16,  7,  4,  4,  5, 15,  4, 20,  6, 15, 11, 10,
       12, 12,  5,  9, 11,  6,  5,  6, 16, 21,  6,  7,  5, 11,  5, 11,  5,
        5,  5,  4, 18,  6, 20,  9, 23,  7, 12, 19, 20,  6,  9,  9, 13, 11,
       12, 10, 16, 15,  4, 17,  7,  9, 10, 20, 15,  4,  4,  6,  5, 15,  7,
        7, 14,  7, 15,  8,  6,  8, 10, 11,  4, 15, 13,  5, 12, 13,  9,  5,
        4, 10, 13, 22, 12, 11, 17,  4, 16, 11, 11, 19, 19,  7, 18,  6,  8,
        7,  7,  6, 14, 21, 14, 19, 14, 12, 12,  7, 10, 17, 18, 16, 15,  7,
       12,  6,  9,  9, 11, 15, 15, 16, 12, 14,  9, 15,  9,  6, 10,  4, 14,
       13, 19, 12,  7,  5

In [6]:
seq = tokenizer.untokenize(tokenized)
seq

'MKGYFGPYGGQYVPEILMGALEELEAAYEGIMKDESFWKEFNDLLRDYAGRPTPLYFARRLSEKYGARVYLKREDLLHTGAHKINNAIGQVLLAKLMGKTRIIAETGAGQHGVATATAAALFGMECVIYMGEEDTIRQKLNVERMKLLGAKVVPVKSGSRTLKDAIDEALRDWITNLQTTYYVFGSVVGPHPYPIIVRNFQKVIGEETKKQIPEKEGRLPDYIVACVSGGSNAAGIFYPFIDSGVKLIGVEAGGEGLETGKHAASLLKGKIGYLHGSKTFVLQDDWGQVQVSHSVSAGLDYSGVGPEHAYWRETGKVLYDAVTDEEALDAFIELSRLEGIIPALESSHALAYLKKINIKGKVVVVNLSGRGDKDLESVLNHPYVRERIR'