In [None]:
import os
import pandas as pd

output_folder = 'AIzymes_IC50'
os.makedirs(output_folder, exist_ok=True)

data = {
    'NQEFT': {"filename": "Landscapes_CAZ_NQEFT", "residues": [76, 124, 132, 156, 243]},
    'AFST':  {"filename": "Landscapes_CAZ_AFST", "residues": [33, 72, 212, 213]},
    'LFLI':  {"filename": "Landscapes_CAZ_LFLI", "residues": [67, 142, 158, 215]}
}

reference_seq = 'OXA48.seq' 
with open(f'{output_folder}/{reference_seq}', 'r') as f:
    reference_sequence = f.read().strip()

# Process .csv files
for landscape in data:
    df = pd.read_csv(f'{output_folder}/{data[landscape]["filename"]}.csv')
    for idx, row in df.iterrows():
        new_sequence = list(reference_sequence)  
        for mut_idx, residue in enumerate(data[landscape]["residues"]):
            new_sequence[residue-1] = row['mutations'][mut_idx] 
        df.at[idx, 'sequence'] = ''.join(new_sequence)
        df.at[idx, 'landscape'] = landscape

    df.to_csv(f'{output_folder}/{data[landscape]["filename"]}_updated.csv', index=False)

display(df)

Unnamed: 0,mutations,IC50,sequence
0,LFLI,0.013,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
1,LFLT,0.014,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
2,LFPI,0.035,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
3,LFPT,0.029,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
4,LLLI,0.013,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
5,LLLT,0.012,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
6,LLPI,0.046,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
7,LLPT,0.03,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
8,IFLI,0.015,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...
9,IFLT,0.014,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...


In [23]:
%run src/plm_trainer_multi_small.py

df_path = [f'{output_folder}/{data[landscape]["filename"]}_updated.csv' for landscape in data]
#df_path = f'{output_folder}/{data["AFST"]["filename"]}_updated.csv'
scores          = ['IC50']

dataset = PLM_trainer(
    output_folder   = output_folder,
    verbose         = False
    )

PLM_trainer.load_dataset(    
    dataset,            
    df_path         = df_path,
    scores          = scores,
    labels          = [],
    select_unique   = False,
    normalize       = 'minmax'
    )

### PLM trainer loaded. ###
Loading dataset: AIzymes_IC50/Landscapes_CAZ_NQEFT_updated.csv
Loading dataset: AIzymes_IC50/Landscapes_CAZ_AFST_updated.csv
Loading dataset: AIzymes_IC50/Landscapes_CAZ_LFLI_updated.csv
### 3 files loaded into one dataset. ###
### Data normalized. ###


In [None]:
PLM_trainer.train_PLM( 
    dataset,
    epochs          = 200,
    esm2_model_name = "facebook/esm2_t6_8M_UR50D",
    p_loss          = 0.,
    liveplot        = False,
    overwrite       = True,
    print_testtrain = True,
    validation      = "test_train",
    figure_name     = 'all_data'
)

train_df


Unnamed: 0,mutations,IC50,sequence,norm_IC50
15,NHKDP,0.296,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.566866
11,NHEDP,0.159,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.293413
16,TQEFT,0.013,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.001996
39,ALAA,0.389,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.752495
47,VLAA,0.513,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNVHFTEHKSQGVVVL...,1.0
41,VFSA,0.015,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNVHFTEHKSQGVVVL...,0.005988
38,ALAT,0.14,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.255489
33,AFSA,0.012,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.0
42,VFAT,0.014,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNVHFTEHKSQGVVVL...,0.003992
27,THEDP,0.43,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.834331


test_df


Unnamed: 0,mutations,IC50,sequence,norm_IC50
48,LFLI,0.013,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.001996
30,THKDT,0.166,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.307385
23,TQKDP,0.122,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.219561
2,NQEDT,0.031,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.037924
58,IFPI,0.126,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.227545
56,IFLI,0.015,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.005988
37,ALSA,0.177,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.329341
34,AFAT,0.015,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.005988
17,TQEFP,0.031,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.037924
24,THEFT,0.062,MRVLALSAVFLVASIIGMPAVAKEWQENKSWNAHFTEHKSQGVVVL...,0.0998


 31%|███       | 62/200 [07:26<15:27,  6.72s/it]

In [None]:
PLM_trainer.train_PLM( 
    dataset,
    epochs          = 200,
    esm2_model_name = "facebook/esm2_t6_8M_UR50D",
    p_loss          = 0.,
    liveplot        = False,
    overwrite       = True,
    print_testtrain = True,
    validation      = "landscape",
    validation_landscape = 'NQEFT'
    figure_name     = 'test_set_NQEFT'
)

PLM_trainer.train_PLM( 
    dataset,
    epochs          = 200,
    esm2_model_name = "facebook/esm2_t6_8M_UR50D",
    p_loss          = 0.,
    liveplot        = False,
    overwrite       = True,
    print_testtrain = True,
    validation      = "landscape",
    validation_landscape = 'AFST'
    figure_name     = 'test_set_AFST'
)

PLM_trainer.train_PLM( 
    dataset,
    epochs          = 200,
    esm2_model_name = "facebook/esm2_t6_8M_UR50D",
    p_loss          = 0.,
    liveplot        = False,
    overwrite       = True,
    print_testtrain = True,
    validation      = "landscape",
    validation_landscape = 'LFLI'
    figure_name     = 'test_set_LFLI'
)