In [1]:
# --- Setup and Imports ---
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import re # Still needed for splitting phases and basic cleaning
import math # For isnan

# --- Configuration ---
XLSX_FILENAME = 'trials.xlsx' # Make sure your file is named this or change here

# Column for semantic matching - Focusing ONLY on Conditions as requested
SEMANTIC_MATCH_COLUMN = 'Conditions'

# Columns and criteria for structured filtering
FILTER_PRIMARY_OUTCOME_COLUMN = 'Primary Outcome Measures'
FILTER_PRIMARY_OUTCOME_TERM = 'Overall Survival'

FILTER_PHASES_COLUMN = 'Phases'
# Note: Splitting Phase combinations for checking is safer.
ACCEPTABLE_PHASES_STR = ['PHASE1|PHASE2', 'PHASE2', 'PHASE2|PHASE3', 'PHASE3', 'PHASE4']
# Create a set of individual phases for flexible checking (e.g., accept PHASE2 from 'PHASE2|PHASE3')
ACCEPTABLE_INDIVIDUAL_PHASES = set()
for phase_combo in ACCEPTABLE_PHASES_STR:
    for phase in re.split(r'[|/]+', phase_combo): # Split by | or /
        ACCEPTABLE_INDIVIDUAL_PHASES.add(phase.strip().upper())

FILTER_STUDY_TYPE_COLUMN = 'Study Type'
FILTER_STUDY_TYPE_VALUE = 'INTERVENTIONAL'

# Relevance score threshold (semantic similarity).
# Only trials with a semantic similarity score AT or ABOVE this threshold will be considered relevant.
# Needs tuning based on your data and desired strictness.
RELEVANCE_SCORE_THRESHOLD = 0.35 # Example threshold (tune this)


# --- Helper Functions ---

# Minimal text cleaning: lowercase and strip whitespace
def clean_text(text):
    """Minimal text cleaning: lowercase and strip whitespace."""
    if isinstance(text, str):
        # Replace multiple whitespaces with single space, then strip leading/trailing
        text = re.sub(r'\s+', ' ', text).strip()
        return text.lower()
    # Return empty string for anything that's not a string or None after conversion
    return ''


# Helper for sorting with None/NaN values (None/NaN goes to the end)
def sort_key_with_none(value, reverse=True):
    if value is None or (isinstance(value, float) and math.isnan(value)):
        return float('-inf') if reverse else float('inf')
    return value

# Function to check if a trial's phases match the acceptable list (checking individual phases)
def check_phases(trial_phases_raw):
    if not isinstance(trial_phases_raw, str):
        return False
    trial_phases_cleaned = clean_text(trial_phases_raw).upper()
    # Split the trial's phase string by | or / or space
    trial_individual_phases = re.split(r'[|\s/]+', trial_phases_cleaned)

    # Check if *any* individual phase mentioned in the trial is in our set of acceptable *individual* phases
    for phase in trial_individual_phases:
        if phase and phase in ACCEPTABLE_INDIVIDUAL_PHASES:
            return True

    return False

# Function to get a sort value for phases
def get_phase_sort_value(phases_raw):
    if not isinstance(phases_raw, str): return 0
    phases_upper = clean_text(phases_raw).upper() # Clean and uppercase the phases for consistent lookup
    # Use a dictionary to map common phases/combinations to a sort value
    phase_order = {
        'PHASE4': 5,
        'PHASE3': 4,
        'PHASE2|PHASE3': 3,
        'PHASE3|PHASE2': 3,
        'PHASE2': 2,
        'PHASE1|PHASE2': 1,
        'PHASE2|PHASE1': 1,
        'PHASE1': 0.5 # Give PHASE1 a value lower than combinations including it
    }
    # Check if the exact cleaned combination is in our predefined order
    exact_combo_value = phase_order.get(phases_upper, 0)
    if exact_combo_value > 0:
         return exact_combo_value

    # If exact combination not found, check individual phases and take the max value
    individual_phases_in_trial = re.split(r'[|\s/]+', phases_upper)
    max_phase_value = 0
    for p in individual_phases_in_trial:
        # Map individual phases (PHASE4=4, PHASE3=3, PHASE2=2, PHASE1=1)
        p_val = 0
        if 'PHASE4' in p: p_val = 4
        elif 'PHASE3' in p: p_val = 3
        elif 'PHASE2' in p: p_val = 2
        elif 'PHASE1' in p: p_val = 1
        max_phase_value = max(max_phase_value, p_val)

    return max_phase_value # Return the highest individual phase value as a fallback


# --- Data Loading and Preprocessing ---

print(f"Loading data from {XLSX_FILENAME}...")
try:
    df = pd.read_excel(XLSX_FILENAME)
    print("Data loaded successfully.")
    print(f"Initial data shape: {df.shape}")
except FileNotFoundError:
    print(f"Error: {XLSX_FILENAME} not found. Please make sure the XLSX file is in the same directory.")
    exit()

# --- Preprocess the SEMANTIC_MATCH_COLUMN ('Conditions') ---
if SEMANTIC_MATCH_COLUMN not in df.columns:
     print(f"Error: '{SEMANTIC_MATCH_COLUMN}' column not found in the file, which is required for semantic matching.")
     exit()

# Make sure the column is treated as string type reliably before applying minimal cleaning
# Convert to string first, then replace potential pandas/numpy NaN representation string 'nan' with empty string
df[SEMANTIC_MATCH_COLUMN + '_str'] = df[SEMANTIC_MATCH_COLUMN].astype(str).replace('nan', '').str.strip()
df[SEMANTIC_MATCH_COLUMN + '_cleaned'] = df[SEMANTIC_MATCH_COLUMN + '_str'].apply(clean_text)


# Load Medical Domain Sentence Transformer model
print("\nLoading Medical Domain Sentence Transformer model...")
try:
    # model = SentenceTransformer('sentence-transformers/pubmedbert-base-uncased')
    model = SentenceTransformer('all-MiniLM-L6-v2')
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading Sentence Transformer model: {e}")
    print("Please ensure you have internet access or the model files are cached, or try a different model name.")
    exit()

# Generate embeddings specifically for the cleaned SEMANTIC_MATCH_COLUMN
print(f"\nGenerating text embeddings for trial data ({SEMANTIC_MATCH_COLUMN} column)...")

# Only generate embeddings for rows where the cleaned semantic match column text is not empty
non_empty_semantic_indices = df.index[df[SEMANTIC_MATCH_COLUMN + '_cleaned'].str.strip() != ''].tolist()
non_empty_semantic_texts = df.loc[non_empty_semantic_indices, SEMANTIC_MATCH_COLUMN + '_cleaned'].tolist()

if not non_empty_semantic_texts:
     print(f"Warning: No non-empty text found in the '{SEMANTIC_MATCH_COLUMN}' column after cleaning. Cannot generate embeddings.")
     semantic_trial_embeddings = np.array([]) # Renamed for clarity
     semantic_index_to_embedding_index = {}
else:
    semantic_trial_embeddings = model.encode(non_empty_semantic_texts, show_progress_bar=True, convert_to_numpy=True) # Renamed
    semantic_index_to_embedding_index = {original_idx: emb_idx for emb_idx, original_idx in enumerate(non_empty_semantic_indices)}
    print(f"{SEMANTIC_MATCH_COLUMN} embeddings generated for {len(non_empty_semantic_indices)} trials.")


# --- Search Function ---

def find_relevant_trials(df: pd.DataFrame, semantic_trial_embeddings: np.ndarray, semantic_index_to_embedding_index: dict,
                         model: SentenceTransformer,
                         user_cancer_type_raw: str, user_stage_raw: str, user_biomarkers_raw: str,
                         relevance_threshold: float = RELEVANCE_SCORE_THRESHOLD):
    """
    Finds and ranks relevant clinical trials based on user-provided cancer information,
    structured filters, and semantic similarity to the Conditions column.

    Args:
        df: The pre-processed DataFrame.
        semantic_trial_embeddings: Pre-calculated embeddings for the trial's semantic match column text.
        semantic_index_to_embedding_index: Mapping from DataFrame index to semantic embedding index.
        model: The loaded Sentence Transformer model.
        user_cancer_type_raw: The raw string input for cancer type.
        user_stage_raw: The raw string input for stage.
        user_biomarkers_raw: The raw string input for biomarkers (comma-separated).
        relevance_threshold: The minimum semantic similarity score for a trial to be considered relevant.

    Returns:
        A list of dictionaries, each representing a relevant trial result with details and scores.
    """
    # Apply minimal cleaning to user inputs
    user_cancer_type_cleaned = clean_text(user_cancer_type_raw)
    user_stage_cleaned = clean_text(user_stage_raw)
    user_biomarkers_cleaned_list = [clean_text(b.strip()) for b in user_biomarkers_raw.split(',') if b.strip()]

    # Create the full user query string for embedding
    # Include all parts of the user's profile for semantic matching against Conditions
    user_full_query_text = f"{user_cancer_type_cleaned} {user_stage_cleaned} {' '.join(user_biomarkers_cleaned_list)}"

    print(f"\n--- Searching for trials for profile: {user_full_query_text.strip()} ---")
    if not user_full_query_text.strip():
         print("Warning: User query is empty after cleaning. Cannot perform search.")
         return []

    # Generate embedding for the user query
    try:
        user_embedding = model.encode(user_full_query_text, convert_to_numpy=True)
    except Exception as e:
        print(f"Error generating user query embedding: {e}")
        return []

    potential_results = []

    # Iterate through the pre-processed DataFrame
    for index, row in df.iterrows():

        # --- Apply Structured Filters ---

        # 1. Filter by Primary Outcome Measures
        # Ensure column exists and handle potential non-string types
        primary_outcome_text = str(row.get(FILTER_PRIMARY_OUTCOME_COLUMN, '')).lower()
        if FILTER_PRIMARY_OUTCOME_TERM.lower() not in primary_outcome_text:
             continue # Skip if Primary Outcome filter not met

        # 2. Filter by Phases
        # Ensure column exists and handle potential non-string types handled by check_phases
        trial_phases_raw = row.get(FILTER_PHASES_COLUMN, '')
        if not check_phases(trial_phases_raw):
             continue # Skip if Phases filter not met

        # 3. Filter by Study Type
        # Ensure column exists and handle potential non-string types
        trial_study_type = str(row.get(FILTER_STUDY_TYPE_COLUMN, '')).upper()
        if trial_study_type != FILTER_STUDY_TYPE_VALUE.upper():
             continue # Skip if Study Type filter not met

        # --- Apply Semantic Relevance Filter (using Conditions column embedding) ---
        # Ensure trial has a Conditions embedding (meaning it had non-empty text in that column)
        if index not in semantic_index_to_embedding_index:
             # This should ideally be filtered out during preprocessing if the column was empty,
             # but this check is a safety net.
             continue

        semantic_emb_index = semantic_index_to_embedding_index[index]
        overall_semantic_sim = cosine_similarity([user_embedding], [semantic_trial_embeddings[semantic_emb_index]])[0][0]

        # Filter by Overall Relevance Threshold
        if overall_semantic_sim < relevance_threshold:
            continue

        # --- If all filters passed, add to results ---
        potential_results.append({
            'index': index,
            'overall_semantic_similarity': overall_semantic_sim,
            # Include relevant original data for display/explanation, using .get() for safety
            'NCT Number': row.get('NCT Number', 'N/A'),
            'Study Title': row.get('Study Title', 'N/A'),
            'Study Status': row.get('Study Status', 'N/A'),
            'Conditions': row.get('Conditions', 'N/A'),
            'Interventions': row.get('Interventions', 'N/A'),
            'Phases': row.get('Phases', 'N/A'),
            'Brief Summary': row.get('Brief Summary', 'N/A'),
            'Primary Outcome Measures': row.get(FILTER_PRIMARY_OUTCOME_COLUMN, 'N/A'), # Use column name constant
        })

    # --- Ranking ---
    # Sort results (which are already filtered):
    # 1. Primarily by Overall Semantic Similarity (descending)
    # 2. Secondarily by Phase (later phases often more relevant clinical question)
    # 3. Tertiary by Study Status (alphabetical)

    potential_results.sort(key=lambda x: (
        x['overall_semantic_similarity'], # Primary: Overall Semantic Similarity (desc)
        get_phase_sort_value(x.get('Phases')), # Secondary: Phase (desc)
         x.get('Study Status', 'ZZZ') # Tertiary: Study Status (alphabetical, puts 'ZZZ' at the end)
    ), reverse=True)

    # --- Present Results ---
    print(f"\nFound {len(potential_results)} relevant trials:")
    print(f"(Filtered by: Primary Outcome contains '{FILTER_PRIMARY_OUTCOME_TERM}', Phases in {ACCEPTABLE_PHASES_STR}, Study Type is '{FILTER_STUDY_TYPE_VALUE}')")
    print(f"(Filtered by Overall Semantic Similarity to '{SEMANTIC_MATCH_COLUMN}' >= {relevance_threshold:.2f})")


    if not potential_results:
        print("\nNo relevant trials found for this profile based on the specified filters and thresholds.")
    else:
        # Prepare formatted output
        formatted_output = []
        for i, result in enumerate(potential_results):
            formatted_output.append(f"\n--- Result {i+1} ---")
            formatted_output.append(f"NCT Number: {result['NCT Number']}")
            formatted_output.append(f"Study Title: {result['Study Title']}")
            formatted_output.append(f"Status: {result['Study Status']}")
            formatted_output.append(f"Phases: {result['Phases']}")
            formatted_output.append(f"Interventions: {result['Interventions']}")
            formatted_output.append(f"Conditions: {result['Conditions']}")
            # Brief summary might not always be helpful if semantic match is only on Conditions
            # formatted_output.append(f"Brief Summary: {result['Brief Summary']}")
            formatted_output.append(f"Primary Outcome: {result['Primary Outcome Measures']}")
            formatted_output.append(f"Explanation: Overall Semantic Sim (to Conditions): {result['overall_semantic_similarity']:.4f}")
            formatted_output.append("-" * 50) # Separator

    print("\n".join(formatted_output))
    return potential_results # Return the list of result dictionaries



Loading data from trials.xlsx...
Data loaded successfully.
Initial data shape: (7498, 256)

Loading Medical Domain Sentence Transformer model...
Model loaded successfully.

Generating text embeddings for trial data (Conditions column)...


Batches:   0%|          | 0/231 [00:00<?, ?it/s]

Conditions embeddings generated for 7388 trials.


In [6]:

# --- Manual Input Section ---

print("\nPlease enter patient information for clinical trial matching.")

# # --- Example 1: NSCLC, Stage IV, specific mutations ---
user_cancer_type = "Non-Small Cell Lung Cancer"
user_stage = "Stage 4"
user_biomarkers = "EGFR mutation, PD-L1 positive"

find_relevant_trials(df, semantic_trial_embeddings, semantic_index_to_embedding_index,
                     model, user_cancer_type, user_stage, user_biomarkers)

# print("\n" + "="*80 + "\n") # Separator for multiple searches

# # --- Example 2: Breast Cancer, metastatic, HER2-low ---
# user_cancer_type = "Breast Cancer"
# user_stage = "metastatic"
# user_biomarkers = "HER2 low"

# find_relevant_trials(df, semantic_trial_embeddings, semantic_index_to_embedding_index,
                    #  model, user_cancer_type, user_stage, user_biomarkers)

# print("\n" + "="*80 + "\n") # Separator for multiple searches


# # --- Example 3: Prostate cancer, mCRPC, PSMA positive ---
# # Using the example Conditions value directly for high similarity expectation
# user_cancer_type = "Prostate-specific membrane antigen (PSMA)-positive metastatic castration-resistant prostate cancer (mCRPC)"
# user_stage = "" # Stage is included in the Conditions text
# user_biomarkers = "" # Biomarkers are included in the Conditions text

# find_relevant_trials(df, semantic_trial_embeddings, semantic_index_to_embedding_index,
#                      model, user_cancer_type, user_stage, user_biomarkers)

# print("\n" + "="*80 + "\n") # Separator for multiple searches

# # --- Example 4: Urothelial Carcinoma (from your sample data) ---
# user_cancer_type = "Urothelial Carcinoma" # Or try 'bladder cancer' in type
# user_stage = "operable high-risk"
# user_biomarkers = "" # No specific biomarkers mentioned in the sample brief summary

# find_relevant_trials(df, semantic_trial_embeddings, semantic_index_to_embedding_index,
#                      model, user_cancer_type, user_stage, user_biomarkers)

# print("\n" + "="*80 + "\n") # Separator for multiple searches


Please enter patient information for clinical trial matching.

--- Searching for trials for profile: non-small cell lung cancer stage 4 egfr mutation pd-l1 positive ---

Found 56 relevant trials:
(Filtered by: Primary Outcome contains 'Overall Survival', Phases in ['PHASE1|PHASE2', 'PHASE2', 'PHASE2|PHASE3', 'PHASE3', 'PHASE4'], Study Type is 'INTERVENTIONAL')
(Filtered by Overall Semantic Similarity to 'Conditions' >= 0.35)

--- Result 1 ---
NCT Number: NCT06712355
Study Title: Safety and Effectiveness of BNT327, an Investigational Therapy in Combination With Chemotherapy for Patients With Untreated Small-cell Lung Cancer
Status: RECRUITING
Phases: PHASE3
Interventions: DRUG: BNT327|DRUG: Atezolizumab|DRUG: Etoposide|DRUG: Carboplatin
Conditions: Extensive-stage Small-cell Lung Cancer
Primary Outcome: Overall survival (OS), OS defined as the time from randomization to death from any cause., Up to approximately 39 months
Explanation: Overall Semantic Sim (to Conditions): 0.6199
------

[{'index': 4219,
  'overall_semantic_similarity': np.float32(0.61991215),
  'NCT Number': 'NCT06712355',
  'Study Title': 'Safety and Effectiveness of BNT327, an Investigational Therapy in Combination With Chemotherapy for Patients With Untreated Small-cell Lung Cancer',
  'Study Status': 'RECRUITING',
  'Conditions': 'Extensive-stage Small-cell Lung Cancer',
  'Interventions': 'DRUG: BNT327|DRUG: Atezolizumab|DRUG: Etoposide|DRUG: Carboplatin',
  'Phases': 'PHASE3',
  'Brief Summary': 'This is a Phase III, multisite, randomized, double-blinded study to investigate BNT327 combined with chemotherapy (etoposide/carboplatin) compared to atezolizumab combined with chemotherapy (etoposide/carboplatin) for the treatment of participants with previously untreated extensive-stage small-cell lung cancer (ES-SCLC).',
  'Primary Outcome Measures': 'Overall survival (OS), OS defined as the time from randomization to death from any cause., Up to approximately 39 months'},
 {'index': 4217,
  'overall