# Train Mondo Annotations

In [167]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from fetch import *
import pandas as pd

## Load model and tokenizer

In [168]:
def split_data(embeddings, labels):
    # Split data into 80% training and 20% test
    embeddings_train, embeddings_temp, labels_train, labels_temp = train_test_split(
        embeddings, labels, test_size=0.20, random_state=42)

    # Split the 20% test into 15% validation and 5% test
    embeddings_val, embeddings_test, labels_val, labels_test = train_test_split(
        embeddings_temp, labels_temp, test_size=0.25, random_state=42)  # 0.25 * 0.20 = 0.05

    return embeddings_train, labels_train, embeddings_val, labels_val, embeddings_test, labels_test

## Prepare dataset

In [169]:
class ProteinDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

## Create a classifier

In [170]:
class ProteinTaggingModel(nn.Module):
    def __init__(self, embedding_size, num_labels, dropout_rate=0.1):
        super(ProteinTaggingModel, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc1 = nn.Linear(embedding_size, embedding_size)  # Adjust this to your combined embedding size
        self.relu = nn.ReLU()
        self.norm1 = nn.LayerNorm(embedding_size)
        self.fc2 = nn.Linear(embedding_size, embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)
        self.classifier = nn.Linear(embedding_size, num_labels)
        self.output_fc = nn.Linear(num_labels, num_labels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, embeddings):
        x = self.dropout(embeddings)
        x = self.relu(self.fc1(x))
        x = self.norm1(x + embeddings)  # Ensure `embeddings` is broadcastable with `x`
        x = self.relu(self.fc2(x))
        x = self.norm2(x + embeddings)  # Ensure `embeddings` is broadcastable with `x`
        logits = self.classifier(x)
        logits = self.output_fc(logits)
        predictions = self.sigmoid(logits)
        return predictions

## Train model

In [171]:
# load dataset
embedding_type = 'esm2_embedding'
embeddings, labels, annotations_vocab = fetch_data_multi(embedding_type, include_empty=True)

# split data
embeddings_train, labels_train, embeddings_val, labels_val, embeddings_test, labels_test = split_data(embeddings, labels)

# dataloaders
train_dataset = ProteinDataset(embeddings_train, labels_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = ProteinDataset(embeddings_val, labels_val)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [172]:
# Setup DataLoader, Model, Loss, and Optimizer
dataset = ProteinDataset(embeddings, labels)

model = ProteinTaggingModel(embedding_size=embeddings.shape[1], num_labels=labels.shape[1])

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs):
    best_val_loss = float('inf')
    best_model = None

    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for embeddings, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(embeddings)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for embeddings, labels in val_loader:
                outputs = model(embeddings)
                loss = criterion(outputs, labels.float())
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        print(f'Epoch {epoch+1}, Training Loss: {total_train_loss / len(train_loader)}, Validation Loss: {avg_val_loss}')

        # Save the model if it has a lower validation loss than the current best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model = model

    return best_model

In [173]:
# Start training
best_model = train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=60)

Epoch 1, Training Loss: 0.014411321849288854, Validation Loss: 0.008011250075150207
Epoch 2, Training Loss: 0.006219276098478863, Validation Loss: 0.004987335706820453
Epoch 3, Training Loss: 0.004191519761173082, Validation Loss: 0.004228843046839515
Epoch 4, Training Loss: 0.0033631189656791404, Validation Loss: 0.003225133604746287
Epoch 5, Training Loss: 0.0028443158781144362, Validation Loss: 0.0027205962818794395
Epoch 6, Training Loss: 0.002313376286264569, Validation Loss: 0.0027003002635184797
Epoch 7, Training Loss: 0.002427091601215516, Validation Loss: 0.002797685613676856
Epoch 8, Training Loss: 0.001902282802852575, Validation Loss: 0.0022725988628916343
Epoch 9, Training Loss: 0.0016140988558419026, Validation Loss: 0.0023721069816826836
Epoch 10, Training Loss: 0.0015807100518479764, Validation Loss: 0.002255753052362321
Epoch 11, Training Loss: 0.0015717698091007445, Validation Loss: 0.002279694847044048
Epoch 12, Training Loss: 0.0013243764546521433, Validation Loss: 

## Evaluation

In [174]:
def evaluate_model(best_model, test_loader):
    best_model.eval()
    predictions = []
    truths = []
    with torch.no_grad():
        for embeddings, labels in test_loader:
            outputs = best_model(embeddings)
            predicted = torch.round(outputs)
            predictions.extend(predicted.cpu().numpy())
            truths.extend(labels.cpu().numpy())

    accuracy = accuracy_score(truths, predictions)
    precision = precision_score(truths, predictions, average='macro', zero_division=1)
    recall = recall_score(truths, predictions, average='macro', zero_division=1)
    f1 = f1_score(truths, predictions, average='macro', zero_division=1)

    return accuracy, precision, recall, f1

test_dataset = ProteinDataset(embeddings_test, labels_test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
accuracy, precision, recall, f1 = evaluate_model(best_model, test_loader)
print(f'Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}')

Accuracy: 0.8674136321195145, Precision: 0.9920378615188716, Recall: 0.9801372117104146, F1 Score: 0.9771470404427797


## Inference

In [175]:
def create_reverse_vocab(annotations_vocab):
    return {idx: label for label, idx in annotations_vocab.items()}

def convert_to_labels(binary_vector, reverse_vocab):
    labels = [reverse_vocab[i] for i, value in enumerate(binary_vector) if value == 1]
    return labels

def intersection_and_difference(true_labels, predicted_labels):
    # Calculate intersection and differences of two label lists
    true_set = set(true_labels)
    predicted_set = set(predicted_labels)
    intersection = true_set & predicted_set
    incorrect = predicted_set - true_set
    return len(intersection), len(true_set), len(predicted_set), len(incorrect)

def predict_and_evaluate(best_model, test_loader, annotations_vocab, num_samples=10):
    best_model.eval()
    total_intersection = 0
    total_true_labels = 0
    total_predicted_labels = 0
    total_incorrect = 0
    num_zero_prediction = 0  # To track samples with zero predicted labels
    sample_data = []
    reverse_vocab = create_reverse_vocab(annotations_vocab)

    with torch.no_grad():
        for embeddings, labels in test_loader:
            outputs = best_model(embeddings)
            predicted = torch.round(outputs)  # Using 0.5 as a threshold
            # Convert binary vectors to MONDO names
            predicted_labels = [convert_to_labels(pred, reverse_vocab) for pred in predicted.cpu().numpy()]
            true_labels = [convert_to_labels(true, reverse_vocab) for true in labels.cpu().numpy()]
            
            for true, pred in zip(true_labels, predicted_labels):
                inter, true_count, pred_count, incorrect = intersection_and_difference(true, pred)
                total_intersection += inter
                total_true_labels += true_count
                total_predicted_labels += pred_count
                total_incorrect += incorrect
                if len(pred) == 0:
                    num_zero_prediction += 1  # Increment if no labels were predicted
            
            # Collect samples for display
            if len(sample_data) < num_samples:
                sample_data.extend(zip(embeddings, true_labels, predicted_labels))
            if len(sample_data) >= num_samples:
                break
    
    # Calculate averages
    num_samples = len(test_loader.dataset)
    average_correct = (total_intersection / num_samples) * 100
    average_incorrect = (total_incorrect / num_samples) * 100
    print(f"Average percentage of correct MONDO names per sample: {average_correct:.2f}%")
    print(f"Average percentage of incorrect MONDO names per sample: {average_incorrect:.2f}%")
    print(f"Number of samples with zero predicted labels: {num_zero_prediction}")

    return sample_data, num_zero_prediction

# Example usage assuming model, test_loader, and annotations_vocab are defined
sample_data, num_zero_prediction = predict_and_evaluate(best_model, test_loader, annotations_vocab, num_samples=10)

# Print the samples and the total number of zero predicted labels
for idx, (embedding, true_label, predicted_label) in enumerate(sample_data):
    print(f"Sample {idx+1}")
    print(f"True MONDO Names: {true_label}")
    print(f"Predicted MONDO Names: {predicted_label}\n")
print(f"Total samples with zero predictions: {num_zero_prediction}")

Average percentage of correct MONDO names per sample: 49.30%
Average percentage of incorrect MONDO names per sample: 0.59%
Number of samples with zero predicted labels: 4
Sample 1
True MONDO Names: []
Predicted MONDO Names: []

Sample 2
True MONDO Names: ['osteosarcoma', 'ovarian cancer']
Predicted MONDO Names: ['osteosarcoma', 'ovarian cancer']

Sample 3
True MONDO Names: ['prostate carcinoma', 'colorectal cancer', 'classic maple syrup urine disease', 'intermittent maple syrup urine disease', 'myocardial infarction', 'intermediate maple syrup urine disease', 'heart disease', 'maple syrup urine disease', 'inborn errors of metabolism', 'intellectual disability', 'inherited organic acidemia', 'osteoarthritis']
Predicted MONDO Names: ['prostate carcinoma', 'colorectal cancer', 'classic maple syrup urine disease', 'intermittent maple syrup urine disease', 'myocardial infarction', 'intermediate maple syrup urine disease', 'heart disease', 'maple syrup urine disease', 'inborn errors of metab

## Save model as Torchfile

In [176]:
best_model.eval()
example = torch.rand(1, 1280)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("esm_model.pt")

### Ensure model was properly saved and is working

In [177]:
def create_reverse_vocab(annotations_vocab):
    return {idx: label for label, idx in annotations_vocab.items()}

def convert_to_labels(binary_vector, reverse_vocab):
    binary_vector = binary_vector.tolist()
    labels = [reverse_vocab[i] for i, value in enumerate(binary_vector) if value == 1]
    return labels

def predict_single_sample(best_model, sample_embedding, annotations_vocab):
    best_model.eval()
    reverse_vocab = create_reverse_vocab(annotations_vocab)

    with torch.no_grad():
        outputs = best_model(sample_embedding.unsqueeze(0))  # Add an extra dimension for batch
        predicted = torch.round(outputs)  # Using 0.5 as a threshold
        # Convert binary vectors to MONDO names
        predicted_labels = [convert_to_labels(pred, reverse_vocab) for pred in predicted.cpu().numpy()]

    return predicted_labels[0]  # Return the first (and only) element

# Load the saved model
loaded_model = torch.jit.load("esm_model.pt")

# Extract a single batch from the test loader
embeddings, labels = next(iter(test_loader))

# Extract a single sample from the batch
sample_index = 0  # Replace with the index of the desired sample
sample_embedding = embeddings[sample_index]

# Perform inference on the sample
predicted_labels = predict_single_sample(model, sample_embedding, annotations_vocab)

print(predicted_labels)

[]


## Save Testing Dataset as CSV

In [178]:
def get_test_data(embeddings_test):
    """
    For each document in the MongoDB database where the embedding in embeddings_test corresponds to func_embedding, save the 'function' and 'mondo_names' to a pandas DataFrame.
    :param embeddings_test: The embeddings used for testing
    :return: A pandas DataFrame containing the 'function' and 'mondo_names' for each document
    """
    # Load environment variables
    load_dotenv()
    MONGO_URI = os.getenv("MONGODB_URI")
    MONGO_DB = "proteinExplorer"
    MONGO_COLLECTION = "protein_embeddings"

    # Connect to MongoDB
    client = MongoClient(MONGO_URI)
    db = client[MONGO_DB]
    collection = db[MONGO_COLLECTION]

    ids = []
    function = []
    labels = []

    for embedding in embeddings_test:
        document = collection.find_one({"esm2_embedding": embedding.tolist()})
        ids.append(document["_id"])
        function.append(document["sequence"])
        labels.append(document["mondo_names"])

    data = {"id": ids, "function": function, "mondo_names": labels}

    return pd.DataFrame(data)

In [179]:
test_df = get_test_data(embeddings_test)
test_df.head()

Unnamed: 0,id,function,mondo_names
0,29574,MASKTKASEALKVVARCRPLSRKEEAAGHEQILTMDVKLGQVTLRN...,[]
1,967,MKSSVAQIKPSSGHDRRENLNSYQRNSSPEDRYEEQERSPRDRDYF...,"[osteosarcoma, ovarian cancer]"
2,4363,MAVAIAAARVWRLNRGLSQAALLLLRQPGARGLARSHPPRQQQQFS...,"[maple syrup urine disease, classic maple syru..."
3,1333,MATSSEEVLLIVKKVRQKKQDGALYLMAERIAWAPEGKDRFTISHM...,"[attention deficit hyperactivity disorder, ina..."
4,871,MACTIQKAEALDGAHLMQILWYDEEESLYPAVWLRDNCPCSDCYLD...,"[astrocytic tumor, breast carcinoma, glioblast..."


In [180]:
# convert test_df to a csv file
test_df.to_csv("esm_test_data.csv", index=False)

## Save Reverse Vocab

In [186]:
import json

def create_reverse_vocab(annotations_vocab):
    reverse_vocab = {idx: label for label, idx in annotations_vocab.items()}
    
    # Save reverse_vocab to a JSON file
    with open('esm_reverse_vocab.json', 'w') as f:
        json.dump(reverse_vocab, f)

create_reverse_vocab(annotations_vocab)

## Convert model to ONNX format

In [181]:
import onnxruntime as ort
import numpy as np

# Extract a single sample from the test loader
embeddings, labels = next(iter(test_loader))
sample_embedding = embeddings[0]
onnx_program = torch.onnx.dynamo_export(best_model, sample_embedding)
onnx_program.save("esm_model.onnx")



In [182]:
import onnxruntime as ort
import numpy as np
import torch

# Load the model
session = ort.InferenceSession("esm_model.onnx")

# Get the name of the model input
input_name = session.get_inputs()[0].name
idx = 6

# Convert PyTorch tensor to numpy array if needed
if isinstance(embeddings[1], torch.Tensor):
    input_tensor = embeddings[idx].numpy()
else:
    input_tensor = embeddings[idx]

# Check if input_tensor is already 1-dimensional
if len(input_tensor.shape) > 1:
    # If not, we flatten the tensor to make it 1-dimensional
    input_tensor = input_tensor.flatten()

# Prepare the input dictionary
input_dict = {input_name: input_tensor}

# Run the model and get the output
output = session.run(None, input_dict)

# Assuming output[0] is the logits and your model outputs logits directly without activation
# Assuming output[0] is the logits and your model outputs logits directly without activation
logits = output[0]  # Adjust if your model outputs are structured differently
probabilities = torch.sigmoid(torch.from_numpy(logits))  # Applying sigmoid since it's a multi-label classification
rounded_predictions = probabilities.round().numpy()  # Rounding the probabilities to get binary predictions

# Load the reverse vocabulary and convert predictions to labels
reverse_vocab = create_reverse_vocab(annotations_vocab)
predicted_labels = convert_to_labels(rounded_predictions, reverse_vocab)

print("Predicted labels:", predicted_labels)
len(predicted_labels)

Predicted labels: ['leukoencephalopathy, progressive, with ovarian failure', 'keratoconus', 'pancreatic agenesis', 'alcohol-related disorders', 'cystinuria', 'spondylolisthesis', '46,XX sex reversal 1', 'cocaine abuse', 'vertebral, cardiac, renal, and limb defects syndrome 2', 'gastric adenocarcinoma', 'epidermolysis bullosa simplex with circinate migratory erythema', 'Diamond-Blackfan anemia', 'myeloid leukemia', 'classic phenylketonuria', 'thyrotoxicosis', 'arrhythmogenic right ventricular cardiomyopathy', 'intellectual disability, X-linked 97', 'pulmonary fibrosis', 'autosomal dominant polycystic kidney disease', 'diverticulitis', 'familial hypertryptophanemia', 'neuroblastoma', 'familial schizencephaly', 'granulomatous disease, chronic, X-linked', 'hereditary spherocytosis type 5', 'carcinosarcoma', 'fatal mitochondrial disease due to combined oxidative phosphorylation defect type 3', 'renal tubular dysgenesis', 'autosomal recessive cutis laxa type 2C', 'Helicobacter pylori infecti

4549

In [183]:
# get true labels from embeddings[1]
true_labels = [convert_to_labels(labels[idx].numpy(), reverse_vocab)]
print("True labels:", true_labels[0])
len(true_labels[0])

True labels: ['breast cancer', 'pancreatic ductal adenocarcinoma', 'juvenile dermatomyositis', 'plasma cell myeloma', 'Duchenne muscular dystrophy', 'swine influenza', 'chronic kidney disease', 'medulloblastoma', 'osteosarcoma', 'leukemia, acute lymphocytic, susceptibility to, 1', 'rheumatoid arthritis', 'ovarian cancer', 'intellectual disability, autosomal dominant 42', 'depressive disorder', 'endogenous depression', 'neurotic disorder', 'myelodysplastic syndrome', 'melancholia', 'intellectual disability', 'anxiety disorder', 'malignant pancreatic neoplasm', 'lung adenocarcinoma', 'Sjogren syndrome', 'subependymal giant cell astrocytoma', 'myopathy', 'childhood acute lymphoblastic leukemia']


26

In [184]:
# get the intersection and difference between the true and predicted labels
intersection, true_count, pred_count, incorrect = intersection_and_difference(true_labels[0], predicted_labels)
print(f"Intersection: {intersection}, Number of Correct: {true_count}, Total Tags: {pred_count}, Number of Incorrect: {incorrect}")

Intersection: 26, Number of Correct: 26, Total Tags: 4549, Number of Incorrect: 4523
