# Enhanced patient trial matching

In [1]:
import os
from pathlib import Path


# Get project directory
def get_project_dir():
    root = Path(os.path.expanduser("~"))
    return f"{root}/github/llm-drug-discovery"


project_dir = get_project_dir()
project_dir

'/home/mgustineli/github/llm-drug-discovery'

In [2]:
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns

from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.metrics import precision_recall_fscore_support
from transformers import AutoTokenizer, AutoModel

from dotenv import load_dotenv
from huggingface_hub import login

# Project directories
PLOT_DIR = f"{project_dir}/plots"

# Hugging Face authentication from `.env`
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    raise ValueError("HF_TOKEN not found in environment (.env).")
login(token=hf_token)

# Determinism for reproducibility
pl.seed_everything(42, workers=True)

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.
Seed set to 42


42

In [3]:
def read_json(path):  # tiny helper
    with open(path) as f:
        return json.load(f)


train_patient_data = read_json(f"{project_dir}/data/train_patients.json")
train_trial_data = read_json(f"{project_dir}/data/train_trials.json")
val_patient_data = read_json(f"{project_dir}/data/test_patients.json")
val_trial_data = read_json(f"{project_dir}/data/test_trials.json")

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

## Dataset

In [4]:
class PatientTrialDataset(Dataset):
    """Creates balanced +/- pairs for a single clinical trial."""

    def __init__(self, patient_data, trial, tokenizer, max_length=512, neg_ratio=1.0):
        self.trial = trial
        self.tokenizer = tokenizer
        self.max_length = max_length
        elig = set(trial["eligible_patients"])
        pos = [p for p in patient_data if p["id"] in elig]
        neg = [p for p in patient_data if p["id"] not in elig]
        nneg = min(int(len(pos) * neg_ratio), len(neg))
        self.samples = [(p, 1) for p in pos] + [
            (p, 0) for p in np.random.choice(neg, nneg, replace=False)
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        patient, label = self.samples[idx]

        ptxt = f"Demographics: {patient['demographics']}. Medical History: {patient['medical_history']}"
        ttxt = (
            f"Trial: {self.trial['name']}. Type: {self.trial['type']}. "
            f"Description: {self.trial['description']}. "
            f"Eligibility Criteria: {self.trial['eligibility_criteria']['text']}"
        )

        enc = self.tokenizer(
            ptxt,
            ttxt,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long),
        }


def make_dataloader(
    patients,
    trials,
    tokenizer,
    max_length=512,
    neg_ratio=1.0,
    batch_size=16,
    shuffle=False,
    num_workers=4,
):
    """Returns DataLoader over concatenated datasets across trials."""
    datasets = [
        PatientTrialDataset(
            patients, trial, tokenizer, max_length=max_length, neg_ratio=neg_ratio
        )
        for trial in trials
    ]
    concat = torch.utils.data.ConcatDataset(datasets)
    return DataLoader(
        concat,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
    )


train_loader = make_dataloader(
    train_patient_data,
    train_trial_data,
    tokenizer=tokenizer,
    batch_size=16,
    shuffle=True,
    num_workers=1,
)

val_loader = make_dataloader(
    val_patient_data,
    val_trial_data,
    tokenizer=tokenizer,
    batch_size=16,
    shuffle=False,
    num_workers=1,
)

## Model

In [5]:
import torchmetrics


class PatientTrialModule(pl.LightningModule):
    def __init__(self, bert, lr=2e-5):
        super().__init__()
        self.save_hyperparameters(ignore=["bert"])
        self.bert = bert
        hidden = bert.config.hidden_size
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 2),
        )
        self.acc = torchmetrics.Accuracy(task="binary", num_classes=2)

    def forward(self, ids, mask):
        pooled = self.bert(ids, attention_mask=mask).pooler_output
        return self.classifier(pooled)

    # training / validation
    def step(self, batch):
        logits = self(batch["input_ids"], batch["attention_mask"])
        loss = nn.functional.cross_entropy(logits, batch["labels"])
        preds = torch.argmax(logits, dim=1)
        acc = self.acc(preds, batch["labels"])
        return loss, acc

    def training_step(self, batch, _):
        loss, acc = self.step(batch)
        self.log_dict({"train_loss": loss, "train_acc": acc}, prog_bar=True)
        return loss

    def validation_step(self, batch, _):
        loss, acc = self.step(batch)
        self.log_dict({"val_loss": loss, "val_acc": acc}, prog_bar=True)
        return {"val_loss": loss, "val_acc": acc}

    # def validation_epoch_end(self, outs):
    #     outs_stack = {k: torch.stack([o[k] for o in outs]).mean() for k in outs[0]}
    #     self.log_dict({f"epoch_{k}": v for k, v in outs_stack.items()})

    # optim
    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.parameters(), lr=self.hparams.lr, weight_decay=0.01
        )
        sch = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, mode="min", patience=2, factor=0.5
        )
        return {
            "optimizer": opt,
            "lr_scheduler": {"scheduler": sch, "monitor": "val_loss"},
        }

In [None]:
# Train
model = PatientTrialModule(base_model)

trainer = pl.Trainer(
    max_epochs=5,
    callbacks=[EarlyStopping(monitor="val_loss", patience=2)],
    deterministic=True,
    log_every_n_steps=10,
    # precision=16,
)
trainer.fit(model, train_loader, val_loader)
print("Training complete.")

/home/mgustineli/github/llm-drug-discovery/.venv/lib/python3.10/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
/home/mgustineli/github/llm-drug-discovery/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/mgustineli/github/llm-drug-discovery/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:7

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/mgustineli/github/llm-drug-discovery/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.
/home/mgustineli/github/llm-drug-discovery/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=13` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

In [None]:
# Function to match patients with clinical trials
def match_patients_with_trials(
    model, patient_data, trial_data, tokenizer, threshold=0.5
):
    """
    Match patients with clinical trials using the trained model
    Returns a dictionary of trial_id -> list of matched patient_ids
    """
    model.eval()
    matches = {}

    for trial in trial_data:
        trial_id = trial["id"]
        matches[trial_id] = []

        for patient in patient_data:
            patient_id = patient["id"]

            # Create patient text
            patient_text = f"Demographics: {patient['demographics']}. "
            patient_text += f"Medical History: {patient['medical_history']}"

            # Create trial text
            trial_text = f"Trial: {trial['name']}. Type: {trial['type']}. "
            trial_text += f"Description: {trial['description']}. "
            trial_text += (
                f"Eligibility Criteria: {trial['eligibility_criteria']['text']}"
            )

            # Tokenize
            encoding = tokenizer.encode_plus(
                patient_text,
                trial_text,
                add_special_tokens=True,
                max_length=512,
                padding="max_length",
                truncation=True,
                return_attention_mask=True,
                return_tensors="pt",
            )

            # Get prediction
            with torch.no_grad():
                input_ids = encoding["input_ids"]
                attention_mask = encoding["attention_mask"]
                outputs = model(input_ids, attention_mask)
                probabilities = torch.softmax(outputs, dim=1)
                match_probability = probabilities[0][
                    1
                ].item()  # Probability of class 1 (match)

                # If probability exceeds threshold, consider it a match
                if match_probability >= threshold:
                    matches[trial_id].append(
                        {
                            "patient_id": patient_id,
                            "match_probability": match_probability,
                        }
                    )

    return matches

In [None]:
# Function to evaluate matching performance
def evaluate_matching(matches, trial_data):
    """
    Evaluate the performance of the matching algorithm
    """
    results = {
        "trial_id": [],
        "precision": [],
        "recall": [],
        "f1": [],
        "num_predicted": [],
        "num_actual": [],
        "num_correct": [],
    }

    for trial in trial_data:
        trial_id = trial["id"]
        actual_matches = set(trial["eligible_patients"])
        predicted_matches = set([m["patient_id"] for m in matches.get(trial_id, [])])

        # Calculate metrics
        correct_matches = actual_matches.intersection(predicted_matches)

        precision = (
            len(correct_matches) / len(predicted_matches) if predicted_matches else 0
        )
        recall = len(correct_matches) / len(actual_matches) if actual_matches else 1.0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0
            else 0
        )

        # Store results
        results["trial_id"].append(trial_id)
        results["precision"].append(precision)
        results["recall"].append(recall)
        results["f1"].append(f1)
        results["num_predicted"].append(len(predicted_matches))
        results["num_actual"].append(len(actual_matches))
        results["num_correct"].append(len(correct_matches))

    # Calculate overall metrics
    overall_precision = np.mean(results["precision"])
    overall_recall = np.mean(results["recall"])
    overall_f1 = np.mean(results["f1"])

    print(f"Overall Precision: {overall_precision:.4f}")
    print(f"Overall Recall: {overall_recall:.4f}")
    print(f"Overall F1 Score: {overall_f1:.4f}")

    return pd.DataFrame(results)


# Plot training metrics
model.plot_metrics()

# Evaluate on test set
print("Evaluating model on test set...")
matches = match_patients_with_trials(
    model, test_patient_data, test_trial_data, tokenizer
)
evaluation_df = evaluate_matching(matches, test_trial_data)

# Save evaluation results
evaluation_df.to_csv(f"{project_dir}/data/matching_evaluation.csv", index=False)

# Visualize evaluation results
plt.figure(figsize=(10, 6))
sns.barplot(x="trial_id", y="f1", data=evaluation_df)
plt.title("F1 Score by Trial")
plt.xlabel("Trial ID")
plt.ylabel("F1 Score")
plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig(f"{project_dir}/plots/f1_by_trial.png")

# Save the trained model
torch.save(model.state_dict(), "f{project_dir}/data/patient_trial_matching_model.pt")

print(
    "Evaluation complete! Results saved to {fproject_dir}/data/matching_evaluation.csv"
)
print("Visualizations saved to {fproject_dir}/plots/")