In [1]:
import pandas as pd
import numpy as np
from Bio import SeqIO
import os 
import torch
from transformers import BertForMaskedLM, BertTokenizer, pipeline, set_seed
import re
import requests
from tqdm.auto import tqdm


  from .autonotebook import tqdm as notebook_tqdm


### functions

In [2]:
## Function to parse fasta to dictionary ##
def fasta2dict (infile):
    tmp = SeqIO.parse(infile,'fasta')
    fasta_dict = {'ID':[],'Seq':[]}
    for record in tmp:
        fasta_dict['ID'].append(record.id)
        fasta_dict['Seq'].append(str(record.seq))
    return fasta_dict  

In [3]:
## Shannon Entropy Functions ##
def parseMSA(msa_file, alnformat="fasta"):
    """Parse in the MSA file using Biopython's AlignIO"""
    print('Parsing MSA...')
    from Bio import AlignIO
    alignment = AlignIO.read(msa_file, alnformat)

    # Do a little sanity checking:
    seq_lengths_list = []
    for record in alignment:
        seq_lengths_list.append(len(record))

    seq_lengths = set(seq_lengths_list)

    print("Alignment length is:" + str(list(seq_lengths)))

    if len(seq_lengths) != 1:
        sys.stderr.write("Your alignment lengths aren't equal. Check your alignment file.")
        sys.exit(1)

    index = range(1, list(seq_lengths)[0]+1)

    return alignment, list(seq_lengths), index


def shannon_entropy(list_input, normalized=False):
    """Calculate Shannon's Entropy per column of the alignment (H=-\sum_{i=1}^{M} P_i\,log_2\,P_i)"""
    
    import math
    unique_base = set(list_input)  # max: 21
    M = len(list_input) # number of sequences in MSA
    entropy_list = []
    # Number of residues in column
    for base in unique_base:
        n_i = list_input.count(base) # Number of residues of type i
        P_i = n_i/float(M) # n_i(Number of residues of type i) / M(Number of residues in column)
        entropy_i = P_i*(math.log(P_i,2)) if not normalized else P_i*(math.log(P_i,M))
        entropy_list.append(entropy_i)

    sh_entropy = -(sum(entropy_list))

    return sh_entropy
    
def shannon_entropy_list_msa(alignment, drop_col_if_gaps_more_than:float=None, normalized=False):
    """Calculate Shannon Entropy across the whole MSA"""
    # alginment is a 2D array where each row is a sequence, each column is position in the msa.
    print('Calculating shannon entropies..')
    shannon_entropy_list = []
    for col_no in range(len(list(alignment[0]))):

        list_input = list(alignment[:, col_no])  # this column from all sequences

        if drop_col_if_gaps_more_than is not None:
            if len([x for x in list_input if x == '-'])/len(list_input) > drop_col_if_gaps_more_than:
                continue
        
        shannon_entropy_list.append(shannon_entropy(list_input, normalized=normalized))

    return shannon_entropy_list


def entropy_from_msa(entropy , msa_seq):
    final = []
    for index ,letter in enumerate(msa_seq):
        if letter != '-':
            final.append(entropy[index])
    return final

In [4]:
def get_scores(seq, vocab_size=30):
    s=''
    l = []
 
    chars = []
    list_of_scores=[]
    for index in range(len(seq)):
        scores=[] 
        dict = {}
        if seq[index] == ' ':
            continue
        s = seq[:index] + "[MASK]" + seq[index + 1:]     # mask tokens by replacing token with [MASK] and add previous and next tokens
        t = []
        char = []
        # I have modified the second loop as it reduce computations 
        predicted_score = unmasker(s)
        for i in range(21):    
            t = predicted_score[i]['score']
            char = predicted_score[i]['token_str']
            scores.append(t)
            chars.append(char)
            dict[char]=t              # for each key add its score
        list_of_scores.append(scores)      # list of lists (scores of each position)
        l.append(dict)              # add each position possible token and its suggested score 
    return list_of_scores,l

In [5]:
def list_of_scores_to_logits_tensor(list_of_scores, dict_of_scores_and_aa, aa_to_id, vocab_size=21):
    logits_tensor = torch.zeros((len(list_of_scores), vocab_size), dtype=torch.float32)
    for tok_idx in range(len(list_of_scores)):
        for aa, prob_ in dict_of_scores_and_aa[tok_idx].items():
            logits_tensor[tok_idx, aa_to_id[aa]] = prob_
            
    return torch.log(logits_tensor)

In [6]:
def typical_sampling(scores,ent):
    mass = 0.9
    filter_value= -float("Inf")
    min_tokens_to_keep = 1
    scores=torch.from_numpy(scores) 
    # calculate entropy
    normalized = torch.nn.functional.log_softmax(scores, dim=-1)
    # p = torch.exp(normalized)
    #ent = -(normalized * p).nansum(-1, keepdim=True)

    # shift and sort
    shifted_scores = torch.abs((-normalized) - ent)     
    sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) 
    sorted_logits = scores.gather(-1, sorted_indices)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative mass above the threshold
    last_ind = (cumulative_probs < mass).sum(dim=1)     
    last_ind[last_ind < 0] = 0
    sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
    if min_tokens_to_keep > 1:
        # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
        sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    scores = scores.masked_fill(indices_to_remove,filter_value)
    #scores = scores.cpu().detach().numpy()

    return scores

### main

In [7]:
set_seed(42)

In [8]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [9]:
msa_path = "/home/debi3evolm/out_protein_seqs/out_protein_seqs_msa.fasta"
drop_level= 1 # .75
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert").to(device)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer, top_k = 30, device=0 if device=='cuda' else -1)

In [11]:
#Generate input to pre-trained model
seq='MVK'
msa_seq ='-MV---K'
final_seq = ''
for ch in seq:
    final_seq = final_seq + ch + ' '

##Shannon Entropies ##
alnformat = "fasta" 

# Start calling functions to do the heavy lifting
alignment, seq_lengths, index_ = parseMSA(msa_path, alnformat)
sel_ = shannon_entropy_list_msa(alignment, drop_col_if_gaps_more_than=drop_level)
final_entropy = entropy_from_msa(sel_,msa_seq)

for i in range(len(final_entropy)):
    final_entropy[i]=[final_entropy[i]]
final_entropy = torch.FloatTensor(final_entropy)

print(final_entropy)

Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
tensor([[0.0050],
        [0.0054],
        [0.0129]])


In [12]:
scores,l = get_scores(final_seq)
scores = np.array(scores)

In [13]:
print(scores.shape) # tokens, vocab_size

(3, 21)


In [14]:
id_to_aa = {v:k for k,v in tokenizer.vocab.items()}

In [15]:
scores_tensor = list_of_scores_to_logits_tensor(list_of_scores=scores, 
                                dict_of_scores_and_aa=l,
                                aa_to_id=tokenizer.vocab,
                                vocab_size=30)

In [16]:
warped_scores = typical_sampling(scores_tensor.numpy(),final_entropy)

In [17]:
real_prob = torch.softmax(warped_scores, dim=1 ).numpy()
n_rows = real_prob.shape[0]
vocab_size =real_prob.shape[1]

aas_out = []

for r in range(n_rows):
    aas_out.append( np.random.choice(a=np.arange(vocab_size), p=real_prob[r, :]) )
    
''.join([ id_to_aa[x] for x in aas_out ])

'EMQ'

In [23]:

def generate_by_typical():
    generated_by_typical = []
    seq = 'MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQVPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERKSIGTLSAFEQNALEGMLDTLKKDIALGEEFVNK'
    msa_seq = '-------------------------------------------------------MKVAVL--G-----AAGG-------------IGQALALLLKTQLP---SG---SELSL-----YD-IA--P------------------VTPGVAVD---------LS-HIPTAVKIK----------GFSGE----DA-------TPALEGADVVLISAGV-ARK------------PG------MDRSDLFNVNAGIVKNLVQQVA--KT--CPKAC-IGIITNPVNTTVAIAAEVLKK--------A--------G---VYD---------K-NKLFGVT-TLDIIRSNTFVAELKG-----\
K-Q------PGEVEVP-VIGGHSG-VTILPL--LSQV-P--------G----V-----------SFT---------------------------EQEVADLTKRIQNAGTEVVEAKA-GGGSATLS-----MGQAAARFGLSLVRALQ---------GEQGV-VEC-A----YV---EGD-------------GQ--YAR-FFS-QPLLLG-K----NGV----------EERKSIGTLS-A-FEQN--------------------------ALEGMLDTLKK----DIAL-GEEFVNK-----------------------------------------------------------'
    
    final_seq = ''
    for ch in seq:
        final_seq = final_seq + ch + ' '

        ##Shannon Entropies ##
    alnformat = "fasta" 
    for i in range(500):
        # Start calling functions to do the heavy lifting
        alignment, seq_lengths, index_ = parseMSA(msa_path, alnformat)
        sel_ = shannon_entropy_list_msa(alignment, drop_col_if_gaps_more_than=drop_level)
        final_entropy = entropy_from_msa(sel_,msa_seq)

        for i in range(len(final_entropy)):
            final_entropy[i]=[final_entropy[i]]
        final_entropy = torch.FloatTensor(final_entropy)

        scores,l = get_scores(final_seq)
        scores = np.array(scores)

        id_to_aa = {v:k for k,v in tokenizer.vocab.items()}

        scores_tensor = list_of_scores_to_logits_tensor(list_of_scores=scores, 
                                dict_of_scores_and_aa=l,
                                aa_to_id=tokenizer.vocab,
                                vocab_size=30)

        warped_scores = typical_sampling(scores_tensor.numpy(),final_entropy)

        real_prob = torch.softmax(warped_scores, dim=1 ).numpy()
        n_rows = real_prob.shape[0]
        vocab_size =real_prob.shape[1]

        aas_out = []

        for r in range(n_rows):
            aas_out.append( np.random.choice(a=np.arange(vocab_size), p=real_prob[r, :]) )
            
        generated_by_typical.append(''.join([ id_to_aa[x] for x in aas_out ]))
    return generated_by_typical

generate_samples_by_typical = generate_by_typical()
generate_samples_by_typical[:3]

Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..




Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculating shannon entropies..
Parsing MSA...
Alignment length is:[719]
Calculatin

KeyboardInterrupt: 

In [24]:
generate_samples_by_typical[:3]

['MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQIAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQIPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERKSIGTLSAFEQHALEGMLDTLKKDIALGEEFVNK',
 'MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTDVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQIAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQIPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQHARFFSQPLLLGKNGVEERKSIGTLSAFEQNALEGMLDTLKKDITLGEEFVNK',
 'MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTDVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQIAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQVPGVSFSEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRAMQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERRSIGTLSAFEQNALEGMLDTLKKDIALGEEFVNK']

ABSCEG --> aa_to_id --> [0, 1, 6, 4, 10 .. ] --> embeddings --> output --> argmax([0.25, 0.1, 0.5 ...]) --> 2 ---> id_to_aa[2] =

string --> tokenizer --> model(tokenize_input) --> logits --> typica_decoding --> argmax

1. Generate Many sequences ---> batch inference first (for speed)
2. Evaluate Genearted Sequences using Evaluation pipeline
3. Generate other sequences: using temperature, greedy
4. Evaluate other results

1 - Using Genenrate method to generate sequences for temperature based sampling because after all we try to 
generate new sequences so instead of masking token by token we can generate a whole new sequence and then 
compare it to typical sampling   

2 - we can compare them through sequence samples if we don't have scores or logits

In [3]:
import torch
from transformers import XLNetLMHeadModel, XLNetTokenizer,pipeline
import re
import os
import requests
from tqdm.auto import tqdm

In [14]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)

In [15]:
model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")

In [16]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [17]:
model = model.to(device)
model = model.eval()

In [None]:
sequences_Example = "M K V A V L G A A G"

In [None]:
with open('protein_seqs.fasta', 'r') as fasta:
    seq = fasta
seq

<_io.TextIOWrapper name='protein_seqs.fasta' mode='r' encoding='UTF-8'>

In [None]:
ids = tokenizer.encode(sequences_Example, add_special_tokens=False)

I stopped using GPU because model raise an error about that there is two devices
GPU(cuda:0) and CPU 

In [None]:
input_ids = torch.tensor(ids).unsqueeze(0).to(device)

In [None]:
input_ids

tensor([[34, 28, 22, 19, 22, 17, 20, 19, 19, 20]], device='cuda:0')

in order to apply greedy we just have to set temperature parameter to 1.0

In [None]:
max_length = 100
temperature = 0.7
k = 0
p = 0.9
repetition_penalty = 1.0
num_return_sequences = 1

In [None]:
output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
        output_scores=True,
        return_dict_in_generate=True
    )

In [None]:
output_ids

SampleDecoderOnlyOutput(sequences=tensor([[34, 28, 22, 19, 22, 17, 20, 19, 19, 20, 22, 21, 19, 34, 20, 22, 34, 20,
         22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21, 20,
         34, 20, 22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22,
         21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21, 20, 34,
         20, 22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21, 20, 34, 20, 22, 21,
         20, 34, 20, 22, 21, 20, 34, 20, 22, 21]], device='cuda:0'), scores=(tensor([[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 2.3946,
         1.6008, 4.3507, 1.8596,   -inf, 3.5959,   -inf,   -inf,   -inf,   -inf,
         1.3735,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf, 1.7379,   -inf,
           -inf]], device='cuda:0'), tensor([[  -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,   -inf,
           -inf,   -inf,   -inf,   -inf,  

In [None]:
output_sequences = [" ".join(" ".join(tokenizer.decode(output_id)).split()) for output_id in output_ids['sequences']]
print('Generated Sequences\n')
for output_sequence in output_sequences:
  print(output_sequence)

Generated Sequences

M K V A V L G A A G V E A M G V M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E G M G V E


#### greedy sampling 

In [None]:
max_length = 100
temperature = .01
k = 0
p = 0.9
repetition_penalty = 1.0
num_return_sequences = 1

In [None]:
output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
        output_scores=True,
        return_dict_in_generate=True
    )

In [None]:
output_ids

SampleDecoderOnlyOutput(sequences=tensor([[34, 28, 22, 19, 22, 17, 20, 19, 19, 20, 19, 21, 22, 34, 20, 19, 19, 20,
         19, 21, 22, 34, 20, 19, 19, 20, 19, 21, 22, 34, 20, 19, 19, 20, 19, 21,
         22, 34, 20, 19, 19, 20, 19, 21, 22, 34, 20, 19, 19, 20, 19, 21, 22, 34,
         20, 19, 19, 20, 19, 21, 22, 34, 20, 19, 19, 20, 19, 21, 22, 34, 20, 19,
         19, 20, 19, 21, 22, 34, 20, 19, 19, 20, 19, 21, 22, 34, 20, 19, 19, 20,
         19, 21, 22, 34, 20, 19, 19, 20, 19, 21]], device='cuda:0'), scores=(tensor([[    -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf,     -inf,     -inf,     -inf,     -inf, 304.5489,     -inf,
             -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf,     -inf]], device='cuda:0'), tensor([[    -inf,     -inf,     -inf,  

In [None]:
output_sequences = [" ".join(" ".join(tokenizer.decode(output_id)).split()) for output_id in output_ids['sequences']]
print('Generated Sequences\n')
for output_sequence in output_sequences:
  print(output_sequence)

Generated Sequences

M K V A V L G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E V M G A A G A E


as you can see that the two methods (greedy and temperature) generate sequences that is repititive 

In [1]:
import torch
from transformers import XLNetLMHeadModel, XLNetTokenizer,pipeline
import re
import os
import requests
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = XLNetTokenizer.from_pretrained("Rostlab/prot_xlnet", do_lower_case=False)
model = XLNetLMHeadModel.from_pretrained("Rostlab/prot_xlnet")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()

In [3]:
seq = []
with open('protein_seqs.fasta', 'r') as fasta:
    for i in fasta:
        seq.append(i.replace('\n', ''))
seq[:5]

['MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQVPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERKSIGTLSAFEQNALEGMLDTLKKDIALGEEFVNK',
 'MKVAVLGAAGGIGQALALLLKLQLPAGTDLSLYDIAPVTPGVAVDVSHIPTAVNVKGFSGEDPTPALEGADVVLISAGVARKPGMDRSDLFNINAGIVRGLIEKVAVTCPKACVGIITNPVNTTVAIAAEVLKKAGVYDKRKLFGVTTLDVLRSETFVAELKGLNVSRTSVPVIGGHSGVTILPLLSQVQYAKWNEDEIEPLTKRIQNAGTEVLNAKAGGGSATLSMAQAAARFARSLVKGLSGETVVECTYVEGDGKYARFFSQPVRLGKEGVEEILPIGPLSNFEQQALENMLPTLRADIELGEKFING',
 'MKVAVIGAAGGIGQALALLLKNRLPAGSDLALYDIAPVTPGVAADLSHIPTPVTIKGYAGEDPTPALEGADVVLVSAGVARKPGMDRADLFNVNAGIVKALAEKIAVVCPKACVGIITNPVNTTVPIAAEVLKKAGVYDKRKLFGVTTLDVIRSETFVAALKDKDPGQVRVPVIGGHSGVTILPLLSQVEGVSFTDEEVAALTKRIQNAGTEVVEAKAGGGSATLSMGQAACRFGLALVKALQGESDVVEYAYVEGEGEYAPFFAQPIKLGKNGVEALLDIGKLSAYEQAALDGMLDTLKGDIQIGVEFVK',
 'MKVAVLGAAGGIGQALALLLKTQLPAGSKLSLYDIAPVTPGVAVDLSHI

In [4]:
if __name__ == "__main__":
    
    max_prompt_len = len(max(seq))
    # seqs = sample_n_seqs(csv_path=args['file'], nsamples=nsamples , max_len=512, delimiter=',')
    prompt_lens = [int(torch.randint(low=0, high=150, size=(1,1))) for _ in seq]
    prompts_seqs = [sq[:prmpt_len] for sq, prmpt_len in zip(seq, prompt_lens)]
    # nsamples = len(seq)
    
    print(f"inferring sequences ..", end='')
    generated_seqs = []

    bs = 100
    with torch.no_grad():
        for i in range(0, 3000, bs):
            # --> encode
            encoded = tokenizer.batch_encode_plus(prompts_seqs[i:min(i+bs, len(prompts_seqs)-1)], add_special_tokens=True, padding=True, return_tensors='pt')
            input_ids = encoded['input_ids'].to(device)
            attention_mask = encoded['attention_mask'].to(device)

            # --> infer (makram: you can replace this with model(outputs) followed by score extraction and warping )
            temperature = .7
            outputs = model.generate(input_ids=input_ids, 
                                max_length=max_prompt_len,
                                temperature = temperature,
                                do_sample=True)

            # --> decode
            generated_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)

            generated_seqs.extend( [a+b for a,b in zip(prompts_seqs[i:i+bs], generated_tokens)] )

            print(f"done {len(generated_seqs)}/{300}..")
                
        print(f"done all.")


    # --> save for later evaluation
    print(f"saving to files ..", end='')

inferring sequences ..done 100/300..
done 200/300..
done 300/300..
done 400/300..
done 500/300..
done 600/300..
done 700/300..
done 800/300..
done 900/300..
done 1000/300..
done 1100/300..
done 1200/300..
done 1300/300..
done 1400/300..
done 1500/300..
done 1600/300..
done 1700/300..
done 1800/300..
done 1900/300..
done 2000/300..
done 2100/300..
done 2200/300..
done 2300/300..
done 2400/300..
done 2500/300..
done 2600/300..
done 2700/300..
done 2800/300..
done 2900/300..
done 3000/300..
done all.
saving to files ..

In [10]:
for i in prompts_seqs[:3]:
    print(i)

MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYD
MKVAVLGAAGGIGQALALLLKLQLPAGTDLSLYDI
MKVAVIGAAGGIGQALALLLKNRLPAGSDLALYDIAPVTPGVAADLSHIPTPVTIKGYAGEDPTPALEGADVVLVSAGVARKPGMDRADLFNVNAGIVKALAEKIAVVCPKACVGIITNPVNTTVPIAAEVLKKAGVYDKRKLF


In [11]:
for i in seq[:3]:
    print(i)

MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDKNKLFGVTTLDIIRSNTFVAELKGKQPGEVEVPVIGGHSGVTILPLLSQVPGVSFTEQEVADLTKRIQNAGTEVVEAKAGGGSATLSMGQAAARFGLSLVRALQGEQGVVECAYVEGDGQYARFFSQPLLLGKNGVEERKSIGTLSAFEQNALEGMLDTLKKDIALGEEFVNK
MKVAVLGAAGGIGQALALLLKLQLPAGTDLSLYDIAPVTPGVAVDVSHIPTAVNVKGFSGEDPTPALEGADVVLISAGVARKPGMDRSDLFNINAGIVRGLIEKVAVTCPKACVGIITNPVNTTVAIAAEVLKKAGVYDKRKLFGVTTLDVLRSETFVAELKGLNVSRTSVPVIGGHSGVTILPLLSQVQYAKWNEDEIEPLTKRIQNAGTEVLNAKAGGGSATLSMAQAAARFARSLVKGLSGETVVECTYVEGDGKYARFFSQPVRLGKEGVEEILPIGPLSNFEQQALENMLPTLRADIELGEKFING
MKVAVIGAAGGIGQALALLLKNRLPAGSDLALYDIAPVTPGVAADLSHIPTPVTIKGYAGEDPTPALEGADVVLVSAGVARKPGMDRADLFNVNAGIVKALAEKIAVVCPKACVGIITNPVNTTVPIAAEVLKKAGVYDKRKLFGVTTLDVIRSETFVAALKDKDPGQVRVPVIGGHSGVTILPLLSQVEGVSFTDEEVAALTKRIQNAGTEVVEAKAGGGSATLSMGQAACRFGLALVKALQGESDVVEYAYVEGEGEYAPFFAQPIKLGKNGVEALLDIGKLSAYEQAALDGMLDTLKGDIQIGVEFVK


In [5]:
generated_seqs_lines = []
for i in generated_seqs:
    generated_seqs_lines.append(i + '\n')

for i in range(len(generated_seqs_lines)):
    generated_seqs_lines[i] = generated_seqs_lines[i].replace(' ', '')

with open('generated_protein_by_temp.txt', 'w') as temp:
        temp.writelines(generated_seqs_lines)

In [13]:
generated_seqs_lines[:5]

['MKVAVLGAAGGIGQA"H\n',
 'MKVAVLGAAGGIGQALALLLKLQLPAGTDLSLYDIAPVTPGVAVDVSHIGGYNTT–ATTNTTTVTTGTTTTMTTTTTTVTTMTTTTNTVTVTTTATTT\n',
 'MKVAVIGAAGGIGQALALLLKNRLPAGSDLALY\n',
 'MKVAVLGAAGGIGQALALLLKTQLPAGSKLSLYDIAP–Y"."..-\n',
 'MKVAVLGAAGGIGQALALLLKTQLPAGSELSLYDIAPVTPGVAVDLSHIPTDVTITGFSGIDPTAALVGADVVLISAGVARKPGMDRSDLFNINAGIIKNLASKCAEVCPTACIGIITNPVNTTKA\n']

### Greedy Search Uing Batches

In [6]:
if __name__ == "__main__":
    
    max_prompt_len = len(max(seq))
    # seqs = sample_n_seqs(csv_path=args['file'], nsamples=nsamples , max_len=512, delimiter=',')
    prompt_lens = [int(torch.randint(low=0, high=150, size=(1,1))) for _ in seq]
    prompts_seqs = [sq[:prmpt_len] for sq, prmpt_len in zip(seq, prompt_lens)]
    # nsamples = len(seq)
    
    print(f"inferring sequences ..", end='')
    generated_seqs = []

    bs = 100
    with torch.no_grad():
        for i in range(0, 3000, bs):
            # --> encode
            encoded = tokenizer.batch_encode_plus(prompts_seqs[i:min(i+bs, len(prompts_seqs)-1)], add_special_tokens=True, padding=True, return_tensors='pt')
            input_ids = encoded['input_ids'].to(device)
            attention_mask = encoded['attention_mask'].to(device)

            # --> infer (makram: you can replace this with model(outputs) followed by score extraction and warping )
            temperature = 0.001
            outputs = model.generate(input_ids=input_ids, 
                                max_length=max_prompt_len,
                                temperature = temperature,
                                do_sample=True)

            # --> decode
            generated_tokens = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False)

            generated_seqs.extend( [a+b for a,b in zip(prompts_seqs[i:i+bs], generated_tokens)] )

            print(f"done {len(generated_seqs)}/{300}..")
                
        print(f"done all.")


    # --> save for later evaluation
    print(f"saving to files ..", end='')

inferring sequences ..done 100/300..
done 200/300..
done 300/300..
done 400/300..
done 500/300..
done 600/300..
done 700/300..
done 800/300..
done 900/300..
done 1000/300..
done 1100/300..
done 1200/300..
done 1300/300..
done 1400/300..
done 1500/300..
done 1600/300..
done 1700/300..
done 1800/300..
done 1900/300..
done 2000/300..
done 2100/300..
done 2200/300..
done 2300/300..
done 2400/300..
done 2500/300..
done 2600/300..
done 2700/300..
done 2800/300..
done 2900/300..
done 3000/300..
done all.
saving to files ..

In [7]:
generated_seqs_lines_greedy = []
for i in generated_seqs:
    generated_seqs_lines_greedy.append(i + '\n')


for i in range(len(generated_seqs_lines_greedy)):
    generated_seqs_lines_greedy[i] = generated_seqs_lines_greedy[i].replace(' ', '')

with open('generated_protein_by_greedy.txt', 'w') as temp:
        temp.writelines(generated_seqs_lines_greedy)

In [12]:
generated_seqs_lines_greedy[:5]

['MKVAVLGAAGGIGQALALLLKTQLPSGSELSLYDIAPVTPGVAVDLSHIPTAVKIKGFSGEDATPALEGADVVLISAGVARKPGMDRSDLFNVNAGIVKNLVQQVAKTCPKACIGIITNPVNTTVAIAAEVLKKAGVYDTMTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT\n',
 'MKVAVLGAAGGIGQALALLLKLQLPAGTDLSLYDITMTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT\n',
 'MKVAVIGAAGGIGQALALLLKNRLPAGSDLALYDIAPVTPGVAADLSHIPTPVTIKGYAGEDPTPALEGADVVLVSAGVARKPGMDRADLFNVNAGIVKALAEKIAVVCPKACVGIITNPVNTTVPIAAEVLKKAGVYDKRKLFTMTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTTT