# LLM Review of Disease-Centric Splits

In [1]:
import logging
import os
import shutil

import numpy as np
import pandas as pd
from openai import AzureOpenAI
from tqdm import tqdm

from src.config import conf

_logger = logging.getLogger(__name__)

Read in disease splits.

In [2]:
nodes = pd.read_csv(conf.paths.kg.nodes_path, dtype={"node_index": int}, low_memory=False)
edges = pd.read_csv(
    conf.paths.kg.edges_path, dtype={"edge_index": int, "x_index": int, "y_index": int}, low_memory=False
)
disease_splits = pd.read_csv(conf.paths.splits_dir / "disease_splits.csv")

Set up the GPT-4 client.

In [5]:
client = AzureOpenAI(
    azure_endpoint=conf.AZURE_OPENAI_ENDPOINT,
    api_key=conf.AZURE_OPENAI_API_KEY,
    api_version="2024-05-01-preview",
)

Use GPT-4 to evaluate disease splits.

In [None]:
gpt_ranks = []
tokens_used = []

for _i, row in tqdm(disease_splits.iterrows(), desc="GPT-4o evaluation of disease splits", total=len(disease_splits)):
    # Construct message
    split_disease = nodes[nodes["node_index"] == row["disease_split_index"]]["node_name"].values[0]
    candidate_disease = row["node_name"]
    split_disease = split_disease.replace("(disease)", "").strip()
    candidate_disease = candidate_disease.replace("(disease)", "").strip()

    message = [
        {
            "role": "system",
            "content": "You are a helpful biomedical expert with an understanding of disease mechanisms, treatment options for every disease, and deep clinical knowledge of disease symptoms, phenotypes, genotypes, and drug treatments.",
        },
        {
            "role": "user",
            "content": f"Rank on a scale from 1 to 5 how closely related {split_disease} and {candidate_disease} are. 1 is not related at all, 4 is that that they are closely related (e.g., a drug that treats {split_disease} could also treat {candidate_disease}), 5 is that they are the same disease or subtypes of the same disease. Respond with a number from 1-5 only, no other text.",
        },
    ]

    # Submit query for inference
    # Approx. 82 tokens per completion
    response = client.chat.completions.create(
        model="gpt-4o-1120",
        messages=message,
        temperature=0.5,
        max_tokens=1,
    )

    # Append to list
    gpt_ranks.append(response.choices[0].message.content)
    tokens_used.append(response.usage.total_tokens)

# Add to disease splits
disease_splits["gpt_rank"] = gpt_ranks
disease_splits["tokens_used"] = tokens_used

GPT-4o evaluation of disease splits:   0%|                                                                                                                            | 0/1406 [00:00<?, ?it/s]

In [None]:
disease_splits["gpt_rank"] = disease_splits["gpt_rank"].astype(int)
_logger.info(disease_splits["gpt_rank"].value_counts())

1    561
5    387
4    163
2    160
3    135
Name: gpt_rank, dtype: int64


Set self comparisons to `Yes`. Set all non-`Yes`/`No` comparisons to `No`.

In [None]:
# Construct gpt_eval as 'Yes' if gpt_rank >= 3, 'No' if gpt_rank < 3
disease_splits["gpt_eval"] = np.where(disease_splits["gpt_rank"] >= 3, "Yes", "No")

# Set self comparisons to 'Yes'
disease_splits.loc[disease_splits["node_index"] == disease_splits["disease_split_index"], "gpt_eval"] = "Yes"
_logger.info(disease_splits["gpt_eval"].value_counts())

No     721
Yes    685
Name: gpt_eval, dtype: int64


Compute total tokens used.

In [None]:
_logger.info("Total tokens used:", sum(disease_splits["tokens_used"]))

Total tokens used: 203471


Save file to CSV.

In [None]:
disease_splits.to_csv(conf.paths.splits_dir / "disease_splits_GPT.csv", index=False)

## Save Disease Splits

Save each split to its own file.

In [None]:
split_dir = conf.paths.splits_dir / "split_edges_GPT"
if os.path.isdir(split_dir):
    shutil.rmtree(split_dir)
os.mkdir(split_dir)

In [None]:
# Get drug_disease_edges
drug_disease_edges = edges[(edges["x_type"] == "disease") & (edges["y_type"] == "drug")]

# Filter to GPT-4 evaluations of 'Yes'
disease_splits_filtered = disease_splits[disease_splits["gpt_eval"] == "Yes"]
disease_splits_grouped = disease_splits_filtered.groupby("disease_split_index")
edge_count = {}

for disease_split, disease_split_df in tqdm(disease_splits_grouped, desc="Save splits"):
    
    # Get indication edges
    disease_split_edges = drug_disease_edges[drug_disease_edges["x_index"].isin(disease_split_df["node_index"])]
    disease_split_edges = disease_split_edges.reset_index(drop=True)

    # If some edges exist
    if len(disease_split_edges) > 0:

        # Save to CSV
        disease_split_edges.to_csv(split_dir / f"{disease_split}.csv", index=False, encoding="utf-8-sig")
        edge_count[disease_split] = len(disease_split_edges)

    else:

        # Drop from disease_splits
        disease_splits_filtered = disease_splits_filtered[
            ~disease_splits_filtered["disease_split_index"].isin(disease_split_df["node_index"])
        ]
        tqdm.write(f"Removed split {disease_split} ({disease_split_df['disease_split'].values[0]}) as it has no edges.")

# Save all disease splits
all_split_edges = drug_disease_edges[drug_disease_edges["x_index"].isin(disease_splits_filtered["node_index"].unique())]
all_split_edges.to_csv(split_dir / "all.csv", index=False, encoding="utf-8-sig")