# Train Mondo Annotations

In [6]:
from transformers import AutoTokenizer, AutoModel
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from fetch import *
import pandas as pd

## Load model and tokenizer

In [2]:
def split_data(embeddings, labels):
    # Split data into 80% training and 20% test
    embeddings_train, embeddings_temp, labels_train, labels_temp = train_test_split(
        embeddings, labels, test_size=0.20, random_state=42)

    # Split the 20% test into 15% validation and 5% test
    embeddings_val, embeddings_test, labels_val, labels_test = train_test_split(
        embeddings_temp, labels_temp, test_size=0.25, random_state=42)  # 0.25 * 0.20 = 0.05

    return embeddings_train, labels_train, embeddings_val, labels_val, embeddings_test, labels_test

## Prepare dataset

In [3]:
class ProteinDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

## Create a classifier

In [4]:
class NaiveProteinTaggingModel(nn.Module):
    def __init__(self, embedding_size, num_labels, dropout_rate=0.1):
        super(NaiveProteinTaggingModel, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.fc1 = nn.Linear(embedding_size, 128)  # Reduce the size and complexity
        self.relu = nn.ReLU()
        self.classifier = nn.Linear(128, num_labels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, embeddings):
        x = self.dropout(embeddings)
        x = self.relu(self.fc1(x))
        logits = self.classifier(x)
        predictions = self.sigmoid(logits)
        return predictions

## Train model

In [7]:
# load dataset
embedding_type = 'func_embedding'
embeddings, labels, annotations_vocab = fetch_data_multi(embedding_type, include_empty=True)

# split data
embeddings_train, labels_train, embeddings_val, labels_val, embeddings_test, labels_test = split_data(embeddings, labels)

# dataloaders
train_dataset = ProteinDataset(embeddings_train, labels_train)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = ProteinDataset(embeddings_val, labels_val)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [8]:
# Setup DataLoader, Model, Loss, and Optimizer
dataset = ProteinDataset(embeddings, labels)

model = NaiveProteinTaggingModel(embedding_size=embeddings.shape[1], num_labels=labels.shape[1])

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs):
    for epoch in range(epochs):
        model.train()
        total_train_loss = 0
        for embeddings, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(embeddings)
            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for embeddings, labels in val_loader:
                outputs = model(embeddings)
                loss = criterion(outputs, labels.float())
                total_val_loss += loss.item()

        print(f'Epoch {epoch+1}, Training Loss: {total_train_loss / len(train_loader)}, Validation Loss: {total_val_loss / len(val_loader)}')

In [9]:
# Start training
train_and_validate(model, train_loader, val_loader, criterion, optimizer, epochs=60)

Epoch 1, Training Loss: 0.038854435833393416, Validation Loss: 0.01694928246568765
Epoch 2, Training Loss: 0.01562373916877999, Validation Loss: 0.01399101394641873
Epoch 3, Training Loss: 0.013394658546431444, Validation Loss: 0.012492694645281265
Epoch 4, Training Loss: 0.012103700135783548, Validation Loss: 0.011425377122613769
Epoch 5, Training Loss: 0.011068092218938456, Validation Loss: 0.010505832605285953
Epoch 6, Training Loss: 0.010205182747275972, Validation Loss: 0.009747958455284108
Epoch 7, Training Loss: 0.009493617834384316, Validation Loss: 0.009107810102822567
Epoch 8, Training Loss: 0.008886558259721494, Validation Loss: 0.008565446969417745
Epoch 9, Training Loss: 0.008373719731586711, Validation Loss: 0.008108226249352197
Epoch 10, Training Loss: 0.00793885933733383, Validation Loss: 0.007712452212928344
Epoch 11, Training Loss: 0.007574724283680987, Validation Loss: 0.007346264058909076
Epoch 12, Training Loss: 0.007254518706939966, Validation Loss: 0.007080345911

## Evaluation

In [10]:
def evaluate_model(model, test_loader):
    model.eval()
    predictions = []
    truths = []
    with torch.no_grad():
        for embeddings, labels in test_loader:
            outputs = model(embeddings)
            predicted = torch.round(outputs)
            predictions.extend(predicted.cpu().numpy())
            truths.extend(labels.cpu().numpy())

    accuracy = accuracy_score(truths, predictions)
    precision = precision_score(truths, predictions, average='macro', zero_division=1)
    recall = recall_score(truths, predictions, average='macro', zero_division=1)
    f1 = f1_score(truths, predictions, average='macro', zero_division=1)

    return accuracy, precision, recall, f1

test_dataset = ProteinDataset(embeddings_test, labels_test)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
accuracy, precision, recall, f1 = evaluate_model(model, test_loader)
print(f'Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1 Score: {f1}')

Accuracy: 0.32275132275132273, Precision: 0.9725407783723214, Recall: 0.9386037764766582, F1 Score: 0.9305249751181706


## Inference

In [11]:
def create_reverse_vocab(annotations_vocab):
    return {idx: label for label, idx in annotations_vocab.items()}

def convert_to_labels(binary_vector, reverse_vocab):
    labels = [reverse_vocab[i] for i, value in enumerate(binary_vector) if value == 1]
    return labels

def intersection_and_difference(true_labels, predicted_labels):
    # Calculate intersection and differences of two label lists
    true_set = set(true_labels)
    predicted_set = set(predicted_labels)
    intersection = true_set & predicted_set
    incorrect = predicted_set - true_set
    return len(intersection), len(true_set), len(predicted_set), len(incorrect)

import torch

def predict_and_evaluate(model, test_loader, annotations_vocab, num_samples=10):
    model.eval()
    total_intersection = 0
    total_true_labels = 0
    total_predicted_labels = 0
    total_incorrect = 0
    num_zero_prediction = 0  # To track samples with zero predicted labels
    sample_data = []
    reverse_vocab = create_reverse_vocab(annotations_vocab)

    with torch.no_grad():
        for embeddings, labels in test_loader:
            outputs = model(embeddings)
            predicted = torch.round(outputs)  # Using 0.5 as a threshold
            # Convert binary vectors to MONDO names
            predicted_labels = [convert_to_labels(pred, reverse_vocab) for pred in predicted.cpu().numpy()]
            true_labels = [convert_to_labels(true, reverse_vocab) for true in labels.cpu().numpy()]
            
            for true, pred in zip(true_labels, predicted_labels):
                inter, true_count, pred_count, incorrect = intersection_and_difference(true, pred)
                total_intersection += inter
                total_true_labels += true_count
                total_predicted_labels += pred_count
                total_incorrect += incorrect
                if len(pred) == 0:
                    num_zero_prediction += 1  # Increment if no labels were predicted
            
            # Collect samples for display
            if len(sample_data) < num_samples:
                sample_data.extend(zip(embeddings, true_labels, predicted_labels))
            if len(sample_data) >= num_samples:
                break
    
    # Calculate averages
    num_samples = len(test_loader.dataset)
    average_correct = (total_intersection / num_samples) * 100
    average_incorrect = (total_incorrect / num_samples) * 100
    print(f"Average percentage of correct MONDO names per sample: {average_correct:.2f}%")
    print(f"Average percentage of incorrect MONDO names per sample: {average_incorrect:.2f}%")
    print(f"Number of samples with zero predicted labels: {num_zero_prediction}")

    return sample_data, num_zero_prediction

# Example usage assuming model, test_loader, and annotations_vocab are defined
sample_data, num_zero_prediction = predict_and_evaluate(model, test_loader, annotations_vocab, num_samples=10)

# Print the samples and the total number of zero predicted labels
for idx, (embedding, true_label, predicted_label) in enumerate(sample_data):
    print(f"Sample {idx+1}")
    print(f"True MONDO Names: {true_label}")
    print(f"Predicted MONDO Names: {predicted_label}\n")
print(f"Total samples with zero predictions: {num_zero_prediction}")

Average percentage of correct MONDO names per sample: 40.87%
Average percentage of incorrect MONDO names per sample: 2.55%
Number of samples with zero predicted labels: 9
Sample 1
True MONDO Names: ['ovarian cancer', 'autism spectrum disorder', 'lung carcinoma', 'gastroesophageal reflux disease', 'lung cancer', 'endogenous depression', 'osteosarcoma', 'bone Paget disease', 'cutaneous lupus erythematosus', 'Crohn disease', 'squamous cell lung carcinoma', 'ulcerative colitis', 'Sjogren syndrome', 'osteoporosis', 'pilocytic astrocytoma', 'attention deficit hyperactivity disorder, inattentive type', 'postmenopausal osteoporosis', 'chronic obstructive pulmonary disease', 'dental caries']
Predicted MONDO Names: ['gastroesophageal reflux disease', 'psoriasis', 'bone Paget disease', 'cutaneous lupus erythematosus', 'Crohn disease', 'squamous cell lung carcinoma', 'ulcerative colitis', 'osteoporosis', 'postmenopausal osteoporosis', 'dental caries']

Sample 2
True MONDO Names: ['large cell medul