## 0. Libraries 📚

In [None]:
import pandas as pd
import ast

## 1. Load data 📥

In [None]:
diagnoses_df = pd.read_csv("data/ground_truth_df.csv")
diagnoses_df['Codigos_diagnosticos'] = diagnoses_df['Codigos_diagnosticos'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
diagnoses_df['Diagnosticos_estandar'] = diagnoses_df['Diagnosticos_estandar'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else [])
diagnoses_df

## 2. Similarity over embeddings

In [None]:
index = 1
diagnosis_description = diagnoses_df["Descripcion_diagnosticos_limpio"][index]
true_diagnosis_standard = diagnoses_df["Diagnosticos_estandar"][index]

print(f"Diagnosis description: {diagnosis_description}")
print(f"Diagnosis codes: {diagnoses_df['Codigos_diagnosticos'][index]}")
print(f"Standard diagnoses: {true_diagnosis_standard}")

In [None]:
diagnoses_list = diagnoses_df['Diagnosticos_estandar'].tolist()
diagnoses_list = list(set(code for sublist in diagnoses_list for code in sublist))

In [None]:
# model_name = "dmis-lab/biobert-v1.1"
# model_name = "medicalai/ClinicalBERT"
# model_name = "bionlp/bluebert_pubmed_uncased_L-12_H-768_A-12"
# model_name = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" ### PubMedBERT
# model_name = "PlanTL-GOB-ES/bsc-bio-es"
# model_name = "yikuan8/Clinical-Longformer"
# model_name = "yikuan8/Clinical-BigBird"
# model_name = "PlanTL-GOB-ES/roberta-base-biomedical-clinical-es"
# model_name = "PlanTL-GOB-ES/roberta-base-biomedical-es"
model_name = "intfloat/multilingual-e5-large"

In [None]:
from transformers import AutoConfig, AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
max_tokens = min(tokenizer.model_max_length, AutoConfig.from_pretrained(model_name).max_position_embeddings)

# Function to get the embedding of a sentence
def get_embedding(text, max_len=512):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_len
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
    # Use the [CLS] token embedding as the sentence representation
    return outputs.last_hidden_state[:, 0, :].cpu() 

# Get embeddings
emb_diagnosticos = torch.vstack([get_embedding(d, max_tokens) for d in diagnoses_list])

In [None]:
# Get embeddings
emb_desc = get_embedding(diagnosis_description, max_tokens)

# Compute cosine similarity
cos_sim = F.cosine_similarity(emb_desc, emb_diagnosticos)

# Show results
top_idx = torch.argmax(cos_sim).item()
print("Closest diagnosis:", diagnoses_list[top_idx])
print("Similarity:", cos_sim[top_idx].item())

# Compute similarity with the true standard diagnoses
true_diagnosis_embeddings = torch.vstack([get_embedding(d, max_tokens) for d in true_diagnosis_standard])
true_cos_sim = F.cosine_similarity(emb_desc, true_diagnosis_embeddings)

# Show results
for idx, diagnosis in enumerate(true_diagnosis_standard):
    print(f"Similarity with '{diagnosis}': {true_cos_sim[idx].item()}")

In [None]:
from tqdm import tqdm

correct_count = 0

for idx, row in tqdm(diagnoses_df.iterrows(), total=len(diagnoses_df)):
    # Get the diagnosis description
    diagnosis_description = row["Descripcion_diagnosticos_limpio"]
    true_diagnosis_standard = row["Diagnosticos_estandar"]

    # Compute the embedding of the description
    emb_desc = get_embedding(diagnosis_description, max_tokens)

    # Compute cosine similarity with all diagnoses
    cos_sim = F.cosine_similarity(emb_desc, emb_diagnosticos)

    # Find the closest diagnosis
    top_idx = torch.argmax(cos_sim).item()
    predicted_diagnosis = diagnoses_list[top_idx]

    # Check if the predicted diagnosis is in the true standard diagnoses
    if predicted_diagnosis in true_diagnosis_standard:
        correct_count += 1
    
# Calculate accuracy metric
accuracy = correct_count / len(diagnoses_df)
print(f"Accuracy: {accuracy:.4f}")

In [None]:
print(f"{diagnosis_description=}")
print(f"{true_diagnosis_standard=}")
print(f"{cos_sim.sort(descending=True).values=}")