In [1]:
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# User input for selecting folder
import ipywidgets as widgets
from IPython.display import display, clear_output

folder_options = ['models', 'models_lora', 'models_lora_large', 'models_fullfinetune', 'model_ensemble']
folder_selector = widgets.Dropdown(options=folder_options, description='Folder:')


In [3]:
def on_folder_selection(change):
    global selected_folder
    selected_folder = change.new
    print(f"Selected folder: {selected_folder}")

folder_selector.observe(on_folder_selection, names='value')
display(folder_selector)

Dropdown(description='Folder:', options=('models', 'models_lora', 'models_lora_large', 'models_fullfinetune', …

In [4]:
# User input for mode
mode_options = [
    "Predict from protein + ligand pairs",
    "Generate from ligand + protein scaffold",
    "Generate from protein + ligand scaffold"
]
mode_selector = widgets.RadioButtons(options=mode_options, description='Mode:')

def on_mode_selection(change):
    global selected_mode
    selected_mode = change.new
    print(f"Selected mode: {selected_mode}")
    
    # Update additional options based on selected mode
    # update_additional_options()

mode_selector.observe(on_mode_selection, names='value')
display(mode_selector)

RadioButtons(description='Mode:', options=('Predict from protein + ligand pairs', 'Generate from ligand + prot…

In [5]:
# User input for mode
file_options = [
    "From Sequence",
    "From CSV"   
]
file_selector = widgets.RadioButtons(options=file_options, description='File:')

def on_file_selection(change):
    global selected_file
    selected_file = change.new
    print(f"Selected mode: {selected_file}")
    
    # Update additional options based on selected mode
    # update_additional_options()

file_selector.observe(on_file_selection, names='value')
display(file_selector)

RadioButtons(description='File:', options=('From Sequence', 'From CSV'), value='From Sequence')

In [6]:
# User input for additional options based on mode
def update_additional_options():
    def on_predict_selection(change):
            global selected_predict
            selected_predict = change.new
            print(f"Selected properties: {selected_predict}")
    def on_predict_prot(change):
            global selected_prot
            selected_prot = change.new
            print(f"prot_path: {selected_prot}")
    def on_predict_lig(change):
            global selected_lig
            selected_lig = change.new
            print(f"prot_path: {selected_lig}")

    
    def on_generate_scaffold(change):
            global selected_generate_scaffold
            selected_generate_scaffold = change.new
            print(f"Selected properties: {selected_generate_scaffold}")
    def on_generate_mask(change):
            global selected_generate_mask
            selected_generate_mask = change.new
            print(f"Selected properties: {selected_generate_mask}")
    def on_generate_number(change):
            global selected_generate_number
            selected_generate_number = change.new
            print(f"Selected properties: {selected_generate_number}")
    
    if selected_mode == "Predict from protein + ligand pairs":
        kcat_options = ["Kcat", "Km", "Kd", "Ki", "IC50", "EC50", "Functional residue", "Protein + ligand embedding", "Protein + ligand logits"]
        kcat_selector = widgets.SelectMultiple(options=kcat_options, description='Properties:')
        kcat_selector.observe(on_predict_selection, names='value')
        protein_path = widgets.Text(description='protein_path:')
        protein_path.observe(on_predict_prot, names='value')
        ligand_path = widgets.Text(description='ligand_path:')
        ligand_path.observe(on_predict_lig, names='value')
        
        display(kcat_selector,protein_path,ligand_path)
        
    elif selected_mode == "Generate from ligand + protein scaffold":
        ligand_path = widgets.Text(description='ligand_path:')
        ligand_path.observe(on_predict_lig, names='value')
        
        scaffold_input = widgets.Text(description='prot_Scaffold_path:')
        scaffold_input.observe(on_generate_scaffold, names='value')
        
        masked_positions_input = widgets.Text(description='Masked positions:')
        masked_positions_input.observe(on_generate_mask, names='value')
        
        num_examples_input = widgets.IntText(description='Number of examples:')
        num_examples_input.observe(on_generate_number, names='value')

        
        display(ligand_path,scaffold_input, masked_positions_input, num_examples_input)

    elif selected_mode == "Generate from protein + ligand scaffold":
        protein_path = widgets.Text(description='protein_path:')
        protein_path.observe(on_predict_prot, names='value')
        
        scaffold_input = widgets.Text(description='lig_Scaffold_path:')
        scaffold_input.observe(on_generate_scaffold, names='value')
        
        masked_positions_input = widgets.Text(description='Masked positions:')
        masked_positions_input.observe(on_generate_mask, names='value')
        
        num_examples_input = widgets.IntText(description='Number of examples:')
        num_examples_input.observe(on_generate_number, names='value')

        
        display(protein_path,scaffold_input, masked_positions_input, num_examples_input)

    

In [7]:
update_additional_options()

SelectMultiple(description='Properties:', options=('Kcat', 'Km', 'Kd', 'Ki', 'IC50', 'EC50', 'Functional resid…

Text(value='', description='protein_path:')

Text(value='', description='ligand_path:')

In [8]:
device_options = ['CPU', 'GPU']
device_selector = widgets.RadioButtons(options=device_options, description='Device:')

def on_device_selection(change):
    global selected_device
    selected_device = 'cuda' if change.new == 'GPU' else 'cpu'
    print(f"Selected device: {selected_device}")

device_selector.observe(on_device_selection, names='value')
display(device_selector)

RadioButtons(description='Device:', options=('CPU', 'GPU'), value='CPU')

In [10]:
# Load the necessary models and heads based on user selections.

import torch

# Load models based on folder and configuration
model_files = [
    'chem_model.pt', 'prot_model.pt', 'chem_model_lmhead.pt', 'prot_model_lmhead.pt',
    'joint_layer3x.pt', 'kcat_head.pt', 'km_head.pt', 'ki_head.pt', 'kd_head.pt', 'IC50_head.pt',
    'EC50_head.pt', 'site_head.pt', 'position_encoding.pt'
]

In [13]:
# Load models based on user input
def load_models(selected_folder, device,idx):
    models = {}
    for i in range(idx):
        for model_file in model_files:
            model_path = f'./{selected_folder}/models_fold{i+1}/{model_file}'
            try:
                models[f'{model_file}{i+1}'] = torch.load(model_path, map_location=device)
                print(f"Loaded {model_file}{i+1} from {selected_folder}/models_fold{i+1}")
            except FileNotFoundError:
                print(f"{model_file}{i+1} not found in {selected_folder}/models_fold{i+1}")
    return models

if selected_folder == 'model_ensemble':
    models = load_models(selected_folder, selected_device,3)
else:
    models = load_models(selected_folder, selected_device,1)

  models[f'{model_file}{i+1}'] = torch.load(model_path, map_location=device)


Loaded chem_model.pt1 from models/models_fold1
Loaded prot_model.pt1 from models/models_fold1
Loaded chem_model_lmhead.pt1 from models/models_fold1
Loaded prot_model_lmhead.pt1 from models/models_fold1
Loaded joint_layer3x.pt1 from models/models_fold1
Loaded kcat_head.pt1 from models/models_fold1
Loaded km_head.pt1 from models/models_fold1
Loaded ki_head.pt1 from models/models_fold1
Loaded kd_head.pt1 from models/models_fold1
Loaded IC50_head.pt1 from models/models_fold1
Loaded EC50_head.pt1 from models/models_fold1
Loaded site_head.pt1 from models/models_fold1
Loaded position_encoding.pt1 from models/models_fold1


In [825]:
def load_csv():
    import pandas as pd
    prot = list(pd.read_csv(selected_prot)['sequence'])
    lig = list(pd.read_csv(selected_lig)['smiles'])
    return(prot,lig)    

In [826]:
prot,lig = load_csv()
print(f'number of protein is: {len(prot)}')
print(f'number of ligand is: {len(prot)}')

In [829]:
from transformers import AutoTokenizer
from tqdm import tqdm
# Tokenize protein and ligand sequences based on mode
if selected_mode == "Predict from protein + ligand pairs":
    model_checkpoint3 = './models/sf_tokenizer'
    tokenizer_sf = AutoTokenizer.from_pretrained(model_checkpoint3)
    
    model_prot_ckpt = "facebook/esm2_t12_35M_UR50D"
    tokenizer_prot = AutoTokenizer.from_pretrained(model_prot_ckpt)
    
    tokenizer_prot.model_max_length = 1024
    
    tokenizer_sf.model_max_length = 1024
    

    def tokenize_data(inputs_slf,inputs_prot):    
        inputs_encoded_slf = tokenizer_sf(inputs_slf, padding="max_length", truncation=True, max_length=tokenizer_sf.model_max_length)
        inputs_encoded_prot = tokenizer_prot(inputs_prot, padding="max_length", truncation=True, max_length=tokenizer_prot.model_max_length)
    
            # Prepare inputs for model_one
        input_ids_slf = torch.tensor(inputs_encoded_slf['input_ids'])
        attention_mask_slf = torch.tensor(inputs_encoded_slf['attention_mask'])
        
        
        input_ids_prot = torch.tensor(inputs_encoded_prot['input_ids'])
        attention_mask_prot = torch.tensor(inputs_encoded_prot['attention_mask'])
    
    
        return input_ids_slf, attention_mask_slf,input_ids_prot, attention_mask_prot

    from torch.utils.data import TensorDataset, DataLoader

    # Assuming `n` and `t3` are defined somewhere in your code
    train_pr_1,train_sf_1 = load_csv()
    n=1
    for j in tqdm(range(len(train_sf_1))): 
        train_true_sf = []
        train_attention_sf = []
        train_true_prot = []
        train_attention_prot = []
        train_label = []
    
        for i in range(n):
            e1, e2, e3, e4 = tokenize_data(train_sf_1[j * n + i : j * n + i + 1],train_pr_1[j * n + i : j * n + i + 1])
            train_true_sf.append(e1.view(-1, 512*2))
            train_attention_sf.append(e2.view(-1, 512*2))
            train_true_prot.append(e3.view(-1, 512*2))
            train_attention_prot.append(e4.view(-1, 512*2))
    
    
        # Convert Python lists to PyTorch tensors
        train_true_sf = torch.cat(train_true_sf)
        train_attention_sf = torch.cat(train_attention_sf)
        train_true_prot = torch.cat(train_true_prot)
        train_attention_prot = torch.cat(train_attention_prot)
    
        combined_input = torch.cat([train_true_sf, train_attention_sf,train_true_prot, train_attention_prot], dim=1)
    
        # Create a PyTorch TensorDataset
        dataset = TensorDataset(combined_input)
    
        # Create a PyTorch DataLoader
        batch_size = 2
        dataloader = DataLoader(dataset, batch_size=batch_size)
    
        # Save the PyTorch DataLoader using torch.save
        torch.save(dataloader, f'./dataset/saved_dataset_ft{j}')


100%|██████████| 140/140 [00:01<00:00, 81.84it/s] 


In [830]:
train_pr_1,train_sf_1 = load_csv()

In [851]:
def run_inference_predict(number_of_batch,idx):
    from tqdm import tqdm
    import numpy as np
    final_p_kcat = []
    final_p_kd = []
    final_p_ki = []
    final_p_km = []
    final_p_IC50 = []
    final_p_EC50 = []
    prot_model = models[f'prot_model.pt{idx+1}']
    chem_model = models[f'chem_model.pt{idx+1}']
    joint_layer = models[f'joint_layer3x.pt{idx+1}']
    EC50_layer = models[f'EC50_head.pt{idx+1}']
    IC50_layer = models[f'IC50_head.pt{idx+1}']
    kcat_layer = models[f'kcat_head.pt{idx+1}']
    kd_layer = models[f'kd_head.pt{idx+1}']
    ki_layer = models[f'ki_head.pt{idx+1}']
    km_layer = models[f'km_head.pt{idx+1}']
    position_encoding = models[f'position_encoding.pt{idx+1}']
    for param in prot_model.parameters():
        param.requires_grad = False
    
    for param in chem_model.parameters():
        param.requires_grad = False   

    for param in joint_layer.parameters():
        param.requires_grad = False

    for param in EC50_layer.parameters():
        param.requires_grad = False

    for param in IC50_layer.parameters():
        param.requires_grad = False

    for param in kcat_layer.parameters():
        param.requires_grad = False

    for param in kd_layer.parameters():
        param.requires_grad = False

    for param in ki_layer.parameters():
        param.requires_grad = False

    for param in km_layer.parameters():
        param.requires_grad = False


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    prot_model.to(device)
    chem_model.to(device)
    joint_layer.to(device)
    position_encoding.to(device)
    EC50_layer.to(device)
    IC50_layer.to(device)
    kcat_layer.to(device)
    kd_layer.to(device)
    ki_layer.to(device)
    km_layer.to(device)

    # Function to perform training step
    for i in tqdm(range(number_of_batch)):

        # dataset = torch.load('./llm_prot_Kcat_test/saved_dataset_ft'+f'{i}')
        dataset = torch.load('./dataset/saved_dataset_ft'+f'{i}')
        batch_p_kcat = []
        batch_p_kd = []
        batch_p_ki = []
        batch_p_km = []
        batch_p_IC50 = []
        batch_p_EC50 = []
        for batch in dataset:
            inputs = batch
            
            # Prepare inputs for model
            inputs = inputs[0].view(-1,512*8)
            input_ids_sf = inputs[:, 512*0:512*2].view(-1,512*2).to(device).to(dtype=torch.int)
            attention_mask_sf = inputs[:, 512*2:512*4].view(-1, 512*2).to(device).to(dtype=torch.int)
            input_ids_prot = inputs[:, 512*4:512*6].view(-1,512*2).to(device).to(dtype=torch.int)
            attention_mask_prot = inputs[:, 512*6:512*8].view(-1, 512*2).to(device).to(dtype=torch.int)

            with torch.no_grad():
                sf_outputs = chem_model(input_ids_sf, attention_mask=attention_mask_sf)
                sf_predictions = sf_outputs.last_hidden_state

                student_outputs = prot_model(input_ids_prot, attention_mask=attention_mask_prot)
                student_predictions = student_outputs.last_hidden_state


                positions = position_encoding(torch.linspace(1,512*4,512*4,dtype=torch.int32).view(-1,512*4).to(device)).to(device)
                combined_input = torch.cat([student_predictions,sf_predictions],dim=1).to(device)

                mixture = joint_layer(combined_input+positions, attention_mask = torch.cat([attention_mask_prot,attention_mask_sf],dim=1).view(-1,1,1,512*4))

                mean_mixture = torch.mean(mixture.last_hidden_state,dim=1)
                heads = [kcat_layer,kd_layer,ki_layer,km_layer,IC50_layer,EC50_layer]
                pred_val_kcat = heads[0](mean_mixture)
                pred_val_kd = heads[1](mean_mixture)
                pred_val_ki = heads[2](mean_mixture)
                pred_val_km = heads[3](mean_mixture)
                pred_val_IC50 = heads[4](mean_mixture)
                pred_val_EC50 = heads[5](mean_mixture)

            batch_p_kcat.append(pred_val_kcat[:,0:])
            batch_p_kd.append(pred_val_kd[:,0:])
            batch_p_ki.append(pred_val_ki[:,0:])
            batch_p_km.append(pred_val_km[:,0:])
            batch_p_IC50.append(pred_val_IC50[:,0:])
            batch_p_EC50.append(pred_val_EC50[:,0:])
        final_p_kcat.append(batch_p_kcat)
        final_p_kd.append(batch_p_kd)
        final_p_ki.append(batch_p_ki)
        final_p_km.append(batch_p_km)
        final_p_IC50.append(batch_p_IC50)
        final_p_EC50.append(batch_p_EC50)
    final_p_kcat_f = [] 
    final_p_kd_f = [] 
    final_p_ki_f = [] 
    final_p_km_f = [] 
    final_p_IC50_f = [] 
    final_p_EC50_f = [] 
    for ii in range(len(final_p_kcat)):
        for jj in range(len(final_p_kcat[ii])):
            final_p_kcat_f.extend(np.array(final_p_kcat[ii][jj].detach().cpu()).reshape(-1,2).tolist())
            final_p_kd_f.extend(np.array(final_p_kd[ii][jj].detach().cpu()).reshape(-1,2).tolist())
            final_p_ki_f.extend(np.array(final_p_ki[ii][jj].detach().cpu()).reshape(-1,2).tolist())
            final_p_km_f.extend(np.array(final_p_km[ii][jj].detach().cpu()).reshape(-1,2).tolist())
            final_p_IC50_f.extend(np.array(final_p_IC50[ii][jj].detach().cpu()).reshape(-1,2).tolist())
            final_p_EC50_f.extend(np.array(final_p_EC50[ii][jj].detach().cpu()).reshape(-1,2).tolist())
    return(final_p_kcat_f,final_p_kd_f,final_p_ki_f,final_p_km_f,final_p_IC50_f,final_p_EC50_f)

In [852]:
if selected_folder == 'model_ensemble':
    final_p_kcat_f_=[]
    final_p_kd_f_=[]
    final_p_ki_f_ =[]
    final_p_km_f_=[]
    final_p_IC50_f_=[]
    final_p_EC50_f_=[]
    for i in range(3):
        final_p_kcat_f,final_p_kd_f,final_p_ki_f,final_p_km_f,final_p_IC50_f,final_p_EC50_f = run_inference_predict(len(train_pr_1),i)
        final_p_kcat_f_.append(final_p_kcat_f)
        final_p_kd_f_.append(final_p_kd_f)
        final_p_ki_f_.append(final_p_ki_f)
        final_p_km_f_.append(final_p_km_f)
        final_p_IC50_f_.append(final_p_IC50_f)
        final_p_EC50_f_.append(final_p_EC50_f)
else:
    final_p_kcat_f,final_p_kd_f,final_p_ki_f,final_p_km_f,final_p_IC50_f,final_p_EC50_f = run_inference_predict(len(train_pr_1),0)
# # final_p_kcat_f_=np.mean(np.array(final_p_kcat_f_),axis=0)
# # final_p_kd_f_=np.mean(np.array(final_p_kd_f_),axis=0)
# # final_p_ki_f_=np.mean(np.array(final_p_ki_f_),axis=0)
# # final_p_km_f_=np.mean(np.array(final_p_km_f_),axis=0)
# # final_p_IC50_f_=np.mean(np.array(final_p_IC50_f_),axis=0)
# # final_p_EC50_f_=np.mean(np.array(final_p_EC50_f_),axis=0)

100%|██████████| 140/140 [00:07<00:00, 18.86it/s]
100%|██████████| 140/140 [00:07<00:00, 18.89it/s]
100%|██████████| 140/140 [00:07<00:00, 18.33it/s]


In [853]:
if selected_folder == 'model_ensemble':
    final_p_kcat_fn = np.array(final_p_kcat_f_).reshape(-1,2)
    final_p_kd_fn = np.array(final_p_kd_f_).reshape(-1,2)
    final_p_ki_fn = np.array(final_p_ki_f_).reshape(-1,2)
    final_p_km_fn = np.array(final_p_km_f_).reshape(-1,2)
    final_p_IC50_fn = np.array(final_p_IC50_f_).reshape(-1,2)
    final_p_EC50_fn = np.array(final_p_EC50_f_).reshape(-1,2)
else:  
    final_p_kcat_fn = np.array(final_p_kcat_f).reshape(-1,2)
    final_p_kd_fn = np.array(final_p_kd_f).reshape(-1,2)
    final_p_ki_fn = np.array(final_p_ki_f).reshape(-1,2)
    final_p_km_fn = np.array(final_p_km_f).reshape(-1,2)
    final_p_IC50_fn = np.array(final_p_IC50_f).reshape(-1,2)
    final_p_EC50_fn = np.array(final_p_EC50_f).reshape(-1,2)


In [854]:
out = np.concatenate((final_p_kcat_fn,final_p_kd_fn,final_p_ki_fn,final_p_km_fn,final_p_IC50_fn,final_p_EC50_fn),axis=0)
df  = pd.DataFrame(out,columns=['parameter','SD'])
df.to_csv(f'./path/to/pred{selected_folder}{itstr}')