In [1]:
from transformers import EvalPrediction, TrainingArguments, Trainer, EarlyStoppingCallback, EsmForMaskedLM, DataCollatorForLanguageModeling, EsmTokenizer
import torch
import random
from torch.utils.data import Dataset as TorchDataset
from datasets import load_dataset
import numpy as np
from collections import Counter

In [2]:
dataset = load_dataset('nikraf/uniref128-256AA')
dataset = dataset['train'].select(range(15))
print(dataset)
class Dataset(TorchDataset):
    
    def __init__(self, dataset):
        self.seqs = dataset['seqs']
        self.lengths = [len(seq) for seq in self.seqs]

    def __len__(self):
        return len(self.seqs)
    
    def __avg__(self):
        return sum(self.lengths) / len(self.lengths)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        return {'seqs': seq}

Dataset({
    features: ['seqs'],
    num_rows: 15
})


In [3]:
def initialize_tokenizer():
    tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t30_150M_UR50D')
    new_tokens =  {'additional_special_tokens': []}

    AA_tokens = ['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
    shuffle_tokens = [f'<{AA}s>' for AA in AA_tokens]
    new_tokens['additional_special_tokens'].extend(shuffle_tokens)

    #Update tokenizer
    if new_tokens['additional_special_tokens']:
        total_tokens = len(tokenizer) + len(list(new_tokens.values())[0])
        tokenizer.add_special_tokens(new_tokens)

    return tokenizer, total_tokens, new_tokens

tokenizer, total_tokens, new_tokens = initialize_tokenizer()
model = EsmForMaskedLM.from_pretrained('facebook/esm2_t30_150M_UR50D', vocab_size=total_tokens, ignore_mismatched_sizes=True)


with torch.no_grad():
    model.resize_token_embeddings(total_tokens)
    try:
        cls_token_embedding = model.embeddings.word_embeddings.weight[tokenizer.cls_token_id, :].clone()
        for token in new_tokens['additional_special_tokens']:
            model.embeddings.word_embeddings.weight[tokenizer._convert_token_to_id(token), :] = cls_token_embedding.clone()
    except AttributeError:
        cls_token_embedding = model.esm.embeddings.word_embeddings.weight[tokenizer.cls_token_id, :].clone()
        for token in new_tokens['additional_special_tokens']:
            model.esm.embeddings.word_embeddings.weight[tokenizer._convert_token_to_id(token), :] = cls_token_embedding.clone()

Some weights of EsmForMaskedLM were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized because the shapes did not match:
- esm.embeddings.word_embeddings.weight: found shape torch.Size([33, 640]) in the checkpoint and torch.Size([53, 640]) in the model instantiated
- lm_head.bias: found shape torch.Size([33]) in the checkpoint and torch.Size([53]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
class DataCollatorForShuffling(DataCollatorForLanguageModeling):

    def __init__(self, tokenizer: EsmTokenizer, shuffle_type: str = 'regular', **kwargs):
        super().__init__(tokenizer=tokenizer, **kwargs)
        self.return_tensors = 'pt'
        self.tokenizer = tokenizer
        self.shuffle_type = shuffle_type

    def shuffle_seq(self, seq):
        if self.shuffle_type == 'sectional':
            return self.sectional_shuffle(seq)
        elif self.shuffle_type == 'regular':
            return self.regular_shuffle(seq)
        else:
            raise ValueError("Invalid shuffle_type. Choose 'sectional' or 'regular'.")
    '''

PREVIOUS APPROACH -- DID NOT WORK BECAUSE WE WERE PICKING FROM set(seq), see new approach below
    def regular_shuffle(self, seq):
        seq = list(seq)
        #per = np.clip(np.random.normal(0.3, 0.12), 0.0, 1.0)
        per = 0.3 
        seq_len = len(seq)
        num_to_shuffle = int(seq_len * per)
        
        shuffled_indices = random.sample(range(seq_len), num_to_shuffle)
        new_tokens = []

        for index in shuffled_indices:
            original_char = seq[index]
            possible_tokens = [f"<{token}s>" for token in set(seq) - {original_char}]
            new_token = random.choice(possible_tokens)
            new_tokens.append(new_token)


        random.shuffle(new_tokens)

        for new_token, seq_index in zip(new_tokens, shuffled_indices):
            seq[seq_index] = new_token
        
        return ''.join(seq)
    '''

    def regular_shuffle(self, seq):
        original_seq = list(seq)  # Save the original sequence

        seq = list(seq)
        per = np.clip(np.random.normal(0.3, 0.12), 0.0, 1.0)
        seq_len = len(seq)
        num_to_shuffle = int(seq_len * per)
        
        shuffled_indices = random.sample(range(seq_len), num_to_shuffle)
        shuffled_indices_set = set(shuffled_indices)
        
        index_mapping = {}
        new_tokens = []

        for original_index in shuffled_indices:
            # Pick a unique index, exclude original
            new_index = random.choice(list(shuffled_indices_set - {original_index}))
            shuffled_indices_set.remove(new_index)  # Remove picked index to prevent reuse
            
            # Mapping of original index to new index
            index_mapping[original_index] = new_index

            # Add <s> 
            new_token = f"<{seq[new_index]}s>"
            new_tokens.append((new_token, original_index))

        # Shuffle
            
        random.shuffle(new_tokens)

        # Replace with new tokens
        for new_token, original_index in new_tokens:
            seq[original_index] = new_token
        
        sorted_index_mapping = dict(sorted(index_mapping.items(), key=lambda item: item[1]))
        print(sorted_index_mapping)
        amino_acid_pairs = [(original_seq[original], seq[new]) 
                            for original, new in sorted_index_mapping.items()]

        print("Amino Acid Pairs:")
        for pair in amino_acid_pairs:
            print(pair)

        return ''.join(seq)

    def sectional_shuffle(self, seq):
        original_seq = list(seq)
        seq_len = len(original_seq)
        per = np.clip(np.random.normal(0.3, 0.12), 0.0, 1.0)
        section_length = int(seq_len * per)

        start_index = random.randrange(seq_len)
        end_index = (start_index + section_length) % seq_len

        if end_index > start_index:
            shuffle_section = original_seq[start_index:end_index]
        else:  
            shuffle_section = original_seq[start_index:] + original_seq[:end_index]

        new_section_tokens = []
        all_tokens = set(shuffle_section)

        for index in shuffle_section:       
            possible_tokens = [f"<{token}s>" for token in all_tokens - {index}]
            new_section_token = random.choice(possible_tokens)
            new_section_tokens.append(new_section_token)

        random.shuffle(new_section_tokens)

        
        if end_index > start_index:
            original_seq[start_index:end_index] = new_section_tokens
        else:
            original_seq[start_index:] = new_section_tokens[:len(original_seq) - start_index]
            original_seq[:end_index] = new_section_tokens[len(original_seq) - start_index:]

        return ''.join(original_seq)

    def torch_call(self, seqs):
        shuffled_seqs = [self.shuffle_seq(seq) for seq in seqs]

        labels = self.tokenizer(seqs, return_tensors=self.return_tensors, padding='longest', truncation=False).input_ids

        tokens = self.tokenizer(shuffled_seqs, return_tensors=self.return_tensors, padding='longest', truncation=False, return_token_type_ids=False)

        labels[labels == self.tokenizer.pad_token_id] = -100

        tokens['labels'] = labels

        return tokens

In [5]:
data_collator = DataCollatorForShuffling(return_tensors='pt', tokenizer=tokenizer)

torch_dataset = Dataset(dataset)
examples = [torch_dataset[i] for i in range(3)]
batch = data_collator.torch_call([example['seqs'] for example in examples])

decoded_shuffled_seqs = [tokenizer.decode(ids, skip_special_tokens=False) for ids in batch['input_ids']]
decoded_original_seqs = [tokenizer.decode(ids, skip_special_tokens=False) for ids in batch['labels']]

print("\nDecoded Shuffled Sequences (from label IDs):")
for seq in decoded_shuffled_seqs:
    print(seq)

print("\nDecoded Original Sequences (from label IDs):")
for seq in decoded_original_seqs:
    print(seq)


{91: 2, 131: 37, 63: 39, 75: 40, 2: 61, 123: 63, 61: 67, 165: 68, 141: 70, 145: 73, 70: 75, 156: 81, 97: 86, 102: 91, 73: 97, 40: 101, 106: 102, 86: 106, 172: 117, 124: 123, 101: 124, 37: 131, 117: 141, 39: 145, 81: 146, 68: 156, 146: 165, 67: 172}
Amino Acid Pairs:
('G', '<Gs>')
('I', '<Is>')
('G', '<Es>')
('R', '<Gs>')
('E', '<Ds>')
('R', '<Rs>')
('G', '<Es>')
('E', '<Es>')
('E', '<Rs>')
('E', '<Ss>')
('D', '<Es>')
('E', '<Es>')
('S', '<Gs>')
('R', '<Es>')
('G', '<Ks>')
('E', '<Gs>')
('G', '<Gs>')
('K', '<Rs>')
('E', '<Es>')
('G', '<Gs>')
('G', '<Rs>')
('R', '<Rs>')
('D', '<Ds>')
('R', '<Gs>')
('Y', '<Es>')
('R', '<Ys>')
('E', '<Rs>')
('D', '<Ds>')
{78: 0, 103: 1, 70: 2, 42: 3, 1: 4, 100: 6, 141: 7, 40: 8, 87: 11, 128: 12, 115: 15, 36: 16, 55: 17, 94: 18, 32: 19, 16: 20, 91: 21, 81: 22, 113: 24, 17: 26, 126: 27, 101: 29, 97: 30, 24: 32, 86: 35, 60: 36, 111: 37, 52: 38, 35: 40, 106: 41, 27: 42, 7: 43, 82: 44, 112: 45, 0: 49, 125: 50, 119: 51, 142: 52, 45: 53, 20: 55, 123: 58, 73: 59, 