# SilkomeGPT: Generative strategies for modeling, design and analysis of spider silk protein sequences for enhanced mechanical properties

Generative strategies for modeling, design and analysis of silk protein sequences for enhanced mechanical properties

Wei Lu, David L. Kaplan, Markus J. Buehler

Massachusetts Institute of Technology, 77 Massachusetts Ave., Cambridge, MA 02139, USA

Contact email: mbuehler@mit.edu

Abstract: Spider silks are remarkable materials characterized by superb mechanical properties such as strength, extensibility and lightweightedness. Yet, to date, limited models are available to fully explore sequence-property relationships for analysis and design. Here a custom generative large-language model is proposed to enable design of novel spider silk protein sequences to meet complex combinations of target mechanical properties. The model, pretrained on a large set of protein sequences, is fine-tuned on ~1,000 major ampullate spidroin (MaSp) sequences for which associated fiber-level mechanical properties exist, to yield an end-to-end forward and inverse generative approach that is aplied in a multi-agent strategy. Performance is assessed through: (1) a novelty analysis and protein type classification for generated spidroin sequences through Basic Local Alignment Search Tool (BLAST) searches, (2) property evaluation and comparison with similar sequences, (3) comparison of molecular structures, as well as, and (4) a detailed sequence motif analyses. This work generates silk sequences with property combinations that do not exist in nature, and develops a deep understanding of the mechanistic roles of sequence patterns in achieving overarching key mechanical properties (elastic modulus, strength, toughness, failure strain). The model provides an efficient approach to expand the silkome dataset, facilitating further sequence-structure analyses of silks, and establishes a foundation for synthetic silk design and optimization. This work not only shows the capacity of generative transformer models to design complex materials, but also illustrates an effective use of agentic modeling for self-improving design solutions.

In [4]:
import os
from tqdm.notebook import tqdm
import pandas as pd
import torch
import random
import numpy as np
import seaborn as sns
from transformers import get_linear_schedule_with_warmup
import time
import datetime
from matplotlib import pyplot as plt
from transformers import Trainer, TrainingArguments,DataCollatorForLanguageModeling
import re
from itertools import chain

device = torch.device("cuda")

In [5]:
#os.environ['HUGGINGFACE_HUB_CACHE '] = "/mnt/d/cache_Huggingface/"
#os.environ['HF_HOME'] = "/mnt/d/cache_Huggingface/"

In [6]:
# token = '...............' #insert your HF write token here  
# from huggingface_hub import login
# login(token=token)

In [7]:
print(torch.__version__)
print(torch. cuda. is_available(), torch.cuda.device_count())

2.1.0+cu121
True 2


### Helper functions

In [8]:
def params(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Parameters (in M): ", params/1e6) 

In [9]:
def select_by_length (dfseqonly, max_length_select=128, fieldname='Seq_Len'):
    
    dfseqonly.drop(dfseqonly[dfseqonly[fieldname] >max_length_select].index, inplace = True)
    print(dfseqonly.shape)
    dfseqonly=dfseqonly.reset_index(drop=True)
    return dfseqonly

def select_by_value (dfseqonly, value_select=128, fieldname='Seq_Len'):

    dfseqonly.drop(dfseqonly[dfseqonly[fieldname] != max_length_select].index, inplace = True)
    print(dfseqonly.shape)
    dfseqonly=dfseqonly.reset_index(drop=True)
    return dfseqonly

In [10]:
from sklearn.metrics import r2_score

def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

def remove_start_end_token (string_input, start='@', end='$'):
    res = string_input.replace(start, "")
    res = res.replace(end, "")
    return res

def remove_start_end_token_first (string_input, start='@', end='$'):
    i=string_input.find(start)
    j=string_input.find(end)
    return string_input [i+1:j]
    
def extract_task (string_input, end_task_token=')', ):
    #i=string_input.find(start)
    j=string_input.find(end_task_token)
    return string_input [:j+1]
    
def extract_start_and_end (string_input, start_token='[', end_token=']', ):
    #i=string_input.find(start)
    i=string_input.find(start_token)
    j=string_input.find(end_token)
    return string_input [i+1:j]
def extract_prediction_values (result_untokenized, start_token='/', end_token='|'):

    prediction=extract_start_and_end ( result_untokenized, start_token=start_token, end_token=end_token )
  
    pred_task=''
    values=None
    values = [float(i) for i in prediction.split(',')]
    return np.array (values)


def validate (model, validation_dataloader_text, temperature=0.1,
             num_beams=1,top_k=50,top_p=0.95, 
             plot_ind=False, 
              max_samples=9999999999,
             ):

    s_res =[]
    s_GT=[]  
    s_correct=0 
    s_total=0
    values=None
    R2_calculate=[]
    c_res =[]
    c_GT=[]

    acc_sol_t=-1
    silk_R2=-1   
                 
    model.eval()

    sample_count=0
   
    print ("Start validate...")
    for X_train_batch   in tqdm(  validation_dataloader_text):

        for iisample in range (len (X_train_batch)):
            sample_count += 1
    
            output=extract_task (X_train_batch[iisample], end_task_token='>') 

            if re.search('CalculateSilkContent<',  output):

                generated = torch.tensor(tokenizer.encode(output, add_special_tokens = False)).unsqueeze(0).to(device)
    
                sample_outputs = model.generate(
                                    inputs =generated, 
                                
                                    do_sample=True,   
                                    top_k=top_k, 
                                    max_new_tokens=64,
                               
                                    top_p=top_p, 
                                    num_return_sequences=1,
                                    pad_token_id=tokenizer.eos_token_id,
                                    temperature =temperature,
                                    num_beams=num_beams,
                                    ) 
                 
                
                    
                try:
                    for i, sample_output in enumerate(sample_outputs):
                      
                        result=tokenizer.decode(sample_output, skip_special_tokens=True)     
                        
                    GT_res=extract_prediction_values (X_train_batch[iisample], start_token='[', end_token=']')
                    prediction=extract_start_and_end ( result, start_token='[', end_token=']')
                
                    values = [float(i) for i in prediction.split(',')]
                   
                    if plot_ind:
                        print ("GT values silk:  ", GT_res, result)
                        print ("Prediction    : ", values)
            
                        plt.plot (GT_res, values, 'r.')
                        plt.plot ([0,1], [0,1], 'k')
                        plt.axis('square')
                        plt.xlabel ('GT')
                        plt.ylabel ('Prediction')
                        plt.title ('Calculated silk properties vs GT')
                        plt.show ()
                    
                    R2=r2_score ( GT_res, values)
                    
                    R2_calculate.append (R2)
                        
                    if len  (values) == len  (GT_res):
                        c_res.append ( values)
                        c_GT.append ( GT_res)
        
                except: 
                    print ("Error in silk prop....", R2)                          

            if re.search('CalculateSolubility<',  output):
                
                generated = torch.tensor(tokenizer.encode(output, add_special_tokens = False)).unsqueeze(0).to(device)

                sample_outputs = model.generate(
                                    inputs =generated, 
                                    #bos_token_id=random.randint(1,30000),
                                    do_sample=True,   
                                    top_k=top_k, 
                                    max_length =generated.shape[1] + 5,#max_length_task,
                                    top_p=top_p, 
                                    num_return_sequences=1,
                                    pad_token_id=tokenizer.eos_token_id,
                                    temperature =temperature,
                                    num_beams=num_beams,
                                    ) 
               
                for i, sample_output in enumerate(sample_outputs):
                    result=tokenizer.decode(sample_output, skip_special_tokens=True)
    
                prediction=extract_start_and_end ( result, start_token='[', end_token=']')
                
                GT_res=None

                try:
                    values = [float(i) for i in prediction.split(',')]
                    
                    GT_res=extract_prediction_values (X_train_batch[iisample], start_token='[', end_token=']')
                    s_res.append (list (values))
                    s_GT.append (list (GT_res))
    
                    if GT_res==values:
                        s_correct =s_correct+1
    
                    s_total =s_total+ 1
                except: 
                    print ("Error: ", result, "processed: ", values, GT_res)  

        if sample_count > max_samples:
            break

    model.train()

    if s_total>0:
        print ("--------------------------")
        acc_sol_t=s_correct/s_total
        print ("Accuracy solubility seq task: ", acc_sol_t)
        print ("Succesful tasks: ", s_total)
        print ("--------------------------") 

    else:
        print ("No succesful task completion.")
        
    try:
    
        c_res=torch.Tensor (c_res).flatten().numpy()

        c_GT=torch.Tensor (c_GT).flatten().numpy()

        print ("Results shape after flatten: ", c_res.shape, c_GT.shape)
        silk_R2=r2_score ( c_res, c_GT )
        print ('R2 score_overall silk properties ', silk_R2 )
        plt.plot (c_GT, c_res, 'r.')
        plt.plot ([0,1], [0,1], 'k')
        plt.axis('square')
        plt.xlabel ('GT silk properties')
        plt.ylabel ('Prediction silk properties')
        plt.title ("Prediction vs. GT")
        plt.savefig("R2_silk_v2.svg")
        plt.show ()   
        gfg=sns.jointplot(y=c_res , x=c_GT, kind ='kde')
        plt.savefig("R2_silk.svg")
        
        plt.show()

    except: 
        print (end="")           
        
    return acc_sol_t, silk_R2

### Inference

In [17]:
import transformers
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig,AutoModelForCausalLM

model_name = 'lamm-mit/SilkomeGPT'

model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    trust_remote_code=True
)
model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

pytorch_model.bin:   0%|          | 0.00/1.01G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.81M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

In [18]:
tokenizer

PreTrainedTokenizerFast(name_or_path='lamm-mit/SilkomeGPT', vocab_size=50000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [None]:
model.to(device)

In [21]:
df_ALL_SEQ = pd.read_csv('ALL_SILK_SEQ.csv')
df_ALL_SEQ
ALL_SEQ=df_ALL_SEQ['Sequence'].astype(str).to_list()
ALL_SEQ[:6]

['AGASSAAIARAGNALTTQSSTSRISYNVNSLVGPGGSFNLAALPAIMSNQVQSISASSPESSTCEILVQVLLELIASLVHILGVSNIGDLNYAPTSQTAAYMSSIFGYNGY',
 'GSGGGKLGGAGRGFYGASGAFGGLGAGLGGLGVGAGSSGGGSGGGGGGGGGAGDGGSAGVLAAASSSSAFGAFGPGSGSSGGGGGGGGGAGGGGSGSGGGQLGGAGRGFYGASGAFGGLGAGLGSIGAGAGSSGGGSGGGGGGGGGAGDGGSAGVLAAASSSSAFGAFGPGSGSSGGGGGGGGGVGGGAGGQGQGSSSGFLGASSSSAIGAFGLGSGGGGGGGGSGGGSTGSGGETDSGLFSTSSSSSAIGVFGSPTGGSSGGGGGGGGSASSQTSGVSSSVLLSDEAASRISSAALSLVYGNTLNTNVFPQVISSLYSLLRISNTGLSEYEIVLELLMEVISTLIHLLGSSVIGDVLYASPANAANYVADSFDFILV',
 'DVASATTSITPTTTTTISQTSFDGQTSGDYPRDEYPGNTLFDEQVSTEYPGGPSGSTSVGLDFAASLNSVLNSRNGLKSPQASLRIRSLTSALQQAMGPDGINPTVLSSVIKASLSSLKNSGMSSDSATVEGLMELVTALVQMLGTSRPDPTKSVSLSSSLGITASLAGALSVN',
 'VVGAGVGACAGGGAGAGAGAGAGAGSGAGAGSLASASSGLTSYSAASRIDNAVSNLVGPNGAFNAAVLPNILSNQVASISASSPGLSTCDILVQTLLELLAAMVHLLGSANIGNLNYASTADTASMVSSLIQQSM',
 'TGAQPGALPCSGGQGQGGPAGLYGAASSSSVVGAFGPGSGGSGSGGGGGGGGGGGSGSQQLSISPLNLLMSDDAASRISSTALSLIYGNTLNMNALPEVVASLYSLLTLSNPGLSRYEIIFELLLEVISTLIHLLGSSVIGDVFYAPPANAAQFVAESFNGIIV',
 'AQSSASALASS

In [22]:
def sample_from_text (prompt = "CalculateSolubility<AAAAAIIAIAIAA>",
                                    do_sample=True,   
                                    top_k=500, 
                                    max_new_tokens = 16,
                                    top_p=0.9, 
                                    num_return_sequences=1,
                                    temperature=0.1,
                                    verbatim=False,
                                    ):
    
    model.eval()
    pred_list=[]
    generated = torch.tensor(tokenizer.encode(prompt, add_special_tokens = False)) .unsqueeze(0)
    generated = generated.to(device)
    with torch.no_grad(): 
        sample_outputs = model.generate(
                                    inputs=generated, 
                                     
                                    eos_token_id =tokenizer.eos_token_id,
                                    do_sample=do_sample,   
                                    top_k=top_k, 
                                    max_new_tokens=max_new_tokens,
                                    top_p=top_p, 
                                    num_return_sequences=num_return_sequences,
                                    temperature=temperature,
                                    )
    
    for i, sample_output in tqdm (enumerate(sample_outputs)) :
        pred=tokenizer.decode(sample_output, skip_special_tokens=True)
        pred_list.append (pred)
        if verbatim:
            print(f"Prediction #{i}: {pred}\n\n" )

    return pred_list

In [None]:
def generate_silk_sequence_from_properties (prop=  [0, 0, 0, 0, 0, 0, 0, 0],
                                           max_new_tokens=512,
                                    top_k=500, 
                                    top_p=0.9, 
                                    num_return_sequences=1,
                                    temperature=0.1,    
                                           ):

    str_= f"GenerateSilkContent<{prop[0]:1.3f},{prop[1]:1.3f},{prop[2]:1.3f},{prop[3]:1.3f},{prop[4]:1.3f},{prop[5]:1.3f},{prop[6]:1.3f},{prop[7]:1.3f}>"
     
    list_l=sample_from_text (prompt = str_,
                                    do_sample=True,   
                                    max_new_tokens = max_new_tokens,
                                    top_k=top_k, 
                                    top_p=top_p, 
                                    num_return_sequences=num_return_sequences,
                                    temperature=temperature,
                                    )
    res_list=[]
    for i in range (len (list_l)):
        prediction=extract_start_and_end ( list_l[i], start_token='[', end_token=']')
        res_list.append (prediction)
    
    return res_list

In [24]:
ALL_SEQ[:5]

['AGASSAAIARAGNALTTQSSTSRISYNVNSLVGPGGSFNLAALPAIMSNQVQSISASSPESSTCEILVQVLLELIASLVHILGVSNIGDLNYAPTSQTAAYMSSIFGYNGY',
 'GSGGGKLGGAGRGFYGASGAFGGLGAGLGGLGVGAGSSGGGSGGGGGGGGGAGDGGSAGVLAAASSSSAFGAFGPGSGSSGGGGGGGGGAGGGGSGSGGGQLGGAGRGFYGASGAFGGLGAGLGSIGAGAGSSGGGSGGGGGGGGGAGDGGSAGVLAAASSSSAFGAFGPGSGSSGGGGGGGGGVGGGAGGQGQGSSSGFLGASSSSAIGAFGLGSGGGGGGGGSGGGSTGSGGETDSGLFSTSSSSSAIGVFGSPTGGSSGGGGGGGGSASSQTSGVSSSVLLSDEAASRISSAALSLVYGNTLNTNVFPQVISSLYSLLRISNTGLSEYEIVLELLMEVISTLIHLLGSSVIGDVLYASPANAANYVADSFDFILV',
 'DVASATTSITPTTTTTISQTSFDGQTSGDYPRDEYPGNTLFDEQVSTEYPGGPSGSTSVGLDFAASLNSVLNSRNGLKSPQASLRIRSLTSALQQAMGPDGINPTVLSSVIKASLSSLKNSGMSSDSATVEGLMELVTALVQMLGTSRPDPTKSVSLSSSLGITASLAGALSVN',
 'VVGAGVGACAGGGAGAGAGAGAGAGSGAGAGSLASASSGLTSYSAASRIDNAVSNLVGPNGAFNAAVLPNILSNQVASISASSPGLSTCDILVQTLLELLAAMVHLLGSANIGNLNYASTADTASMVSSLIQQSM',
 'TGAQPGALPCSGGQGQGGPAGLYGAASSSSVVGAFGPGSGGSGSGGGGGGGGGGGSGSQQLSISPLNLLMSDDAASRISSTALSLIYGNTLNMNALPEVVASLYSLLTLSNPGLSRYEIIFELLLEVISTLIHLLGSSVIGDVFYAPPANAAQFVAESFNGIIV']

In [25]:
def is_novel (ALL_SEQ,seq):
    if seq in ALL_SEQ:
        print (f"Sequence {seq} NOT NOVEL")
        return False
    else:
        print (f"Sequence {seq} is NOVEL")
        return True

In [26]:
is_novel(ALL_SEQ, 'MKHTIGILGGMGPAATADMLEKFVELRHASCDQQHIPLIVSSIPDIPDRTACLLSGGPSPYRYLERYLHMLEDAGAECIVIPCNTAHYWFDDLQNVAKARMISILDATLGDIPPSARHVGLLATNATLATGLYQKKALARGLTLIQPEDAGQALVMQAIYTLKRGDKTAAQALLLPQIDSLIARGAQAIIMGCTEIPLIVAGHERAIACPMIDSTASLVRAAIRWYESWPDTRASLTGEQRLTA')

Sequence MKHTIGILGGMGPAATADMLEKFVELRHASCDQQHIPLIVSSIPDIPDRTACLLSGGPSPYRYLERYLHMLEDAGAECIVIPCNTAHYWFDDLQNVAKARMISILDATLGDIPPSARHVGLLATNATLATGLYQKKALARGLTLIQPEDAGQALVMQAIYTLKRGDKTAAQALLLPQIDSLIARGAQAIIMGCTEIPLIVAGHERAIACPMIDSTASLVRAAIRWYESWPDTRASLTGEQRLTA is NOVEL


True

In [27]:
is_novel(ALL_SEQ, 'AGASSAAIARAGNALTTQSSTSRISYNVNSSSSSLVGPGGSFNLAALPAIMSNQVQSISASSPESSTCEILVQVLLELIASLVHILGVSNIGDLNYAPTSQTAAYMSSIFGYNGY')

Sequence AGASSAAIARAGNALTTQSSTSRISYNVNSSSSSLVGPGGSFNLAALPAIMSNQVQSISASSPESSTCEILVQVLLELIASLVHILGVSNIGDLNYAPTSQTAAYMSSIFGYNGY is NOVEL


True

In [28]:
barplotlabels=['Toughness', 'Toughness SD', 'E', 'E SD', 'Strength', 'Strength SD', 'Strain', 'Strain SD']

In [None]:
from sklearn.metrics import mean_absolute_error,mean_squared_error
def make_bar_plot (req, prop, fname=None, barplotlabels=None):
   
    x1= np.arange (len (req))
    x2= np.arange (len (prop))
    
    fig, ax = plt.subplots()
    
    ax.bar(x1-0.1, req, color='blue', label='Target', width=0.2)
    ax.bar(x2+0.1  , prop, color='red', label='Predicted', width=0.2)
    
    if barplotlabels != None:
        plt.xticks (x1, barplotlabels[:len (x1)], rotation='vertical')

    plt.xlabel('Properties')
    plt.ylabel('Value')
    
    if fname != None:
        plt.savefig(fname)

    plt.show() 

def generate_new (
    mum_samples_perstep=64,
    req=[0.455,0.199,0.239,0.194,0.373,0.304,0.327,0.101],
    temperature=1, top_k=500, 
    top_p=.9,   csv_output='output.csv',              
   ):
    
    seq_list_p=generate_silk_sequence_from_properties (prop=  req,
                                            temperature=temperature,
                                            top_k=top_k, 
                                            top_p=top_p, 
                                            num_return_sequences=mum_samples_perstep,
                                            )
    r2_list=[]
    seq_list=[]
    novel_list=[]
    prop_list=[]
    MAE_list=[]
    MSE_list=[]
    sol_list=[]
    
    for i in range (mum_samples_perstep):
        novel=is_novel(ALL_SEQ,seq_list_p[i])
        print (f"Sequence {seq_list_p[i]} is NOVEL={novel}")

        sol=sample_from_text (prompt = f"CalculateSolubility<{seq_list_p[i]}>",
                                            do_sample=True,   
                                            top_k=500, 
                                            #max_length = 300,
                                            max_new_tokens = 7,
                                            top_p=0.9, 
                                            num_return_sequences=1,
                                            temperature=0.01,
                                            )[0]
        
        prop=sample_from_text (prompt = f"CalculateSilkContent<{seq_list_p[i]}>",
                                            do_sample=True,   
                                            top_k=500, 
                                            #max_length = 300,
                                            max_new_tokens = 64,
                                            top_p=0.9, 
                                            num_return_sequences=1,
                                            temperature=0.01,
                                            )[0]
        try:
            
            if novel:
                prop=extract_prediction_values (prop, start_token='[', end_token=']')
                print ("##########################")#\nRequired: ", req,"\nPredicted: ",  prop)
                print ("Sequence: ", seq_list_p[i])
                sol_value=extract_prediction_values (sol, start_token='[', end_token=']')
                
                print ("Solubility: ", sol_value)
                
                plt.plot (req, prop, 'ro')
                plt.plot ([0,1], [0,1], 'k--')
                plt.axis('square')
                plt.show ()

                make_bar_plot (req, prop, fname=f'plot_s{i}.svg', barplotlabels=barplotlabels)
                r2=r2_score ( req, prop )
                MAE=mean_absolute_error (req, prop)
                MSE=mean_squared_error (req, prop)
                print (f"R2={r2}, MAE={MAE}....\n###################")

                r2_list.append (r2)
                MSE_list.append (MSE)
                MAE_list.append (MAE)
                novel_list.append (novel)
                seq_list.append (seq_list_p[i])
                prop_list.append (prop)
                sol_list.append (int (sol_value) )

        except:
            print ("ERROR: ", req, prop)

    results_df = pd.DataFrame(
    {'Sequence': seq_list,
     'Novel': novel_list,
     'R2': r2_list,
     'MAE': MAE_list,
     'Properties': prop_list,
     'Solubility': sol_list,
    })
    
    return r2_list,MAE_list,MSE_list,novel_list,seq_list, prop_list, sol_list,csv_output

In [None]:
def generate_new_and_find_best (
                    req=[0.300,0.204,0.279,0.043,0.329,0.127,0.702,0.139],
                    mum_samples_perstep=128,
                    temperature=1.25, top_k=500,
                    top_p=.8,
                    stem=None, best_crit=0, #0=R2, 1=MAE, 2=MSE
                    num_repeat=4,
                    ):

    if stem==None:
        stem=str (req)
    
    r2_list_overall = []
    MAE_list_overall = [] 
    MSE_list_overall = []
    novel_list_overall = []
    seq_list_overall = []
    prop_list_overall = []
    sol_list_overall = []
    csv_output_overall = []
    
    for i in range (num_repeat):
        r2_list,MAE_list,MSE_list,novel_list,seq_list, prop_list,sol_list,csv_output = generate_new(
          mum_samples_perstep=mum_samples_perstep,
          # max_new_tokens=max_new_tokens,
          req=req,
          temperature=temperature,
          top_k=top_k,
          top_p=top_p,
          csv_output=f'{stem}_prop_{i}.csv',
          )
        
        r2_list_overall.append (r2_list)
        MAE_list_overall.append (MAE_list)
        MSE_list_overall.append (MSE_list)
        novel_list_overall.append (novel_list)
        seq_list_overall.append (seq_list)
        prop_list_overall.append (prop_list)
        sol_list_overall.append (sol_list)
        csv_output_overall.append (csv_output)
        
        torch.cuda.empty_cache()
        
    r2_list_overall = list(chain(*r2_list_overall))
    MAE_list_overall = list(chain(*MAE_list_overall))
    MSE_list_overall = list(chain(*MSE_list_overall))
    novel_list_overall = list(chain(*novel_list_overall))
    seq_list_overall = list(chain(*seq_list_overall))
    prop_list_overall = list(chain(*prop_list_overall))
    sol_list_overall = list(chain(*sol_list_overall))
    csv_output_overall = list(chain(*csv_output_overall))

    print ("Ratio of novel structures: ", len (novel_list_overall)/mum_samples_perstep)
    print ("Done, now find best solution and save...")
    try:
        if best_crit==0:
            best=np.argmax (r2_list_overall)
        if best_crit==1:
            best=np.argmax (MAE_list_overall)
        if best_crit==2:
            best=np.argmax (MSE_list_overall)

        plt.plot (req, prop_list_overall[best], 'ro')
        plt.plot ([0,1], [0,1], 'k--')
        plt.show ()
        make_bar_plot (req, prop_list_overall[best], fname=f'{stem}_plot_s_best.svg', barplotlabels=barplotlabels)
        seq_list_overall[best], r2_list_overall[best], MSE_list_overall[best], novel_list_overall[best]

        print (f"For req={req} ({barplotlabels})")
        print ("Best sequence: ", seq_list_overall[best], "\nR2=", r2_list_overall[best], "\nMSE=",MSE_list_overall[best])
        
        results_df = pd.DataFrame(
            {'Sequence': seq_list_overall,
             'Novel': novel_list_overall,
             'R2': r2_list_overall,
             'MAE': MAE_list_overall,
             'Properties': prop_list_overall,
             'Solubility': sol_list_overall,
            })
        results_df.to_csv (f'{stem}_{round(r2_list_overall[best],4)}_prop.csv')

    except:
        print ("No novel solutions generated or other error.")
        
    print(['['+ i for i in ''.join(csv_output_overall).split('[')[1:]])

    return csv_output_overall

In [None]:
barplotlabels

### Design examples

In [None]:
mum_samples_perstep = 32
num_repeat = 64

In [None]:
print(f'num_samples_perstep={mum_samples_perstep}/128, num_repeat={num_repeat}')

generate_new_and_find_best (req=[0.300,0.20,0.2,0.1,0.3,0.1,0.20,0.1],
                    mum_samples_perstep=32, #128
                    temperature=1.25, top_k=500, 
                    top_p=.8,   
                    stem=None,
                    num_repeat=num_repeat,
                      )

In [None]:
torch.cuda.empty_cache()

In [None]:
print(f'num_samples_perstep={mum_samples_perstep}/128, num_repeat={num_repeat}')
generate_new_and_find_best (req=[0.600,0.204,0.279,0.043,0.329,0.127,0.502,0.139],
                  mum_samples_perstep=32, #128
                  temperature=1.25, top_k=500, 
                  top_p=.8,   
                  stem=None,
                  num_repeat=num_repeat,
                    )

### Other ways to generate/calculate

In [None]:
sample_from_text (prompt = "CalculateSolubility<{seq}>",
                                    do_sample=True,   
                                    top_k=500, 
                                    #max_length = 300,
                                    max_new_tokens = 16,
                                    top_p=0.9, 
                                    num_return_sequences=1,
                                    temperature=0.1,
                                    )

In [None]:
generate_silk_sequence_from_properties (prop=  [0.1,0.17,0.272,0.216,0.462,0.051,0.189,0.9],
                                        temperature=1.5,
                                        num_return_sequences=4,
                                       )

In [None]:
is_novel(ALL_SEQ, 'AGASSAAIARAGNALTTQSSTSRISYNVNSSSSSLVGPGGSFNLAALPAIMSNQVQSISASSPESSTCEILVQVLLELIASLVHILGVSNIGDLNYAPTSQTAAYMSSIFGYNGY')