In [1]:
"""
@Author: Magnus Graham
1/3/2025

This notebook matches free text to a set of predefined symptoms.
It uses spaCy to preprocess text for, and uses BioBERT and SBERT
to map their meaning to the closest possible match in the symptoms list.

"""

'\n@Author: Magnus Graham\n1/3/2025\n\nThis notebook matches free text to a set of predefined symptoms.\nIt uses spaCy to preprocess text for, and uses BioBERT and SBERT\nto map their meaning to the closest possible match in the symptoms list.\n\n'

In [None]:
!pip install sentence-transformers
!pip install torch
!pip install spacy
!python -m spacy download en_core_web_sm  # if not already installed



In [None]:
import spacy
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer

# Load pre-trained SBERT model
model = SentenceTransformer('all-MiniLM-L6-v2')
import torch
import torch.nn.functional as F

# Load spaCy model for tokenization
nlp = spacy.load("en_core_web_sm")

# Load BioBERT model and tokenizer from Hugging Face
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-v1.1")


In [None]:
def normalize_text(text):
    # Lowercase and remove unnecessary punctuation
    text = text.lower().strip()
    return text

def preprocess_input(user_input):
    # Normalize input text
    normalized_input = normalize_text(user_input)
    
    # Split by commas or periods
    clauses = [clause.strip() for clause in normalized_input.split(",") if clause.strip()]
    print("processing")
    return clauses

def process_clauses(clauses, lemmatize=True):
    print("Processing clauses")
    processed_clauses = []  # Initialize the list to store processed clauses
    for clause in clauses:
        doc = nlp(clause)
        
        if lemmatize:
            # Lemmatize each token in the clause
            processed_clause = " ".join([token.lemma_ for token in doc if token.is_alpha and not token.is_stop])
        else:
            # Tokenize and exclude stop words and non-alphabetical tokens (without lemmatizing)
            processed_clause = " ".join([token.text for token in doc if token.is_alpha and not token.is_stop])
        
        processed_clauses.append(processed_clause)  # Append the processed clause to the list
    return processed_clauses



In [None]:
def get_bio_bert_embeddings(sentences):
    """
    Generate SBERT embeddings for a list of sentences.
    
    Args:
        sentences (list of str): List of input sentences.
    
    Returns:
        torch.Tensor: Embedding tensor of shape (batch_size, hidden_size).
    """
    if isinstance(sentences, str):
        sentences = [sentences]  # Ensure input is a list

    # Directly encode the sentences
    embeddings = model.encode(sentences, convert_to_tensor=True)  # Output shape: (batch_size, hidden_size)
    return embeddings


In [None]:
import heapq
import torch.nn.functional as F

# Function to compute pairwise cosine similarity
def cosine_similarity_matrix(embeddings1, embeddings2):
    """
    Computes the cosine similarity matrix between two sets of embeddings.
    
    Args:
        embeddings1 (torch.Tensor): Tensor of shape (n1, hidden_size).
        embeddings2 (torch.Tensor): Tensor of shape (n2, hidden_size).
    
    Returns:
        torch.Tensor: Similarity matrix of shape (n1, n2).
    """
    return torch.mm(F.normalize(embeddings1, p=2, dim=1), F.normalize(embeddings2, p=2, dim=1).T)

# Function to find the most similar predefined symptoms for each clause
def compare_clauses_to_symptoms(clauses, predefined_symptoms, top_n=10):
    """
    Matches clauses to predefined symptoms using cosine similarity of SBERT embeddings.
    
    Args:
        clauses (list of str): User input clauses to compare.
        predefined_symptoms (list of str): List of predefined symptoms.
        top_n (int): Number of most similar symptoms to return for each clause.
    
    Returns:
        None: Prints the top N matches for each clause.
    """
    print("Comparing...")
    
    # Get SBERT embeddings for clauses and symptoms
    clause_embeddings = get_bio_bert_embeddings(clauses)
    symptom_embeddings = get_bio_bert_embeddings(predefined_symptoms)
    
    # Compute similarity matrix
    similarity_matrix = cosine_similarity_matrix(clause_embeddings, symptom_embeddings)
    
    # Process each clause and its similarities to symptoms
    for i, clause in enumerate(clauses):
        # Get similarity scores for the current clause
        similarities = similarity_matrix[i].tolist()
        
        # Combine symptoms with their respective similarity scores
        similarity_scores = [(predefined_symptoms[j], score) for j, score in enumerate(similarities)]
        
        # Extract the top N most similar symptoms
        top_similar_symptoms = heapq.nlargest(top_n, similarity_scores, key=lambda x: x[1])
        
        # Display the top N matched symptoms for the clause
        print(f"Clause: '{clause}'")
        for symptom, similarity in top_similar_symptoms:
            print(f"  - Symptom: '{symptom}' - Similarity: {similarity:.2f}")



In [None]:
import pandas as pd

# Load the CSV file into a DataFrame
file_path = "/Users/magnusgraham/Downloads/disease-symptom/DiseaseAndSymptoms.csv"
df = pd.read_csv(file_path)

#create a set to store all unique symptom values
predefined_symptoms = set()
disease_names = set()
    
for col in df.columns[1:19]:  # Adjust indices if needed
    for value in df[col].unique():
        predefined_symptoms.add(str(value).replace("_", " "))
predefined_symptoms = list(predefined_symptoms)

#add user inputted symptoms
user_input = """chills, dehydration, fatigue, fever, flushing, loss of appetite, body ache, or sweating
Nasal: congestion, runny nose, or sneezing
Also common: chest pressure, head congestion, headache, nausea, shortness of breath, sore throat, or swollen lymph nodes"""
# Step 1: Preprocess the input
clauses = preprocess_input(user_input)


# Step 2: Tokenize and Lemmatize the clauses
processed_clauses = process_clauses(clauses)
processed_symptoms = process_clauses(predefined_symptoms)

# Step 3: Compare each clause to predefined symptoms using BioBERT
compare_clauses_to_symptoms(processed_clauses, processed_symptoms)