In this notebook, we want to evaluate the pretrained BGC model.

The finetuning was done:
  1. Concatanating all proteins in a BGC using a 'B' character
  2. Replacing one keyword
  3. Run ~1000 proteins (512aa sliding window, with overlapping 300?) from one BGC family
  4. How many Epochs?

We found:
  - No 'B' is generated. This is likely due to the fact that when we fine-tune the foward module, only the embedding layer and the softmax weights are updated. 'B' was never trained as it is not in the data for pretrain (likely embedded to 0s).
  - Setting top k=1 will evently lead to repetitive sequences when the generated sequences exceed certain length. Using topk>1 increases the randomness, but may increasingly lose the ability to generate "real" artificial sequences.

In [1]:
#from google.colab import drive
#drive.mount('/content/drive')

#import sys, os

#progen_path = '/content/drive/Shareddrives/jgi_collaborations/BGC/jgiBGC_finetuned_6/lrc-deep-learning/progen_original/progen_code'
#sys.path.append(progen_path)

In [2]:
from __future__ import print_function
from __future__ import division
import os
#os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
import sys
import torch
import tqdm
import pdb
import numpy as np
import platform
import hashlib
import pytorch_transformer
import re
import argparse
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
import torch.nn.functional as F
#from torch.utils.tensorboard import SummaryWriter
#from transformProtein import transformProtein
from ProteinDataset_uid import ProteinDataset
from torch.utils.data import Dataset, DataLoader
import pickle
import time
import matplotlib.pyplot as plt
import collections

2023-10-09 15:48:24.763996: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
CUDA_DEVICE_NUM = 0
DEVICE = torch.device(f'cuda:{CUDA_DEVICE_NUM}' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cuda:0


In [6]:
seq_length = 511
embedding_dim = 1280
num_layers = 36
vocab_loc = './mapping_files/vocab.txt'

use_py3 = platform.python_version()[0] == '3'
vocab = open(vocab_loc).readlines() if not use_py3 else open(vocab_loc, encoding='utf-8').read().split('\n')[:-1]
vocab = list(map(lambda x: x.split(' ')[0], vocab))
vocab_size = len(vocab)
print('-----vocab size',vocab_size,'------')

-----vocab size 129407 ------


In [7]:
class TiedEmbeddingSoftmax(torch.nn.Module):

  def __init__(self, vocab_size=vocab_size, embedding_size=embedding_dim, **kwargs):
    super(TiedEmbeddingSoftmax, self).__init__()
    self.w = torch.nn.Parameter(torch.normal(0., 1e-2, size=(vocab_size, embedding_size)))
    self.b = torch.nn.Parameter(torch.zeros(vocab_size))

  def forward(self, inputs, embed=True):
    if embed:
      return torch.nn.functional.embedding(inputs, self.w)
    else:
      return torch.tensordot(inputs, self.w.t(), 1) + self.b

class CTRLmodel(torch.nn.Module):
  def __init__(self):
    super(CTRLmodel,self).__init__()
    self.tied_embedding_softmax = TiedEmbeddingSoftmax()
    self.encoder = pytorch_transformer.Encoder(device=DEVICE)
    self.checkpoint = None

  def forward(self, inputs):
    x = self.tied_embedding_softmax(inputs, embed = True)
    x = self.encoder(x)
    x = self.tied_embedding_softmax(x, embed = False)
    return x

  def loadCheckpoint(self, model_path, num_layers):
    #pytorch_model_hash = hashlib.md5(model_path.encode('utf-8')).hexdigest()
    pytorch_model_hash = model_path

    if os.path.exists(pytorch_model_hash):
      print('Found PyTorch checkpoint @', pytorch_model_hash)
      print('Loading instead of converting from TensorFlow')
      checkpoint = torch.load(pytorch_model_hash)
      self.checkpoint = checkpoint

      #self.tied_embedding_softmax.load_state_dict(checkpoint['softmax'])
      #self.encoder.load_state_dict(checkpoint['encoder'])
      ## load state dict has KeyError, because checkpoint is ready the state_dict
      ## can load checkpoint directly
      ## https://discuss.pytorch.org/t/keyerror-state-dict/18220/5
      if 'model_state_dict' in checkpoint:
        new_state_dict = collections.OrderedDict()
        for k, v in checkpoint['model_state_dict'].items():
            name = k.replace("module.", '') # remove `module.` prefix from pickel file generated by DistributedDataParallel
            new_state_dict[name] = v
        self.load_state_dict(new_state_dict)
      else:
        self.load_state_dict(checkpoint)

      #self.tied_embedding_softmax.to('cuda')
      #self.encoder.to('cuda')
      self.tied_embedding_softmax.to(DEVICE)
      self.encoder.to(DEVICE)

    else:
      print('Error: Could not find PyTorch checkpoint')
      sys.exit(1)

  def get_checkpoint(self):
    return self.checkpoint

In [31]:
model = CTRLmodel()
print('model initialized')

#curr_model_path = '/content/drive/Shareddrives/jgi_collaborations/BGC/jgiBGC_finetuned_6/finetune_progen_multi_node_GPU_demo_V2_backup.pth' # Finetuned model on JGI large data epoch=1024
#curr_model_path =  '/content/drive/Shareddrives/jgi_collaborations/BGC/jgiBGC_finetuned_6/finetune_progen_multi_node_GPU_demo_V3.pth'

#curr_model_path = './checkpoints_cur/finetune_progen_multi_node_GPU_demo_V3.pth' # Finetuned model on JGI large data epoch=2048
curr_model_path = "./checkpoints_cur/finetune_progen_multi_node_GPU_demo_allpram.pth"

reader = model.loadCheckpoint(model_path=curr_model_path, num_layers = num_layers)
print('previous checkpoint loaded')

#model = model.cuda()
model = model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters()) #lr, betas

model initialized
Found PyTorch checkpoint @ ./checkpoints_cur/finetune_progen_multi_node_GPU_demo_allpram.pth
Loading instead of converting from TensorFlow
previous checkpoint loaded


In [41]:
ckpoint = model.get_checkpoint()
if 'model_state_dict' in ckpoint:
    print(' epoch {0}\n loss {1}'.format(ckpoint['epoch'], ckpoint['loss']))

 epoch 149
 loss 5.028169631958008


In [43]:
model

CTRLmodel(
  (tied_embedding_softmax): TiedEmbeddingSoftmax()
  (encoder): Encoder(
    (layer0): EncoderLayer(
      (multi_head_attention): MultiHeadAttention(
        (Wq): Linear(in_features=1280, out_features=1280, bias=True)
        (Wk): Linear(in_features=1280, out_features=1280, bias=True)
        (Wv): Linear(in_features=1280, out_features=1280, bias=True)
        (dense): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (ffn): Sequential(
        (0): Linear(in_features=1280, out_features=8192, bias=True)
        (1): ReLU()
        (2): Linear(in_features=8192, out_features=1280, bias=True)
      )
      (layernorm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (layernorm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (layer1): EncoderLayer(
      (multi_head_attention): MultiHeadAttention(
        (Wq): Linear(in_features

In [51]:
for p in model.named_parameters():
    if 'encoder.layernorm' not in p[0]:
        print(p[0])

tied_embedding_softmax.w
tied_embedding_softmax.b
encoder.layer0.multi_head_attention.Wq.weight
encoder.layer0.multi_head_attention.Wq.bias
encoder.layer0.multi_head_attention.Wk.weight
encoder.layer0.multi_head_attention.Wk.bias
encoder.layer0.multi_head_attention.Wv.weight
encoder.layer0.multi_head_attention.Wv.bias
encoder.layer0.multi_head_attention.dense.weight
encoder.layer0.multi_head_attention.dense.bias
encoder.layer0.ffn.0.weight
encoder.layer0.ffn.0.bias
encoder.layer0.ffn.2.weight
encoder.layer0.ffn.2.bias
encoder.layer0.layernorm1.weight
encoder.layer0.layernorm1.bias
encoder.layer0.layernorm2.weight
encoder.layer0.layernorm2.bias
encoder.layer1.multi_head_attention.Wq.weight
encoder.layer1.multi_head_attention.Wq.bias
encoder.layer1.multi_head_attention.Wk.weight
encoder.layer1.multi_head_attention.Wk.bias
encoder.layer1.multi_head_attention.Wv.weight
encoder.layer1.multi_head_attention.Wv.bias
encoder.layer1.multi_head_attention.dense.weight
encoder.layer1.multi_head_att

In [33]:
with open(os.path.join('mapping_files/','taxa_to_lineage.p'),'rb') as handle:
    taxa_to_lineage = pickle.load(handle)
with open('mapping_files/taxa_to_ctrl_idx.p','rb') as handle:
    taxa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/kw_to_ctrl_idx.p','rb') as handle:
    kw_to_ctrl_idx = pickle.load(handle)
# with open('mapping_files/aa_to_ctrl_idx.p','rb') as handle:
#     aa_to_ctrl_idx = pickle.load(handle)
with open('mapping_files/aa_to_ctrl_idx_seqonly.p','rb') as handle:
    aa_to_ctrl_idx = pickle.load(handle)
    
with open('mapping_files/kw_to_name.p2','rb') as handle:
    kw_to_name = pickle.load(handle)
# with open('mapping_files/taxid_to_name.p2','rb') as handle:
#     taxid_to_name = pickle.load(handle)

In [34]:
def flipdict(my_map):
    return {v: k for k, v in my_map.items()}
ctrl_idx_to_aa = flipdict(aa_to_ctrl_idx)
ctrl_idx_to_kw = flipdict(kw_to_ctrl_idx)
ctrl_idx_to_taxa = flipdict(taxa_to_ctrl_idx)

In [35]:
def predict_fn(inputs):
    with torch.no_grad():
        #inputs = torch.tensor(inputs).cuda()
        inputs = torch.tensor(inputs).to(DEVICE)
        output = model(inputs)
        output = output[:,:,-26:-1] # remove non-AA token logits
        return output

In [36]:
taxid = 9606 # homo sapiens taxonomy id from NCBI: https://www.ncbi.nlm.nih.gov/taxonomy
tax_lineage = taxa_to_lineage[taxid] # make lineage in ncbi ids
print(tax_lineage)
tax_lineage = [taxa_to_ctrl_idx[ite] for ite in tax_lineage] # now translated as ctrl code indices
print(tax_lineage)

#kw_lineage = [677,9] # UniprotKB keywords from https://www.uniprot.org/docs/keywlist
kw_lineage = [472] # see kw mapping in data_preprocess.ipynb
print(kw_lineage)
kw_lineage = [kw_to_ctrl_idx[ite] for ite in kw_lineage] # now translated to ctrl code indices
print(kw_lineage)

[33208, 7711, 40674, 9443, 9604, 9605, 9606]
[11177, 5756, 14034, 6957, 7068, 7069, 7070]
[472]
[0]


In [37]:
# Try to generate longer sequences
# example_seq = 'YMIQEEEWDRDLLLDPAWEKQQRKTFTAWCNSHLRKAGTQIENIEEDFRNGLKLMLLLEVISGERLPKPDRGKMRFHKIANVNKALDYIASKGVKLVSIGAEEIVDGNVKMTLGMIWTIILRFAIQDISVEETSAKEGLLLWCQRKTAPYRNVNIQNFHTSWKDGLGLCALIHRHRPDLIDYSKLNKDDPIGNINLAMEIAEKHLDIPKMLDAEDIVNTPKPDERAIMTYVSCFYHAFAGAEQAETAANRICKVLAVNQENERLMEEYERLASELLEWIRRTIPWLENRTPAATMQAMQKKLEDFRDYRRKHKPPKVQEKCQLEINFNTLQTKLRISNRPAFMPSEGKMVSDIAGAWQRLEQAEKGYEEWLLNEIRRLERLEHLAEKFRQKASTHETWAYGKEQILLQKDYESASLTEVRALLRKHEAFESDLAAHQDRVEQIAAIAQELNELDYHDAVNVNDRCQKICDQWDRLGTLTQKRREALERMEKLLETIDQLHLEFAKRAAPFNNWMEGAMEDLQDMFIVHSIEEIQSLITAHEQFKATLPEADGERQSIMAIQNEVEKVIQSYNIRISSSNPYSTVTMDELRTKWDKVKQLVPIRDQSLQEELARQHANERLRRQFAAQANAIGPWIQNKMEEIARSSIQITGALEDQMNQLKQYEHNIINYKNNIDKLEGDHQLIQEALVFDNKHTNYTMEHIRVGWELLLTTIARTINEVETQILTRDAKGITQEQMNEFRASFNHFDRRKNGLMDHEDFRACLISMGYDLGEAEFARIMTLVDPNGQGTVTFQSFIDFMTRETADTDTAEQVIASFRILASDKPYILAEELRRELPPDQAQYCIKRMPAYSGPGSVPGALDYAAFSSALYGESDL'
# use a BGC from training
example_seq = 'MNGKRNIFTCISIIGIGLASFSSFSFAANVTDNSVQNSIPVVNQQVAAAKEMKPFPQQVNYAGVIKPTHVTQESLNASVRSYYDNWKKKYLKNDLSSLPGGYYVKGEITGDADGFKPLGTSEGQGYGMIITVLMAGYDSNAQKIYDGLFKTARTFKSSQNPNLMGWVVADSKKAQGHFDSATDGDLDIAYSLLLAHKQWGSNGTVNYLKEAQDMITKGIKASNVTNNSRLNLGDWDSKNSLDTRPSDWMMSHLRAFYEFTGDKTWLTVINNLYDVYTQFSNKYSPNTGLISDFVVKNPPQPAPKDFLEESEYTNAYYYNASRVPLRIVMDYAMYGEKRSKVISDKVSSWIQNKTNGNPSKIVDGYQLNGSNIGSYSTAVFVSPFIAASITSSNNQKWVNSGWDWMKNKRESYFSDSYNLLTMLFITGNWWKPVPDDKKIQNQINDAIYEGYDNBMEKVLFFGDPGIDDSFAIMYGLLHPEIEIVGIVTGYGNVEHIHAAHNAAYILQLANRQ'

#  when generate longer sequences, we ended up in repeats
# let's do it 200 at a time, update prefix everytime

max_len = 5000 * 2
gen_step = 200
total_generated = ''
prefix = example_seq[:150]

while 1:
    if len(total_generated) > max_len:
        break

    ref = example_seq[150:200]
    penalty = 1.2
    topk = 10

    seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
    # generate_num = len(kw_lineage+tax_lineage)+len(prefix+ref)
    generate_num = gen_step + len(kw_lineage+tax_lineage)
    seq_length = min(generate_num, 511)
    text = tax_lineage+kw_lineage+seed_seq
    padded_text = text + [0] * (generate_num - len(text))
    tokens_generated = np.tile(padded_text, (1,1))

    for token in range(len(text)-1, generate_num-1):
        prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
        _token = token if token < seq_length else -1
        prompt_logits = prompt_logits.cpu().detach().numpy()

        if penalty>0:
            penalized_so_far = set()
            for _ in range(token-3,token+1):
                generated_token = tokens_generated[0][_] - (vocab_size-26) # added
                if generated_token in penalized_so_far:
                    continue
                penalized_so_far.add(generated_token)
                prompt_logits[_token][generated_token] /= penalty

        # compute probabilities from logits
        prompt_probs = np.exp(prompt_logits[_token])
        prompt_probs = prompt_probs / sum(prompt_probs)
        pruned_list = np.argsort(prompt_probs)[::-1]

        if topk==1:
            idx = pruned_list[0]
        else:
            pruned_list = pruned_list[:topk]
            chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
            idx = pruned_list[chosen_idx]

        # assign the token for generation
        idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
        tokens_generated[0][token+1] = idx


    tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
    tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
    tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])
    # save generated, update prefix before the next round
    total_generated += tokens_generated_so_far
    prefix = tokens_generated_so_far[-150:]
    print(prefix)
    tokens_generated_so_far = ''
    tokens_generated = ''



EMKPFPQQVNYAGVIKPTHVTQESLNASVRSYYDNWKKKYLKNDLSSLPGGYYVKGEITGDADGFKPLGTSEGQGYGMIITVLMAGYDSNAQKIYDGLFKDINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSG
GYYVKGEITGDADGFKPLGTSEGQGYGMIITVLMAGYDSNAQKIYDGLFKDINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSGFYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYD
DINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSGFYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYDVELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGY
FYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYDVELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGYEKDGNLEAKTYDASLVGEDLSTAGIYKTFDGESNAVIKDGTLYAEKFTSG
VELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGYEKDGNLEAKTYDASLVGEDLSTAGIYKTFDGESNAVIKDGTLYAEKFTSGDVKFNILAKGQDSTEKVRAESGKAQISFTVISEKLVPDTKSIFTENSKVE
EKDGNLEAKTYDASLVGEDLSTAGIYKTFDGESNAVIKDGTLYAEKFTSGDVKFNILAKGQDSTEKVRAESGKAQISFTVISEKLVPDTKSIFTENSKVEIGTVSKANISVKATSNVLASQDGVYFKDNSIVGTAFYDSQTKASLDEKTL
DVKFNILAKGQDSTEKVRAESGKAQISFTVISEKLVPDTKSIFTENSKVEIGTVSKANISVKATSNVLASQDGVYFKDNSIVGTAFYDSQTKAS

In [38]:
len(total_generated)

10200

In [39]:
'B' in total_generated

False

In [40]:
total_generated[150:]

'DINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSGEMKPFPQQVNYAGVIKPTHVTQESLNASVRSYYDNWKKKYLKNDLSSLPGGYYVKGEITGDADGFKPLGTSEGQGYGMIITVLMAGYDSNAQKIYDGLFKDINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSGFYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYDGYYVKGEITGDADGFKPLGTSEGQGYGMIITVLMAGYDSNAQKIYDGLFKDINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSGFYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYDVELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGYDINFIKVTHPADVQESARLKIEHATKVDNSRPIGKNTIEGKNVSFKAPSGFYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYDVELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGYEKDGNLEAKTYDASLVGEDLSTAGIYKTFDGESNAVIKDGTLYAEKFTSGFYIEAVQPGKIALYGHSLADGSLVIEASLDKEFAVTKSGNYVLTDAKGYDVELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGYEKDGNLEAKTYDASLVGEDLSTAGIYKTFDGESNAVIKDGTLYAEKFTSGDVKFNILAKGQDSTEKVRAESGKAQISFTVISEKLVPDTKSIFTENSKVEVELSKDNVSILKNSEFKNGEDSTVIALNEFKPASIGEKVFYDKGQAISGYEKDGNLEAKTYDASLVGEDLSTAGIYKTFDGESNAVIKDGTLYAEKFTSGDVKFNILAKGQDSTEKVRAESGKAQISFTVISEKLVPDTKSIFTENSKV

In [25]:
# >BGC0000001|c1|15054-32399|+|AEK75502.1|type_1_polyketide_synthase|AEK75502.1
example_seq = 'MSAPDEPVAVVGLACRLPGAADPEAFWALLRDGREAITDPPASRRDPDGRARRGGFLDAVDLFDAEFFGVPPREAAAMDPQQRLVLELSWEALEDARIRPDALAGSRTGVFVGAISDDYATLLRRRGPDAIGPHSLTGTNRGIIANRVSYHLGLHGPSITVDSAQSSALVAVHVAAESLRRGESELALAGGVNLNLAPESTLGAERFGALSPDGRCHTFDARANGYVRGEGGGLVVLKPLDRALADGDRVHAVLLGSAVNNDGVTDGLTVPGSDGQREVIRLAHERAGTTPGEVDYVELHGTGTVVGDPVEAAALGAELGQLRDTPLLVGSAKTNVGHLEGAAGIVGFVKAVLCVRHRTLPPSLNFATPNPRIPLNELNLRVVTESHTVPRPLVVGVSSFGMGGTNAHAVLTEAPLRTARKPAPAARPLTWVVSGHTPQALRAQAGQLTSLAADPADVAFSLATTRATLPYRAAVVGETAADLRAGMAAVATGTPHPGTVTGSPAGTLAFVFTGQGSQRAGMGRELAARFPVYAQAFAEVAAALDPHLGRPLDEVLDDADALDRTEFAQPALFAVEVALFRLLTHWGLRPDAVAGHSVGEIAAAHVAGVLDLPDAARLVAARGRLMQALPTGGAMVALSAGEEEVRPLLRPGADLAAVNAAESVVVAGDDDAVSAIEETVRGWGRRTSRLRVSHAFHSARMDPMRVSFAQALADIEFAQPTIPVVSALTDPDVTDAEHWVRHVRDTVRFADAVRDLRDRGVRTVLEVGPDAVLTALAHDVAELAAVAVLRRDRPEPDTAVTALATAFTRGAAVDWTALLGARQAVDLPRYAFQRSRHWLDQDNPAIAAPEVTRAPDGTPRRSDEELLDLVRTAVAVAHGRVGPAAIDPDTTFRDLGLDSVTSVEFRDRLAAATGVPLSPGLVYDHPTPRAVVAHLRTLTGGGPADPEQESGYRDEPVAVIGMACRYPGGVGSPDDLWQLVRDGRDATGPFPTDRGWDLDALYDPDPGTPGRTYVRRGGFLDGAAEFDADFFGISPREASAMDPQQRLLLHTAWEALEHGRLNPESLRGTRTGVFVGVVDNDYGPRLHEPVEGTEGYLLTGTTASVASGRVAYALGLTGPAVTVDTACSSSLVALHLAAQALRQGECTLALAGGATVLATPGMFLEFSRQRGLAPDGRCKAFAATADGTAWAEGAGLVVLERLSDARRNGHPVLAVLRGSAINQDGASNGLTAPSGPSQERVIRRALAVAGLAPSDVDLMEAHGTGTALGDPIEARAILATYGQRRDTPLHLGSLKSNIGHTQAAAGIAGVIKVVQAMQHGTLPATLHVDEPTPHVDWAEGQVSLLTEATPWPDTGRPRRAAVSSFGISGTNAHVILEHGDPQPAPPRRTSTGHVAWLVSAREPELVAEQAGRLHRFVRDNPELDPADVALSLATTRPLLEHRAAVVGADRDELLAGLAELESGRRRAEAIRPGKVAFLFAGQGTQRLNMGRQLYDTNPTFAHALDTVTNALNPHLNQPLLDIIFGTDPHLLNRTENAQPALFAIETALYHLLTHHGIHPDYLLGHSLGEITAAHAAGILTLTDAATLVTTRAKLMQTATPGGAMIAIEATETEIQPTLHPTVTIAAINTPTTTVISGDHHHTHAIAHHWRQQGRRTTTLTVSHAFHSPHMDPILDTFHTTTQTLTHHPPHTPLITNLTGQPLTNPTPEHWTHHLRQPVRYHDATTTLTHHGVTHTIEIGPDTTLTTLTKTNHPTLTTTPTLRPHHNENHTLTHTLATTPTTNWASLYPHARPVRLPTTAFRRDRYWLTGGRATPGADSGVREVDHPLLGAAVTLADDSTVYTGRLSRRTAPWLADHVVLGRALLPGTALLEYALWAGRDVGLPRVAELTLEAPLVLPDEGVTQVRVTVGPPGEQRTVAVHARADDAEQWTRHASGMLTAAPPASPVPPRVDGPPVDVDDLYERLAGKGYEYGPAFRLATAARHGQHVVAQLAAPAGPDGFVLHPAQVDAALHPIVLDGDETLLPFSWSGVSVFRRPSGALHAYWTPERALVLTDADGVVATADSLHLRPARMPAPTDLHRIRWVPAEDARRQIRVEPVADATAALAMLHERLDATEPTALVVPHLDRTGAAGLVRSAQAEHPGRFVLIHADDPVRTVPDGEPEAAWRDGSWWVPRLARVAPVDPGLPLSGTVLVTGGTGALGALVARHLVRAHRVRDLVLVSRRGADAPGAAALADELAGHGARVDLRACDVADREALACLLADLPTLDAVVHAAGVVRDATVSALTVEQVRAAATKAESAWHLHELTRDRPLRAFVLFSSISGLLGTAGQGAYAAANAALDALAAHRHALGLPALSLAWGLWEDTGMGAGLSAADVARWRRDGLPPLTVEQGVALFDAALSHEGPVLAPVRLDLAALRGRDVLPAALRGLVTRRAVPPAGSRPRDEAELREVVRSVVAEVLGYPSAAGVDSARPFRDLGLDSLGGVELRNRLAAATGLPVPATLVFDHPTPDAVVAHLLGATTSAQPAPTPTVATRTDEPIAIVGMACRYPGGVSSPEDLWRLVADGVDAIGEFPTDRGWDLGRLYDPDPEHAGTSYTRHGGFLYDAADFDAGFFALSPREATATDPQQRLLLEVAWEAFERAGIDPTAVRGSRTGVFAGVMYGDYGTRWRTAPEGFEGHLLTGNTSSVVSGRVAYSFGLEGPAVTVDTACSSSLVALHLAAQSLRSGECDLALAGGVTVMATPHTFVEFSRQRGLSPDGRCRSFSAAANGTGWSEGAGLLLVERLSDARANGHHVLAILRGSAVNQDGASNGLTAPNGPAQQRVIRTALTNAHLQPTDIDLVEAHGTGTRLGDPIEAQALIATYGHHRNTPLHLGSLKSNIGHTQAAAGVAGVIKVIQAMQHGTLPATLHVNEPTPHVNWADSQVTLLTEATPWPDTGRPRRAAVSSFGISGTNAHVILEHGDPRPVPPEETDPPAPVPLVISARSAGALRDQAARVRTALGSGLPVRDVAYTLGAARARHPHQAVVVGEGRAELLAGLDAVADGTVPGAVATPGKVAFLFAGQGTQRLNMGRQLYDTNPTFAHALDTVTNALNPHLNQPLLDIIFGTDPHLLNRTENAQPALFAIETALYHLLTHHGIHPDYLLGHSLGEITAAHAAGILTLTDAATLVTTRAKLMQTATPGGAMIAIEATETEIQPTLHPTVTIAAINTPTTTVISGDHHHTHAIAHHWRQQGRRTTTLTVSHAFHSPHMDPILDTFHTTTQTLTHHPPHTPLITNLTGQPLTNPTPEHWTHHLRQPVRYHDATTTLTHHGVTHTIEIGPDTTLTTLTKTNHPTLTTTPTLRPHHNENHTLTHTLATTPTTNWKTLLPHATVIDLPTYPFQRQRYWLDGPAADTGLDGSGHPLLPGVVDLADGGLVLTGTVSADSHPWLAGHRIGGATLLPATAVVEAVAHAASRVGLDVDELVLTAAVPVDAPVRLRLTVGPASDDDSRAVHLHGNTDDGEWLPYATGRLAVVTQSPAADLATWPPTDAEPVDVVDLYDRLAEGGYGYHGLFQGLRALWRRGDETFAEVRPDESPTGGFAPHPALWDAALHPLAWDAAERGQVEIPFEWRSVRRHGPGAPALRVRLARRDDAVSVDVADDAGRPIASAGALRLRSTGTAPTTVLEPDWEPVPTDGEWTGRYATVVAPRTGADASAAYAAVTWALDALRQHEGDEPLVVRTVDDPAGAAVRGLVRTAQTEQPGRFVLFTGSGDPEPALVRAALASGEPEVALRDGTLMAPRLSRIPVAPGPLPFASGSTVLVTGGTGALGALVARHLVVRHGVRRLLLTSRRGPAADGAAELVDELTAAGAEVEVVACDVADRPAVAALLASIPEEHPLTAVIHTAGVLDDGALTSLTEERLARVLRPKAEAAWHLHEFTRDRPLTAFVLFSSITGITGTAGQANYAAANAYLDALARHRRNLGLPGVSLAWGLWGATGMASGLGAADLDRLARSGITPLSPQEGLDLFDACLVADRPVLAPARVDLSTTRRQRRRAASAAATVTSREGLRELVRAQVAAVLGHTDATEVSTDVAFTGLGLDSLTAVELRNRIAERTGLRLSSTVVFDHPSVDALTDHLVAELAGARPVETPQPVTQPADEPIAIVGMACRYPGGVSSPEDLWRLVADGVDAIGEFPTDRGWGEIHDPDPDRPGHSYTRHGGFLYAAGDFDAELFGMSPREALTTDPQQRLLLEVAWEAFERAGLPPGSLRGSRTGVFTGVMYNDYGARLHQAGTPAPGYEGYLVSGSAGSVASGRVAYSFGLEGPAVTVDTACSSSLVALHLAAQSLRSGECDLALAGGVTVMASPATFVEFSRQRGLAPDGRCKPFAAAADGTGWSEGAGLLLVERLSDARANGHHVLAILRGSAVNQDGASNGLTAPNGPAQQRVIRTALTNAHLQPTDIDLVEAHGTGTRLGDPIEAQALIATYGHHRNTPLHLGSLKSNIGHTQAAAGVAGVIKVIQAMQHGTLPATLHVNEPTPHVNWADSQVTLLTEATPWPDTGRPRRAAVSSFGISGTNAHVILEHGDAVDTGTGTGTVAPGTAVVPWLLSGTSRQALTAYARLLGEVDAAPVDIAATLALGRSPLALRASVVGRDRPEFPTARLQPVQPVDGPTAFAFTGQGSQRAGMGLGLAARFPQFADALASVAEALDPHLPSPLLDVLADGDLLERTEYAQPAIFAVEVALFRLLAHYGVTPHVLLGHSVGELAAAHVAGVLDLPDAATLVAARGRLMGGLVPGGAMAAVRAGEDEVLALLVPGAEIAAVNADDAVVVSGDAEAVAAVTQALRDAGRRVTPLRVSHAFHSARMDPVLEEFRAVAATLRFSEPTIPLISLLPGSPTDPGYWVRHLREAVRFGDGVRSLAEWGVRRVLEVGPDAALTPVTGPTGIATLRRDHDEESAFVTALAALHDTGATVDWATFFGELGARRVPLPTYPFQRRRYWLTPTAPRTTGGSGHPLLDAAVELPEGAVLFTGRVAAEDADWLADHVVLGQTVVSGATLLSLVLHAAAAAGRPTVRRLTLHAPLVLPDDGGAADLRVGVDEQGQVTVYARPAGGGWTRHASGTLDTVEQPAEALGSWPPAGAEPLDVDYTRLADAGYAYGPGCRRVRAAWRLGDDLYAEVGPVDADGHAAPHPALLDAALHPLALDLLDDEQTRVPHVWSSVTVHATGATTLRARIRRTGTDRVALTLTDTDDRPVATADLTVRAVARGLPDLYAVRLTPVRPATGGTVWPSVGRDVGLPRYAELSTSTDDIVERAHDRVTEVAELLRRWLAQGPPEARLVVATDQVTDPADGVLWGLVRAAQTEHPDRFVLLDSDGDPRSRTLVPGALATGEPQLVVRDGRITVPRLARTAAAPQPPRLDPDGTVLVTGAGGALGSLTARRLVTHHGVRRLLLLGRRGGMQPLAAELTALGATVRVAACDAANRAALARVLDTVPAAHPLTAVVHAAGVVSDGPLATLTPQRFAEVLRPKVDAAWHLHELTCEQDLAAFVLFSSLAGLVGNAGQANYAAANTGLDALAAYRRAAGLPAVSLAWGLWDAPGMGAALDETQRARIARTGVAPLPVERGLALFDACLGAREALLVPAALQPERATRVAPVLAGLAPATTATTPQQDWPRRLAGRGAAEQHRLLLELVRSTIVEVLGHSSVAAVAPDRGLMDLGFDSLTAVELAGRLGADTGVRTPSTVVFDHPTPTALAHYLRHELVGEEAADDEKPHELDEVSDEDLFALIDTELGER'
#  when generate longer sequences, we ended up in repeats
# let's do it 200 at a time, update prefix everytime

max_len = 5000
gen_step = 200
total_generated = ''
prefix = example_seq[:150]

while 1:
    if len(total_generated) > max_len:
        break

    ref = example_seq[150:200]
    penalty = 1.2
    topk = 10

    seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
    # generate_num = len(kw_lineage+tax_lineage)+len(prefix+ref)
    generate_num = gen_step + len(kw_lineage+tax_lineage)
    seq_length = min(generate_num, 511)
    text = tax_lineage+kw_lineage+seed_seq
    padded_text = text + [0] * (generate_num - len(text))
    tokens_generated = np.tile(padded_text, (1,1))

    for token in range(len(text)-1, generate_num-1):
        prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
        _token = token if token < seq_length else -1
        prompt_logits = prompt_logits.cpu().detach().numpy()

        if penalty>0:
            penalized_so_far = set()
            for _ in range(token-3,token+1):
                generated_token = tokens_generated[0][_] - (vocab_size-26) # added
                if generated_token in penalized_so_far:
                    continue
                penalized_so_far.add(generated_token)
                prompt_logits[_token][generated_token] /= penalty

        # compute probabilities from logits
        prompt_probs = np.exp(prompt_logits[_token])
        prompt_probs = prompt_probs / sum(prompt_probs)
        pruned_list = np.argsort(prompt_probs)[::-1]

        if topk==1:
            idx = pruned_list[0]
        else:
            pruned_list = pruned_list[:topk]
            chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
            idx = pruned_list[chosen_idx]

        # assign the token for generation
        idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
        tokens_generated[0][token+1] = idx


    tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
    tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
    tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])
    # save generated, update prefix before the next round
    total_generated += tokens_generated_so_far
    prefix = tokens_generated_so_far[-150:]
    print(prefix)
    tokens_generated_so_far = ''
    tokens_generated = ''




ARRGGFLDAVDLFDAEFFGVPPREAAAMDPQQRLVLELSWEALEDARIRPDALAGSRTGVFVGAISDDYATLLRRRGPDAIGPHSLTGTNRGIIANRVSYHLDAGSMAVLDTARSDPVRHLSKRDFATSLPNTRASDLVGADITVQGRSP
DALAGSRTGVFVGAISDDYATLLRRRGPDAIGPHSLTGTNRGIIANRVSYHLDAGSMAVLDTARSDPVRHLSKRDFATSLPNTRASDLVGADITVQGRSPLEVIKRAPSDETAPNRGALTKIAGRLVADRGSPNELVTIDLGRSEIADSP
HLDAGSMAVLDTARSDPVRHLSKRDFATSLPNTRASDLVGADITVQGRSPLEVIKRAPSDETAPNRGALTKIAGRLVADRGSPNELVTIDLGRSEIADSPRVFIAPTELDIVRNLAGREDFSLAIHRVSGPLRNSAFDLRNGVATRIDNT
LEVIKRAPSDETAPNRGALTKIAGRLVADRGSPNELVTIDLGRSEIADSPRVFIAPTELDIVRNLAGREDFSLAIHRVSGPLRNSAFDLRNGVATRIDNTQASELTGFRESLADRGYEATSGVFLASGRNDFTVSEINAGRQNFGTVLRK
RVFIAPTELDIVRNLAGREDFSLAIHRVSGPLRNSAFDLRNGVATRIDNTQASELTGFRESLADRGYEATSGVFLASGRNDFTVSEINAGRQNFGTVLRKGAVQLTNGPVSLTQANGRISAETRNAVESGAVLNQGSALETGIRLVGPQI
QASELTGFRESLADRGYEATSGVFLASGRNDFTVSEINAGRQNFGTVLRKGAVQLTNGPVSLTQANGRISAETRNAVESGAVLNQGSALETGIRLVGPQIANGTVQAIDGRQVKITGLDVYSRLAPNQDISYAGPLVAIGQSANLGSTDY
GAVQLTNGPVSLTQANGRISAETRNAVESGAVLNQGSALETGIRLVGPQIANGTVQAIDGRQVKITGLDVYSRLAPNQDISYAGPLVAIGQSAN

In [26]:
'B' in total_generated

False

In [20]:
# >BGC0000001|c1|15054-32399|+|AEK75502.1|type_1_polyketide_synthase|AEK75502.1
example_seq = 'MSAPDEPVAVVGLACRLPGAADPEAFWALLRDGREAITDPPASRRDPDGRARRGGFLDAVDLFDAEFFGVPPREAAAMDPQQRLVLELSWEALEDARIRPDALAGSRTGVFVGAISDDYATLLRRRGPDAIGPHSLTGTNRGIIANRVSYHLGLHGPSITVDSAQSSALVAVHVAAESLRRGESELALAGGVNLNLAPESTLGAERFGALSPDGRCHTFDARANGYVRGEGGGLVVLKPLDRALADGDRVHAVLLGSAVNNDGVTDGLTVPGSDGQREVIRLAHERAGTTPGEVDYVELHGTGTVVGDPVEAAALGAELGQLRDTPLLVGSAKTNVGHLEGAAGIVGFVKAVLCVRHRTLPPSLNFATPNPRIPLNELNLRVVTESHTVPRPLVVGVSSFGMGGTNAHAVLTEAPLRTARKPAPAARPLTWVVSGHTPQALRAQAGQLTSLAADPADVAFSLATTRATLPYRAAVVGETAADLRAGMAAVATGTPHPGTVTGSPAGTLAFVFTGQGSQRAGMGRELAARFPVYAQAFAEVAAALDPHLGRPLDEVLDDADALDRTEFAQPALFAVEVALFRLLTHWGLRPDAVAGHSVGEIAAAHVAGVLDLPDAARLVAARGRLMQALPTGGAMVALSAGEEEVRPLLRPGADLAAVNAAESVVVAGDDDAVSAIEETVRGWGRRTSRLRVSHAFHSARMDPMRVSFAQALADIEFAQPTIPVVSALTDPDVTDAEHWVRHVRDTVRFADAVRDLRDRGVRTVLEVGPDAVLTALAHDVAELAAVAVLRRDRPEPDTAVTALATAFTRGAAVDWTALLGARQAVDLPRYAFQRSRHWLDQDNPAIAAPEVTRAPDGTPRRSDEELLDLVRTAVAVAHGRVGPAAIDPDTTFRDLGLDSVTSVEFRDRLAAATGVPLSPGLVYDHPTPRAVVAHLRTLTGGGPADPEQESGYRDEPVAVIGMACRYPGGVGSPDDLWQLVRDGRDATGPFPTDRGWDLDALYDPDPGTPGRTYVRRGGFLDGAAEFDADFFGISPREASAMDPQQRLLLHTAWEALEHGRLNPESLRGTRTGVFVGVVDNDYGPRLHEPVEGTEGYLLTGTTASVASGRVAYALGLTGPAVTVDTACSSSLVALHLAAQALRQGECTLALAGGATVLATPGMFLEFSRQRGLAPDGRCKAFAATADGTAWAEGAGLVVLERLSDARRNGHPVLAVLRGSAINQDGASNGLTAPSGPSQERVIRRALAVAGLAPSDVDLMEAHGTGTALGDPIEARAILATYGQRRDTPLHLGSLKSNIGHTQAAAGIAGVIKVVQAMQHGTLPATLHVDEPTPHVDWAEGQVSLLTEATPWPDTGRPRRAAVSSFGISGTNAHVILEHGDPQPAPPRRTSTGHVAWLVSAREPELVAEQAGRLHRFVRDNPELDPADVALSLATTRPLLEHRAAVVGADRDELLAGLAELESGRRRAEAIRPGKVAFLFAGQGTQRLNMGRQLYDTNPTFAHALDTVTNALNPHLNQPLLDIIFGTDPHLLNRTENAQPALFAIETALYHLLTHHGIHPDYLLGHSLGEITAAHAAGILTLTDAATLVTTRAKLMQTATPGGAMIAIEATETEIQPTLHPTVTIAAINTPTTTVISGDHHHTHAIAHHWRQQGRRTTTLTVSHAFHSPHMDPILDTFHTTTQTLTHHPPHTPLITNLTGQPLTNPTPEHWTHHLRQPVRYHDATTTLTHHGVTHTIEIGPDTTLTTLTKTNHPTLTTTPTLRPHHNENHTLTHTLATTPTTNWASLYPHARPVRLPTTAFRRDRYWLTGGRATPGADSGVREVDHPLLGAAVTLADDSTVYTGRLSRRTAPWLADHVVLGRALLPGTALLEYALWAGRDVGLPRVAELTLEAPLVLPDEGVTQVRVTVGPPGEQRTVAVHARADDAEQWTRHASGMLTAAPPASPVPPRVDGPPVDVDDLYERLAGKGYEYGPAFRLATAARHGQHVVAQLAAPAGPDGFVLHPAQVDAALHPIVLDGDETLLPFSWSGVSVFRRPSGALHAYWTPERALVLTDADGVVATADSLHLRPARMPAPTDLHRIRWVPAEDARRQIRVEPVADATAALAMLHERLDATEPTALVVPHLDRTGAAGLVRSAQAEHPGRFVLIHADDPVRTVPDGEPEAAWRDGSWWVPRLARVAPVDPGLPLSGTVLVTGGTGALGALVARHLVRAHRVRDLVLVSRRGADAPGAAALADELAGHGARVDLRACDVADREALACLLADLPTLDAVVHAAGVVRDATVSALTVEQVRAAATKAESAWHLHELTRDRPLRAFVLFSSISGLLGTAGQGAYAAANAALDALAAHRHALGLPALSLAWGLWEDTGMGAGLSAADVARWRRDGLPPLTVEQGVALFDAALSHEGPVLAPVRLDLAALRGRDVLPAALRGLVTRRAVPPAGSRPRDEAELREVVRSVVAEVLGYPSAAGVDSARPFRDLGLDSLGGVELRNRLAAATGLPVPATLVFDHPTPDAVVAHLLGATTSAQPAPTPTVATRTDEPIAIVGMACRYPGGVSSPEDLWRLVADGVDAIGEFPTDRGWDLGRLYDPDPEHAGTSYTRHGGFLYDAADFDAGFFALSPREATATDPQQRLLLEVAWEAFERAGIDPTAVRGSRTGVFAGVMYGDYGTRWRTAPEGFEGHLLTGNTSSVVSGRVAYSFGLEGPAVTVDTACSSSLVALHLAAQSLRSGECDLALAGGVTVMATPHTFVEFSRQRGLSPDGRCRSFSAAANGTGWSEGAGLLLVERLSDARANGHHVLAILRGSAVNQDGASNGLTAPNGPAQQRVIRTALTNAHLQPTDIDLVEAHGTGTRLGDPIEAQALIATYGHHRNTPLHLGSLKSNIGHTQAAAGVAGVIKVIQAMQHGTLPATLHVNEPTPHVNWADSQVTLLTEATPWPDTGRPRRAAVSSFGISGTNAHVILEHGDPRPVPPEETDPPAPVPLVISARSAGALRDQAARVRTALGSGLPVRDVAYTLGAARARHPHQAVVVGEGRAELLAGLDAVADGTVPGAVATPGKVAFLFAGQGTQRLNMGRQLYDTNPTFAHALDTVTNALNPHLNQPLLDIIFGTDPHLLNRTENAQPALFAIETALYHLLTHHGIHPDYLLGHSLGEITAAHAAGILTLTDAATLVTTRAKLMQTATPGGAMIAIEATETEIQPTLHPTVTIAAINTPTTTVISGDHHHTHAIAHHWRQQGRRTTTLTVSHAFHSPHMDPILDTFHTTTQTLTHHPPHTPLITNLTGQPLTNPTPEHWTHHLRQPVRYHDATTTLTHHGVTHTIEIGPDTTLTTLTKTNHPTLTTTPTLRPHHNENHTLTHTLATTPTTNWKTLLPHATVIDLPTYPFQRQRYWLDGPAADTGLDGSGHPLLPGVVDLADGGLVLTGTVSADSHPWLAGHRIGGATLLPATAVVEAVAHAASRVGLDVDELVLTAAVPVDAPVRLRLTVGPASDDDSRAVHLHGNTDDGEWLPYATGRLAVVTQSPAADLATWPPTDAEPVDVVDLYDRLAEGGYGYHGLFQGLRALWRRGDETFAEVRPDESPTGGFAPHPALWDAALHPLAWDAAERGQVEIPFEWRSVRRHGPGAPALRVRLARRDDAVSVDVADDAGRPIASAGALRLRSTGTAPTTVLEPDWEPVPTDGEWTGRYATVVAPRTGADASAAYAAVTWALDALRQHEGDEPLVVRTVDDPAGAAVRGLVRTAQTEQPGRFVLFTGSGDPEPALVRAALASGEPEVALRDGTLMAPRLSRIPVAPGPLPFASGSTVLVTGGTGALGALVARHLVVRHGVRRLLLTSRRGPAADGAAELVDELTAAGAEVEVVACDVADRPAVAALLASIPEEHPLTAVIHTAGVLDDGALTSLTEERLARVLRPKAEAAWHLHEFTRDRPLTAFVLFSSITGITGTAGQANYAAANAYLDALARHRRNLGLPGVSLAWGLWGATGMASGLGAADLDRLARSGITPLSPQEGLDLFDACLVADRPVLAPARVDLSTTRRQRRRAASAAATVTSREGLRELVRAQVAAVLGHTDATEVSTDVAFTGLGLDSLTAVELRNRIAERTGLRLSSTVVFDHPSVDALTDHLVAELAGARPVETPQPVTQPADEPIAIVGMACRYPGGVSSPEDLWRLVADGVDAIGEFPTDRGWGEIHDPDPDRPGHSYTRHGGFLYAAGDFDAELFGMSPREALTTDPQQRLLLEVAWEAFERAGLPPGSLRGSRTGVFTGVMYNDYGARLHQAGTPAPGYEGYLVSGSAGSVASGRVAYSFGLEGPAVTVDTACSSSLVALHLAAQSLRSGECDLALAGGVTVMASPATFVEFSRQRGLAPDGRCKPFAAAADGTGWSEGAGLLLVERLSDARANGHHVLAILRGSAVNQDGASNGLTAPNGPAQQRVIRTALTNAHLQPTDIDLVEAHGTGTRLGDPIEAQALIATYGHHRNTPLHLGSLKSNIGHTQAAAGVAGVIKVIQAMQHGTLPATLHVNEPTPHVNWADSQVTLLTEATPWPDTGRPRRAAVSSFGISGTNAHVILEHGDAVDTGTGTGTVAPGTAVVPWLLSGTSRQALTAYARLLGEVDAAPVDIAATLALGRSPLALRASVVGRDRPEFPTARLQPVQPVDGPTAFAFTGQGSQRAGMGLGLAARFPQFADALASVAEALDPHLPSPLLDVLADGDLLERTEYAQPAIFAVEVALFRLLAHYGVTPHVLLGHSVGELAAAHVAGVLDLPDAATLVAARGRLMGGLVPGGAMAAVRAGEDEVLALLVPGAEIAAVNADDAVVVSGDAEAVAAVTQALRDAGRRVTPLRVSHAFHSARMDPVLEEFRAVAATLRFSEPTIPLISLLPGSPTDPGYWVRHLREAVRFGDGVRSLAEWGVRRVLEVGPDAALTPVTGPTGIATLRRDHDEESAFVTALAALHDTGATVDWATFFGELGARRVPLPTYPFQRRRYWLTPTAPRTTGGSGHPLLDAAVELPEGAVLFTGRVAAEDADWLADHVVLGQTVVSGATLLSLVLHAAAAAGRPTVRRLTLHAPLVLPDDGGAADLRVGVDEQGQVTVYARPAGGGWTRHASGTLDTVEQPAEALGSWPPAGAEPLDVDYTRLADAGYAYGPGCRRVRAAWRLGDDLYAEVGPVDADGHAAPHPALLDAALHPLALDLLDDEQTRVPHVWSSVTVHATGATTLRARIRRTGTDRVALTLTDTDDRPVATADLTVRAVARGLPDLYAVRLTPVRPATGGTVWPSVGRDVGLPRYAELSTSTDDIVERAHDRVTEVAELLRRWLAQGPPEARLVVATDQVTDPADGVLWGLVRAAQTEHPDRFVLLDSDGDPRSRTLVPGALATGEPQLVVRDGRITVPRLARTAAAPQPPRLDPDGTVLVTGAGGALGSLTARRLVTHHGVRRLLLLGRRGGMQPLAAELTALGATVRVAACDAANRAALARVLDTVPAAHPLTAVVHAAGVVSDGPLATLTPQRFAEVLRPKVDAAWHLHELTCEQDLAAFVLFSSLAGLVGNAGQANYAAANTGLDALAAYRRAAGLPAVSLAWGLWDAPGMGAALDETQRARIARTGVAPLPVERGLALFDACLGAREALLVPAALQPERATRVAPVLAGLAPATTATTPQQDWPRRLAGRGAAEQHRLLLELVRSTIVEVLGHSSVAAVAPDRGLMDLGFDSLTAVELAGRLGADTGVRTPSTVVFDHPTPTALAHYLRHELVGEEAADDEKPHELDEVSDEDLFALIDTELGER'
#  when generate longer sequences, we ended up in repeats
# let's do it 200 at a time, update prefix everytime

max_len = 5000
gen_step = 200
total_generated = ''
prefix = example_seq[:150]

while 1:
    if len(total_generated) > max_len:
        break

    ref = example_seq[150:200]
    penalty = 5.0
    topk = 1

    seed_seq = [aa_to_ctrl_idx[ii] for ii in prefix]
    # generate_num = len(kw_lineage+tax_lineage)+len(prefix+ref)
    generate_num = gen_step + len(kw_lineage+tax_lineage)
    seq_length = min(generate_num, 511)
    text = tax_lineage+kw_lineage+seed_seq
    padded_text = text + [0] * (generate_num - len(text))
    tokens_generated = np.tile(padded_text, (1,1))

    for token in range(len(text)-1, generate_num-1):
        prompt_logits = predict_fn(tokens_generated[:, :seq_length]).squeeze()
        _token = token if token < seq_length else -1
        prompt_logits = prompt_logits.cpu().detach().numpy()

        if penalty>0:
            penalized_so_far = set()
            for _ in range(token-3,token+1):
                generated_token = tokens_generated[0][_] - (vocab_size-26) # added
                if generated_token in penalized_so_far:
                    continue
                penalized_so_far.add(generated_token)
                prompt_logits[_token][generated_token] /= penalty

        # compute probabilities from logits
        prompt_probs = np.exp(prompt_logits[_token])
        prompt_probs = prompt_probs / sum(prompt_probs)
        pruned_list = np.argsort(prompt_probs)[::-1]

        if topk==1:
            idx = pruned_list[0]
        else:
            pruned_list = pruned_list[:topk]
            chosen_idx = torch.distributions.categorical.Categorical(logits=torch.tensor(np.expand_dims(prompt_logits[_token][pruned_list],0))).sample().numpy()[0]
            idx = pruned_list[chosen_idx]

        # assign the token for generation
        idx += (vocab_size-26) # added to convert 0 AA to original ctrl idx
        tokens_generated[0][token+1] = idx


    tokens_generated_so_far = tokens_generated[0].squeeze()[:token+2]
    tokens_generated_so_far = tokens_generated_so_far[(tokens_generated_so_far>=(vocab_size-26)) & (tokens_generated_so_far<(vocab_size-1))]
    tokens_generated_so_far = ''.join([ctrl_idx_to_aa[c] for c in tokens_generated_so_far])
    # save generated, update prefix before the next round
    total_generated += tokens_generated_so_far
    prefix = tokens_generated_so_far[-150:]
    print(prefix)
    tokens_generated_so_far = ''
    tokens_generated = ''

ARRGGFLDAVDLFDAEFFGVPPREAAAMDPQQRLVLELSWEALEDARIRPDALAGSRTGVFVGAISDDYATLLRRRGPDAIGPHSLTGTNRGIIANRVSYAFDLRGPSVALDTGCSATLVGIAELRMAVESLRGDAPVLSAGCVSLQAPT
DALAGSRTGVFVGAISDDYATLLRRRGPDAIGPHSLTGTNRGIIANRVSYAFDLRGPSVALDTGCSATLVGIAELRMAVESLRGDAPVLSAGCVSLQAPTSLAQRVLDAGRPDAVLTDFAGHLRTGDAVLRDGALVRGDAERVLADGRLA
AFDLRGPSVALDTGCSATLVGIAELRMAVESLRGDAPVLSAGCVSLQAPTSLAQRVLDAGRPDAVLTDFAGHLRTGDAVLRDGALVRGDAERVLADGRLASVRLAGPDLAGVTRLAGVPLARVGLARVGDAPLRVAGLRDAVLRGAVPLR
SLAQRVLDAGRPDAVLTDFAGHLRTGDAVLRDGALVRGDAERVLADGRLASVRLAGPDLAGVTRLAGVPLARVGLARVGDAPLRVAGLRDAVLRGAVPLRAGPDLAVGPLAVRGPAVLRGAVPLRGAPLRVAGDRLAGVRLAGPDLAGVR
SVRLAGPDLAGVTRLAGVPLARVGLARVGDAPLRVAGLRDAVLRGAVPLRAGPDLAVGPLAVRGPAVLRGAVPLRGAPLRVAGDRLAGVRLAGPDLAGVRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAPDGV
AGPDLAVGPLAVRGPAVLRGAVPLRGAPLRVAGDRLAGVRLAGPDLAGVRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAPDGVALTGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAGPD
LAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAPDGVALTGPDLAGVTRLAGPDLAGVTRLAGPDLAGVTRLAGPDLAGVT