In [1]:
import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib

from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.regularizers import l2
from keras.callbacks import EarlyStopping
from keras.models import load_model

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score, log_loss, precision_score, recall_score, roc_auc_score

In [2]:
# Adjust display settings
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

In [3]:
# Define the file path
file_path = 'data/results_LSS_updated'

# Initialize lists to store sequences and the last column values
sequences = []
last_column_values = []

# Open the file and read line by line
with open(file_path, 'r') as file:
    for line in file:
        parts = line.strip().split()
        if len(parts) >= 2:  # Ensure the line has at least two parts
            sequence = parts[1]  # Sequence (2nd column)
            last_value = int(parts[-1])  # Last column (number)
            
            sequences.append(sequence[:-1])
            last_column_values.append(last_value)

df = pd.DataFrame({
    'sequence': sequences,
    'Status': last_column_values
})

In [4]:
df = df.reset_index(drop=True)
print(df.head(5))
df.shape

   sequence  Status
0  FEFKEKFF       0
1  FKFEEKFF       0
2  EFKEKFFF       0
3  FFFKEKFE       0
4  KFFKEFFE       0


(88, 2)

In [5]:
df_filtered = df[df['Status'].isin([0, 2, 3, 4])].copy() 

# Add 'fibril_state' column based on 'Status' values
df_filtered['fibril_state'] = df_filtered['Status'].apply(lambda x: 'fibril'
 if x in [2, 3, 4] else 'non-fibril')

df_filtered = df_filtered.reset_index(drop=True)
print(df_filtered.head())
df_filtered.shape

   sequence  Status fibril_state
0  FEFKEKFF       0   non-fibril
1  FKFEEKFF       0   non-fibril
2  EFKEKFFF       0   non-fibril
3  FFFKEKFE       0   non-fibril
4  KFFKEFFE       0   non-fibril


(72, 3)

In [6]:
#@title Plot Training History
def plot_training_history(history, fold):
  fig, axs = plt.subplots(1, 2, figsize=(12, 4))

  axs[0].plot(history.history['accuracy'])
  axs[0].plot(history.history['val_accuracy'])
  if isinstance(fold, int):
    axs[0].set_title('Model accuracy (Fold %d)' % fold)
  else:
    axs[0].set_title('Model accuracy')
  axs[0].set_ylabel('Accuracy')
  axs[0].set_xlabel('Epoch')
  axs[0].legend(['Train', 'Validation'], loc='upper left')
  axs[0].grid(linestyle=':')

  axs[1].plot(history.history['loss'])
  axs[1].plot(history.history['val_loss'])
  if isinstance(fold, int):
    axs[1].set_title('Model loss (Fold %d)' % fold)
  else:
    axs[1].set_title('Model loss')
  axs[1].set_ylabel('Loss')
  axs[1].set_xlabel('Epoch')
  axs[1].legend(['Train', 'Validation'], loc='upper left')
  axs[1].grid(linestyle=':')

  plt.tight_layout()
  plt.show()

In [7]:
#@title Plot Sorted Distribution of Fibril / Non-fibril Ratio
def plot_ratio(df):
  sorted_motifs = df['ratio'].sort_values(ascending=False).index.tolist()
  melted_df = df[['fibril_prob', 'non_fibril_prob']].reset_index().melt(id_vars='motifs', value_vars=['fibril_prob', 'non_fibril_prob'], var_name='fibril_state', value_name='value')
  melted_df['motifs'] = pd.Categorical(melted_df['motifs'], categories=sorted_motifs, ordered=True)
  melted_df = melted_df.sort_values('motifs')
  plt.figure(figsize=(15, 7))
  sns.barplot(x='motifs', y='value', hue='fibril_state', data=melted_df, palette=['cyan', 'orange'])
  plt.grid(linewidth=1, linestyle=':', axis='y')
  plt.xticks(rotation=45)
  plt.ylabel('Probability', fontsize=16)
  plt.xlabel('Motifs', fontsize=16)
  plt.title('Probability of Amino Acid Motifs in Fibril and Non-fibril States, Ranked by Fibril/Non-fibril Ratio', fontsize=18)
  plt.tight_layout()
  plt.show()

In [8]:
import torch
esm2_model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t6_8M_UR50D") # esm2_t6_8M_UR50D/esm2_t12_35M_UR50D
esm2_model.eval()

Using cache found in C:\Users\alan_/.cache\torch\hub\facebookresearch_esm_main


ESM2(
  (embed_tokens): Embedding(33, 320, padding_idx=1)
  (layers): ModuleList(
    (0-5): 6 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=320, out_features=320, bias=True)
        (v_proj): Linear(in_features=320, out_features=320, bias=True)
        (q_proj): Linear(in_features=320, out_features=320, bias=True)
        (out_proj): Linear(in_features=320, out_features=320, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=320, out_features=1280, bias=True)
      (fc2): Linear(in_features=1280, out_features=320, bias=True)
      (final_layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=120, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((320,), eps=1e-05, elementwis

In [9]:
def get_esm2_embedding(sequence, model, alphabet, avg=True):
    batch_converter = alphabet.get_batch_converter()
    data = [("protein", sequence)] 
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[6], return_contacts=False)
    token_representations = results["representations"][6]
    sequence_embedding = token_representations[0, 1:len(sequence) + 1].cpu().numpy()

    if avg:
        sequence_embedding = token_representations[0, 1:len(sequence) + 1].mean(0).cpu().numpy()
    return sequence_embedding

In [10]:
def save_results(cv_models, metrics, scalers, save_dir=None):
    if save_dir is None:
        save_dir = 'trained_models'
    os.makedirs(save_dir, exist_ok=True)
        
    for i, model in enumerate(cv_models):
        model.save(f'{save_dir}/model_fold_{i}.h5') 
    joblib.dump(metrics, f'{save_dir}/metrics.pkl')
    joblib.dump(scalers, f'{save_dir}/scalers.pkl')

In [11]:
def load_results(load_dir=None):
    if load_dir is None:
        load_dir = 'trained_models'
    cv_models = []
    for i in range(5):
        model = load_model(f'{load_dir}/model_fold_{i}.h5')
        cv_models.append(model)
    metrics = joblib.load(f'{load_dir}/metrics.pkl')
    scalers = joblib.load(f'{load_dir}/scalers.pkl')
    return scalers, cv_models, metrics

In [12]:
def print_stats(metrics):
    for metric_name, metric_values in metrics.items():
        mean_value = np.mean(metric_values) 
        print(f"{metric_name} mean: {mean_value:.4f}")

In [13]:
#@title N Fold Cross Validation MLP
def n_fold_cross_validation(df_filtered, n_splits=5, random_state=42, epochs=200, batch_size=10, plot=True, save_recall=False):
  kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

  accuracies = []
  precisions = []
  recalls = []
  f1_scores = []
  aucs = []
  losses = []
  models = []
  scalers = []

  # Add 'isFibril' (Label) column - fibril: 1; non-fibril: 0
  data = df_filtered.copy()
  data['isFibril'] = data['fibril_state'].apply(lambda x: 1 if x == 'fibril' else 0)
  labels = np.array(data['isFibril'])

  fold = 0
  for train_index, val_index in kfold.split(data, labels):

    fold += 1
    print(f'Fold {fold}')

    df = df_filtered.copy()

    labels_train, labels_val = labels[train_index], labels[val_index]

    df['train_seq'] = ''
    df['test_seq'] = ''
    df.loc[train_index, 'train_seq'] = df.loc[train_index, 'sequence']
    df.loc[val_index, 'test_seq'] = df.loc[val_index, 'sequence']

    sequences_train = df[df['train_seq'] != '']['sequence'].values
    sequences_val = df[df['test_seq'] != '']['sequence'].values

    esm2_train_embeddings = np.array([get_esm2_embedding(seq, esm2_model, alphabet) for seq in sequences_train])
    esm2_val_embeddings = np.array([get_esm2_embedding(seq, esm2_model, alphabet) for seq in sequences_val])
    
    data_train = esm2_train_embeddings
    data_val = esm2_val_embeddings
    input_dim = 320

    scaler = StandardScaler()
    data_train = scaler.fit_transform(data_train)
    data_val = scaler.transform(data_val)
    scalers.append(scaler)

    model = Sequential([
        Dense(256, activation='relu', input_dim=input_dim, kernel_regularizer=l2(0.01)),
        Dropout(0.2), 
        Dense(128, activation='relu', kernel_regularizer=l2(0.01)),
        Dropout(0.2),
        Dense(64, activation='relu', kernel_regularizer=l2(0.01)),
        Dropout(0.2),
        Dense(32, activation='relu', kernel_regularizer=l2(0.01)),
        Dropout(0.2),
        Dense(1, activation='sigmoid')
    ])

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)

    # Train the model
    history = model.fit(data_train, labels_train, epochs=epochs, batch_size=batch_size,
              validation_data=(data_val, labels_val),
              callbacks=[early_stopping],
              verbose=0) # verbose = 1 default

    models.append(model)
    pred_prob = model.predict(data_val).flatten()

    if save_recall:
      save_path = f'recall/fold_{fold}_fibril_probabilities.xlsx'
      os.makedirs(os.path.dirname(save_path), exist_ok=True)
      df_results = pd.DataFrame({
          'sequence': sequences_val,
          'fibril_probability': pred_prob
          })
      df_results.to_excel(save_path, index=False)

    pred_prob_binary = (pred_prob > 0.5).astype(int)

    # Calculate Accuracy, F1 Score, and Loss
    acc = accuracy_score(labels_val, pred_prob_binary)
    precision = precision_score(labels_val, pred_prob_binary)
    recall = recall_score(labels_val, pred_prob_binary)
    f1s = f1_score(labels_val, pred_prob_binary)
    auc = roc_auc_score(labels_val, pred_prob)
    loss = log_loss(labels_val, pred_prob)

    accuracies.append(acc)
    precisions.append(precision)
    recalls.append(recall)
    f1_scores.append(f1s)
    aucs.append(auc)
    losses.append(loss)
    print()

    if plot:
      plot_training_history(history, fold)

  metrics = {
    "Accuracy": accuracies,
    "Precision": precisions,
    "Recall": recalls,
    "F1-Scores": f1_scores,
    "AUC": aucs,
    "Loss": losses,
  }

  print_stats(metrics)

  return scalers, models, metrics

In [14]:
# scalers, models, metrics = n_fold_cross_validation(df_filtered, epochs=500, batch_size=16, random_state=91, plot=False)

In [15]:
scalers, cv_models, metrics = load_results(load_dir='trained_models')

In [16]:
print_stats(metrics)

Accuracy mean: 0.9438
Precision mean: 0.9100
Recall mean: 0.9100
F1-Scores mean: 0.9056
AUC mean: 0.9390
Loss mean: 0.2821


In [17]:
new_seq_file_path = "data/420seq.csv"

used_sequences = df_filtered['sequence'].tolist()

all_seq_df = pd.read_csv(new_seq_file_path, header=None)
column_names = ['sequence']
all_seq_df.columns = column_names
all_sequences = all_seq_df[column_names[0]].to_list()
filtered_all_seq_df = all_seq_df[~all_seq_df['sequence'].isin(used_sequences)].reset_index(drop=True)

print(filtered_all_seq_df.head())
filtered_all_seq_df.shape

   sequence
0  FFFFKEKE
1  FFFFKEEK
2  FFFFEKKE
3  FFFKFEKE
4  FFFKKFEE


(348, 1)

In [18]:
def fibril_prediction(seq_df, scalers, cv_models, pred_dir=None):
    if pred_dir is None:
        pred_dir = 'pred_fibril_probs'
    os.makedirs(pred_dir, exist_ok=True)
    
    index = 1
    all_probabilities_df = pd.DataFrame()

    for model, scaler in zip(cv_models, scalers):
        
        features = np.array([get_esm2_embedding(seq, esm2_model, alphabet) for seq in seq_df['sequence']])
        features_scaled = scaler.transform(features)

        predictions = model.predict(features_scaled)
        probabilities = [prob[0] for prob in predictions]
        predicted_df = pd.DataFrame({'sequence': seq_df['sequence'].to_list(), f'fibril_probability_model_{index}': probabilities})
        
        if all_probabilities_df.empty:
            all_probabilities_df = predicted_df
        else:
            all_probabilities_df = all_probabilities_df.merge(predicted_df, on='sequence')
        predicted_df.to_csv(f'{pred_dir}/predicted_fibril_state_{index}.csv')
        index += 1

    all_probabilities_df['average_fibril_probability'] = all_probabilities_df.filter(like='fibril_probability_').mean(axis=1)

    sorted_df = all_probabilities_df.sort_values(by='average_fibril_probability', ascending=False)

    sorted_df[['sequence', 'average_fibril_probability']].to_csv(f'{pred_dir}/average_fibril_probability.csv', index=False)

In [19]:
fibril_prediction(filtered_all_seq_df, scalers, cv_models, pred_dir='all_seq_pred')



In [20]:
test_seq_dict = {
    "sequence": [
        "EFKFFKFE", "FKFFKFEE", "KFFKFEFE", "EKFFKFEF", 
        "FEFEFKFK", "FKFEKFEF", "KFEKFEFF", "EEKFFKFF"
    ]
}

df_test_seq = pd.DataFrame(test_seq_dict)
fibril_prediction(df_test_seq, scalers, cv_models, pred_dir='test_seq_pred')

