In [None]:
import time
start_time = time.time()

In [None]:
%%time
import torch
import warnings
# build_vocab, pretrain_trfm, utils packages are from SMILES Transformer
from transformers import T5EncoderModel, T5Tokenizer
# transformers package is from ProtTrans
import re
import gc
import numpy as np
import pandas as pd
import pickle
import math
import yaml
import os
import sys
from pathlib import Path
warnings.filterwarnings(action='ignore', category=UserWarning)

# Define your custom path here, e.g., "custom/path/to/inputs.yml"
inputs_path = "../../analyses/iML1515/inputs"
notebook_dir = Path().resolve()

# If inputs_path is empty, use the INPUTS environment variable
if not inputs_path:
    inputs_path = os.getenv("INPUTS")
    if inputs_path is None:
        raise ValueError("The INPUTS environment variable is not set.")

inputs_file = os.path.join(inputs_path, "inputs.yml")
if not os.path.isfile(inputs_file):
    raise FileNotFoundError(f"The 'inputs.yml' file could not be found at {inputs_file}.")

with open(inputs_file, "r") as file:
    data = yaml.safe_load(file)

sbml_model = data["sbml_model"]
test_path = data["output_file_path"]
model_path = (notebook_dir / data["unikp_path"]).resolve()

sys.path.append(model_path)
from build_vocab import WordVocab
from utils import split

In [None]:
%%time
seqs_smiles_df = pd.read_csv(os.path.join(test_path, 'sequences_smiles.csv'))

with open(os.path.join(model_path, 'UniKP for kcat.pkl'), "rb") as f:
    model = pickle.load(f)

batch = 0
batch_len = 20

predicted_pair = {}

# Function to check if a pair of Substrate Smiles and Sequence has been assessed
def is_assessed(substrate_smiles, sequence):
    return (substrate_smiles, sequence) in predicted_pair

# Function to check if reaction is a transport reaction
transporters = ['transport', 'symporter', 'diffusion', 'antiport', 'tranposrt']
def contains_keywords(cell):
    return any(keyword.lower() in str(cell).lower() for keyword in transporters)

# Collect sequences and smiles in batches
sequences = []
smiles = []
indices = []

for index, row in seqs_smiles_df.iterrows():
    if type(row['Sequence']) != float and not contains_keywords(row['Reaction name']) and np.isnan(row['Kcat']):
        if row['Substrate Smiles'] != 'Compound not found':
            pair_key = (row['Substrate Smiles'], row['Sequence'])
            if is_assessed(*pair_key):
                seqs_smiles_df.at[index, 'Kcat'] = predicted_pair[pair_key]
            else:
                print(row['Reaction ID'])
                sequences.append(row['Sequence'])
                smiles.append(row['Substrate Smiles'])
                indices.append(index)

In [None]:
%%time
def process_sequence(seq):
    if len(seq) > 1000:
        return seq[:500] + seq[-500:]
    return seq

def Seq_to_vec(Sequence):
    sequences_Example = [' '.join(process_sequence(seq)) for seq in Sequence]
    num_sequences = len(sequences_Example)
    #print("Processed sequences:", sequences_Example)

    tokenizer = T5Tokenizer.from_pretrained(os.path.join(model_path, 'prot_t5_xl_uniref50'), do_lower_case=False)
    model = T5EncoderModel.from_pretrained(os.path.join(model_path, 'prot_t5_xl_uniref50'))
    gc.collect()
    print("Tokenizer and model loaded.")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
        batch_size = min(num_sequences, torch.cuda.device_count() * 8)  # Adjust batch size for multiple GPUs
    else:
        print("Let's use", os.cpu_count(), "CPUs!")
        batch_size = min(num_sequences, os.cpu_count())
    model = model.to(device).eval()
    print("Model moved to device and set to evaluation mode.")

    features = []

    for i in range(0, num_sequences, batch_size):
        print('Processing sequences from', str(i), 'to', str(i + batch_size))
        batch_sequences = sequences_Example[i:i + batch_size]
        batch_ids = tokenizer.batch_encode_plus(batch_sequences, add_special_tokens=True, padding=True)
        input_ids = torch.tensor(batch_ids['input_ids']).to(device)
        attention_mask = torch.tensor(batch_ids['attention_mask']).to(device)
        #print("Batch input IDs:", input_ids)
        #print("Batch attention mask:", attention_mask)

        with torch.no_grad():
            embedding = model(input_ids=input_ids, attention_mask=attention_mask)

        embedding = embedding.last_hidden_state
        for seq_num in range(embedding.size(0)):
            seq_len = (attention_mask[seq_num] == 1).sum()
            seq_emd = embedding[seq_num][:seq_len - 1]
            features.append(seq_emd)
        #print("Batch embeddings:", embedding)


    print("Finished processing sequences.")
    return features

# Process all sequences and smiles in one go
if sequences:
    features = Seq_to_vec(sequences)
    #print("Features:", features)

In [None]:
%%time
def normalize_feature(features):
    # Perform normalization directly on the GPU
    features_normalize = torch.stack([f.mean(dim=0) for f in features], dim=0)
    #print("Normalized features (GPU):", features_normalize)

    # Move features_normalize back to CPU if needed
    features_normalize = features_normalize.cpu().numpy()
    #print("Normalized features (CPU):", features_normalize)
    return features_normalize

if features:
    seq_vecs = normalize_feature(features)
    #print("Sequence vectors:", seq_vecs)

In [None]:
%%time
from pretrain_trfm import TrfmSeq2seq

def smiles_to_vec(Smiles, vocab, device, trfm):
    pad_index = 0
    unk_index = 1
    eos_index = 2
    sos_index = 3
    mask_index = 4

    def get_inputs(sm):
        seq_len = 220
        sm = sm.split()
        if len(sm)>218:
            print('SMILES is too long ({:d})'.format(len(sm)))
            sm = sm[:109]+sm[-109:]
        ids = [vocab.stoi.get(token, unk_index) for token in sm]
        ids = [sos_index] + ids + [eos_index]
        seg = [1]*len(ids)
        padding = [pad_index]*(seq_len - len(ids))
        ids.extend(padding), seg.extend(padding)
        return ids, seg

    def get_array(smiles):
        x_id, x_seg = [], []
        for sm in smiles:
            a, b = get_inputs(sm)
            x_id.append(a)
            x_seg.append(b)
        #print("Input arrays created:\nIDs:", x_id, "\nSegments:", x_seg)
        return torch.tensor(x_id).to(device), torch.tensor(x_seg).to(device)

    x_split = [split(sm) for sm in Smiles]
    xid, xseg = get_array(x_split)
    #print("SMILES split and converted to tensors:\nXID:", xid, "\nXSEG:", xseg)
    X = trfm.encode(torch.t(xid).to(device))
    #print("Encoding complete:\nEncoded Tensor:", X)
    return X.squeeze(0)

if smiles:
    #print("SMILES:", smiles)
    vocab = WordVocab.load_vocab(os.path.join(model_path, 'vocab.pkl'))
    #print("Vocabulary loaded:", vocab.stoi)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    trfm = TrfmSeq2seq(len(vocab), 256, len(vocab), 4)
    # Use map_location to ensure the model is loaded on the device if GPU is not available
    trfm.load_state_dict(torch.load(os.path.join(model_path, 'trfm_12_23000.pkl'), map_location=device))
    trfm.to(device)
    trfm.eval()
    print("Transformer model loaded and set to evaluation mode.")

    smiles_vecs = [smiles_to_vec([sm], vocab, device, trfm) for sm in smiles]
    #print("SMILES vectors:", smiles_vecs)

In [None]:
pd.DataFrame(smiles_vecs)

In [None]:
if sequences and smiles:
    %time fused_vectors = np.concatenate((smiles_vecs, seq_vecs), axis=1)
    print("fused_vectors time above")
    %time pre_kcats = model.predict(fused_vectors)
    pd.DataFrame(pre_kcats)
    print("pre_kcats time above")
    %time kcates = [math.pow(10, pre_kcats[i]) for i in range(len(pre_kcats))]
    pd.DataFrame(kcates)
    print("kcates time above")

    for i, index in enumerate(indices):
        seqs_smiles_df.at[index, 'Kcat'] = kcates[i]
        pair_key = (smiles[i], sequences[i])
        predicted_pair[pair_key] = kcates[i]

    # Save the DataFrame periodically
    batch += 1
    if batch == batch_len:
        seqs_smiles_df.to_csv(os.path.join(test_path, 'sequences_smiles_complete.csv'), index=False)
        batch = 0

seqs_smiles_df.to_csv(os.path.join(test_path, 'sequences_smiles_complete.csv'), index=False)

In [None]:
end_time = time.time()
elapsed_time = end_time - start_time
print(f'Total execution time: {elapsed_time:.2f} seconds')