# Clinical Trials ranking project

In this notebook, I focuse on analyzing and processing clinical trial data to enable ranking of trials based on patient information. In this project, I tried to proceed data cleaning, feature extraction, and embedding-based similarity calculations to align clinical trials with individual patient information.

## Objectives:
1. **Data Cleaning and Preprocessing**: Handle missing values, normalize data, and prepare structured and unstructured fields for analysis.
2. **Embedding Generation**: Utilize pre-trained language models to create vector embeddings of unstructured text fields, enabling semantic similarity calculations.
3. **Patient-Specific Filtering**: Apply filters based on patient attributes (e.g., age, gender, medical condition) to refine the set of relevant trials.
4. **Ranking Clinical Trials**: Rank the filtered trials based on similarity to patient data using cosine similarity scores.


**Project by:** Mouna NAIM  

**Email:** mounanaim5001@gmail.com


To install necessary packages, please run:
```
pip install <package_name>
```

In [1]:
# libraries
import pandas as pd
import numpy as np
from collections import Counter
import re
import warnings

# Machine Learning and NLP libraries
import torch
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util

# Progress 
from tqdm.auto import tqdm

# NLTK for tokenization
import nltk
from nltk.tokenize import sent_tokenize

# Download necessary NLTK resources
nltk.download('punkt')  # Download Punkt tokenizer models for sentence splitting and word tokenization





[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\red-y\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

### Explore and clean the data

In [4]:
# Load Data
file_path = 'trials.csv'  #Path of the data 
data = pd.read_csv(file_path, delimiter=';')
data.head()

Unnamed: 0,nct_id,last_update_submitted_date,overall_status,brief_title,official_title,phase,acronym,study_type,description,minimum_age_num,minimum_age_unit,maximum_age_num,maximum_age_unit,criteria
0,NCT00002462,2023-02-14,"Active, not recruiting",RT or No RT Following Chemotherapy in Treating...,Phase III Randomized Trial of Adjuvant Involve...,Phase 3,,Interventional,RATIONALE: Drugs used in chemotherapy use diff...,15.0,Years,70.0,Years,DISEASE CHARACTERISTICS: Histologically proven...
1,NCT00003980,2012-07-10,Suspended,BIBX 1382 in Treating Patients With Solid Tumors,Phase I and Pharmacokinetics Study to Determin...,Phase 1,,Interventional,RATIONALE: Drugs used in chemotherapy use diff...,18.0,Years,,,DISEASE CHARACTERISTICS: Histologically or cyt...
2,NCT00005584,2023-02-14,"Active, not recruiting",Combination Chemotherapy With or Without Radia...,Prospective Controlled Trial in Clinical Stage...,Phase 3,,Interventional,RATIONALE: Drugs used in chemotherapy use diff...,15.0,Years,70.0,Years,DISEASE CHARACTERISTICS:\n\nRandomized groups:...
3,NCT00049595,2023-11-08,Completed,Comparison of Two Combination Chemotherapy Reg...,BEACOPP (4 Cycles Escalated + 4 Cycles Baselin...,Phase 3,,Interventional,RATIONALE: Drugs used in chemotherapy use diff...,16.0,Years,60.0,Years,DISEASE CHARACTERISTICS:\n\nHistologically con...
4,NCT00055731,2023-01-17,Completed,Hormone Therapy With or Without Docetaxel And ...,Phase III Randomized Study Of Adjuvant Hormona...,Phase 3,,Interventional,RATIONALE: Androgens can stimulate the growth ...,0.0,Years,79.0,Years,DISEASE CHARACTERISTICS:\n\nHistologically con...


In [5]:
# Display Missing Values
missing_data = data.isna().sum()
print("Missing Data Summary:")
print(missing_data)

Missing Data Summary:
nct_id                          0
last_update_submitted_date      0
overall_status                  0
brief_title                     0
official_title                  3
phase                           0
acronym                       379
study_type                      0
description                     0
minimum_age_num                29
minimum_age_unit               29
maximum_age_num               787
maximum_age_unit              787
criteria                        0
dtype: int64


The data contains missing values in official_title, acronym, and age columns

### Dealing with missing values

In [None]:
# We replace the Nan acronyms by No Acronym
data['acronym'] = data['acronym'].fillna('No Acronym') 

# We replace Nan values in official title column with the corresponding values from the brief_title column
data['official_title'] = data['official_title'].fillna(data['brief_title']) 

In [8]:
units = data['minimum_age_unit']
count_units_min =set(units)
print(f"units for minimum age is {count_units_min}")

units_max = data['maximum_age_unit']
count_units_max =set(units_max)
print(f"units for maximum age is {count_units_max}")

units for minimum age is {nan, 'Year', 'Years', 'Months', 'Month', 'Day'}
units for maximum age is {'Years', nan}



In the age data, we have five different units, which is not practical. The idea here is to normalize all ages to the unit of years. Then, we create two columns: the first one, `minimum_age`, contains the minimum age in years (without units), and the second column, `maximum_age`, contains the maximum age in years (without units).

In [31]:
## In this function, all ages are converted to years unit
def normalize_age(num, unit): 
    if pd.isna(num) or pd.isna(unit):
        return None
    unit_lower = unit.lower()
    if 'year' in unit_lower or 'years' in unit_lower:
        return num
    elif 'month' in unit_lower or 'months' in unit_lower:
        return num / 12
    elif 'day' in unit_lower or 'days' in unit_lower:
        return num / 365
    else:
        return None

# We add two columns minimun_age and maximum_age, that contain only ages in years
data['minimum_age'] = data.apply(lambda x: normalize_age(x['minimum_age_num'], x['minimum_age_unit']), axis=1)
data['maximum_age'] = data.apply(lambda x: normalize_age(x['maximum_age_num'], x['maximum_age_unit']), axis=1)

Up to this point, the columns `minimum_age` and `maximum_age` have been added to the dataset. These two columns replace the four columns: `minimum_age_num`, `maximum_age_num`, `minimum_age_unit`, and `maximum_age_unit`, which will be dropped from the original dataset in the next cell.


In [32]:
# Dropping minimum_age_num, maximum_age_num, minimum_age_unit, maximum_age_unit
columns_to_drop = ['minimum_age_num', 'maximum_age_num', 'minimum_age_unit', 'maximum_age_unit']
data = data.drop(columns=columns_to_drop, errors='ignore')

In [33]:
data.columns

Index(['nct_id', 'last_update_submitted_date', 'overall_status', 'brief_title',
       'official_title', 'phase', 'acronym', 'study_type', 'description',
       'criteria', 'gender', 'inclusion_criteria', 'exclusion_criteria',
       'organs', 'keywords', 'combined_text', 'text_embedding', 'minimum_age',
       'maximum_age'],
      dtype='object')

## Ranking function
### Step 1: Classification of gender 
Given patient data (age, gender, organs, or other additional information), we need to identify the most relevant information in the data that matches the patient's profile. To rank the clinical trials, I proposed creating a function based on transformers, leveraging the description column. This column provides insights into whether a trial is suitable for men, women, or both.

To achieve this, I will use a pretrained classification Transformer to classify the description based on their applicability to men, women, or both. This classification will help determine the suitability of each clinical trial for a given patient profile.


Firstable, I tried a simple classication (in the function below) using the key works like male, man, men or female, woman, women, but the results was not satisfying.

In [11]:
tqdm.pandas()
def classify_gender_simple(description_text):
    """
    Classify gender requirement using keyword matching

    """
    if pd.isna(description_text) or description_text.strip() == '':
        return 'Both'

    description_text_lower = description_text.lower()
    if re.search(r'\b(male|man|men)\b', description_text_lower) and not re.search(r'\b(female|woman|women)\b', description_text_lower):
        return 'Male'
    elif re.search(r'\b(female|woman|women)\b', description_text_lower) and not re.search(r'\b(male|man|men)\b', description_text_lower):
        return 'Female'
    else:
        return 'Both'

genders = data['description'].fillna('').progress_apply(classify_gender_simple)
#print(genders)

  0%|          | 0/1000 [00:00<?, ?it/s]

So, I tried to use Transformers to classify the gender of patients.

This function determine the gender (Male, Female, Both) for admitted patients for a clinical trial using Transformers. 
It uses a combination of regular expression (regex) search and zero-shot classification with a natural language model to classify clinical trials based on gender criteria specified in the description column.

In [None]:
# To delete the specific warning about length mismatch
warnings.filterwarnings("ignore", message="Length of IterableDataset .*")

# Initialize tqdm for pandas
tqdm.pandas()

# Choose the device (GPU or CPU)
device = 0 if torch.cuda.is_available() else -1
print(f"Using device: {'GPU' if device == 0 else 'CPU'}")

# Initialize the zero-shot classification pipeline
classifier = pipeline(
    'zero-shot-classification',
    model='valhalla/distilbart-mnli-12-1', # Model I found on hugging face library https://huggingface.co/tasks/zero-shot-classification
    device=device
)

# In this tab, I add line indices if male or female is found in the description column
needs_classification = []
print("Performing regex pre-check on 'description'...")
# Convert description line into lower 
for index, row in tqdm(data.iterrows(), total=data.shape[0]):
    description_text = str(row['description']).lower()

    # Return both if male or female not found 
    if not re.search(r'\b(male|man|men)\b', description_text) and not re.search(r'\b(female|woman|women)\b', description_text):
        data.at[index, 'gender'] = 'Both'
    else:
        needs_classification.append(index)

# Process the entries that need classification in batches
if needs_classification:
    batch_size = 8  #For GPU
    # Update candidate labels to more relevant phrases
    candidate_labels = ['only male candidates', 'only female candidates']
    texts_to_classify = data.loc[needs_classification, 'description'].fillna('').tolist()
    indices = needs_classification

    print("Classifying Gender with Zero-Shot Classifier...")
    genders = []
    for i in tqdm(range(0, len(texts_to_classify), batch_size), desc="Processing Batches"):
        batch_texts = texts_to_classify[i:i+batch_size]
        # Pre-truncate texts to prevent splitting
        batch_texts = [text[:1000] for text in batch_texts] 
        try:
            results = classifier(
                batch_texts,
                candidate_labels,
                truncation=True,
                max_length=512
            )
            for result in results:
                predicted_label = result['labels'][0]
                score = result['scores'][0]

                if score >= 0.9:
                    if predicted_label == 'only male candidates':
                        genders.append('Male')
                    elif predicted_label == 'only female candidates':
                        genders.append('Female')
                    else:
                        # In case the predicted label is not in expected labels
                        genders.append('Both')
                else:
                    # If the confidence score is below 0.9, assign 'Both' 
                    genders.append('Both')
        except Exception as e:
            print(f"Classification error: {e}")
            genders.extend(['Both'] * len(batch_texts))

    # Assign the results back to the DataFrame
    for idx, gender in zip(indices, genders):
        data.at[idx, 'gender'] = gender
else:
    print("No entries required classification with the zero-shot classifier.")

print("Gender classification completed.")
print(genders)

Using device: GPU
Performing regex pre-check on 'description'...


  0%|          | 0/1000 [00:00<?, ?it/s]

Classifying Gender with Zero-Shot Classifier...


Processing Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Gender classification completed.
['Both', 'Both', 'Female', 'Both', 'Both', 'Female', 'Both', 'Female', 'Female', 'Both', 'Male', 'Both', 'Both', 'Both', 'Both', 'Both', 'Female', 'Both', 'Both', 'Both', 'Both', 'Both', 'Both', 'Both', 'Both', 'Female', 'Female', 'Male', 'Female', 'Both', 'Both', 'Both', 'Both', 'Both', 'Both', 'Male', 'Both', 'Both', 'Female', 'Male', 'Both', 'Male', 'Female', 'Both', 'Both', 'Both', 'Both', 'Both']


## Step 2: Classifying inclusion and exclusion criteria

The `criteria` column contains both inclusion and exclusion `criteria`. In this section, I will attempt to separate them using zero-shot classification. The function below parses the clinical trial's textual "criteria" and separates them into distinct `inclusion_criteria` and `exclusion_criteria`.

In [None]:
# tqdm for pandas
tqdm.pandas()

# Choose the device (GPU or CPU)
device = 0 if torch.cuda.is_available() else -1
print(f"Using device: {'GPU' if device == 0 else 'CPU'}")

# Initialize the zero-shot classification pipeline
classifier = pipeline(
    'zero-shot-classification',
    model='valhalla/distilbart-mnli-12-1', 
    device=device
)

def parse_criteria_line_by_line(criteria_text):
 
    if pd.isna(criteria_text) or criteria_text.strip() == '':
        return pd.Series({'inclusion_criteria': '', 'exclusion_criteria': ''})
    
    # Split the text into lines
    lines = criteria_text.strip().split('\n')

    # Remove empty lines and strip whitespace
    lines = [line.strip() for line in lines if line.strip()]
    
    # Prepare lists to hold inclusion and exclusion criteria
    inclusion_lines = []
    exclusion_lines = []
    
    # Initialize the last predicted label 
    last_label = None
    started = False  # Start adding lines after finding a line with score >= 0.9
    
    # Define labels
    candidate_labels = ['inclusion criteria', 'exclusion criteria']
    
    # Process lines in batches
    batch_size = 1000  
    for i in range(0, len(lines), batch_size):
        batch_lines = lines[i:i+batch_size]
        # Skip very short lines
        batch_lines = [line for line in batch_lines if len(line) >= 3]
        if not batch_lines:
            continue
        try:
            # Classify the batch of lines
            results = classifier(
                batch_lines,
                candidate_labels,
                truncation=True,
                max_length=512
            )
            for result, line in zip(results, batch_lines):
                predicted_label = result['labels'][0]
                score = result['scores'][0]
                
                if score >= 0.9:
                    last_label = predicted_label
                    started = True  # Start adding lines once we have a confident prediction
                
                if started and last_label:
                    if last_label == 'inclusion criteria':
                        inclusion_lines.append(line)
                    elif last_label == 'exclusion criteria':
                        exclusion_lines.append(line)
            torch.cuda.empty_cache()  # Clear cache after processing each batch
        except Exception as e:
            # I skip the batch if there is an error
            print(f"Classification error: {e}")
            continue
    
    # Combine the lines back into text
    inclusion_text = '\n'.join(inclusion_lines)
    exclusion_text = '\n'.join(exclusion_lines)
    
    return pd.Series({'inclusion_criteria': inclusion_text, 'exclusion_criteria': exclusion_text}) # return a serie


Using device: GPU


#### test of the function parse_criteria_line_by_line

In [18]:
# Test on a sample
sample_text = data['criteria'].iloc[100]  # I choose randomly the index 100 to test the function
parsed_criteria = parse_criteria_line_by_line(sample_text)
print(parsed_criteria)
print(data['criteria'][100])
print("Inclusion Criteria:")
print(parsed_criteria['inclusion_criteria'])
print("\nExclusion Criteria:")
print(parsed_criteria['exclusion_criteria'])


inclusion_criteria    Inclusion Criteria:\nMales and females of 18 y...
exclusion_criteria    Exclusion Criteria:\nAny significant medical c...
dtype: object
Inclusion Criteria:

Males and females of 18 years of age to 80 years of age.
Understand and voluntarily sign an informed consent document prior to any study related assessments/procedures are conducted.
Able to adhere to the study visit schedule and other protocol requirements.

Patients with histologically proven peripheral T-cell lymphoma (PTCL), not previously treated; the following subtypes as defined by the World Health Organization (WHO) classification (2008;2011) may be included, whatever the Ann Arbor stage (I - IV):

a. Nodal types: i. PTCL, not otherwise specified ii. Angioimmunoblastic T-cell lymphoma iii. Anaplastic large cell lymphoma, anaplastic lymphoma kinase (ALK)-negative type

b. Extra-nodal types: i. Enteropathy-associated T-cell lymphoma ii. Hepato-splenic T-cell lymphoma iii. Subcutaneous panniculitis-like T

In [19]:
# Apply the parsing function to thz data with a progress bar
print("Parsing criteria texts...")
criteria_df = data['criteria'].progress_apply(parse_criteria_line_by_line)
# Concatenate the new columns to the data
data = pd.concat([data, criteria_df], axis=1)


Parsing criteria texts...


  0%|          | 0/1000 [00:00<?, ?it/s]

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


In this function, we extract organs present in brief_title, official_title, description, criteria. Then, I add the organ column to my data

In [20]:
organs = ['Lung', 'Bladder', 'Breast', 'Prostate', 'Gastric', 'Brain', 'Skin',
          'Colon', 'Vulva', 'Thyroid', 'Kidney', 'Pleura']
organs_lower = [organ.lower() for organ in organs]

def extract_organs_from_text(row):
    """
    Identify organs mentioned in all relevant columns.
    """
    text = ' '.join([
        str(row['brief_title']),
        str(row['official_title']),
        str(row['description']),
        str(row['criteria'])
    ]).lower()
    found_organs = [organ for organ in organs_lower if organ in text]
    return list(set(found_organs))

data['organs'] = data.apply(extract_organs_from_text, axis=1)
data['organs']

0                        [skin]
1                       [brain]
2                        [skin]
3                        [skin]
4              [prostate, skin]
                 ...           
995    [breast, prostate, skin]
996              [breast, skin]
997                  [prostate]
998      [breast, kidney, skin]
999              [breast, skin]
Name: organs, Length: 1000, dtype: object

In this function, I search key words in the brief_title, official_title, description. In other words, I delete stopping words from these three columns. Then, I add a column named `keywords` containing the key words.

In [None]:
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

def extract_keywords(row):
    
    text = ' '.join([
        str(row['brief_title']),
        str(row['official_title']),
        str(row['description'])
    ]).lower()
    # Re to extract words
    words = re.findall(r'\b\w+\b', text)
    # Remove stop words
    stopwords = set(ENGLISH_STOP_WORDS)
    keywords = [word for word in words if word not in stopwords and len(word) > 2]
    return list(set(keywords))

data['keywords'] = data.apply(extract_keywords, axis=1)
data['keywords']

0      [phase, ways, different, mopp, purpose, combin...
1      [phase, ways, determine, different, solid, adm...
2      [phase, high, ways, prognostic, different, ray...
3      [phase, ways, different, lymphoma, known, cycl...
4      [phase, high, androgens, ways, prostate, diffe...
                             ...                        
995    [17p, intergroupe, rate, interventional, respo...
996    [phase, operated, test, chronic, ttfields, mal...
997    [617, radical, coupled, prostate, scanner, pos...
998    [phase, high, early, her2, older, combination,...
999    [phase, data, preclinical, filo, impact, newly...
Name: keywords, Length: 1000, dtype: object

In [23]:
data.columns

Index(['nct_id', 'last_update_submitted_date', 'overall_status', 'brief_title',
       'official_title', 'phase', 'acronym', 'study_type', 'description',
       'minimum_age_num', 'minimum_age_unit', 'maximum_age_num',
       'maximum_age_unit', 'criteria', 'gender', 'inclusion_criteria',
       'exclusion_criteria', 'organs', 'keywords'],
      dtype='object')

In [None]:
# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load embedding model
embedding_model = SentenceTransformer('all-mpnet-base-v2', device=device)

Using device: cuda


I devide my features into structured and unstructured data

In [25]:
structured_data = ['age', 'gender', 'organ', 'phase']  # 
unstructured_data = ['keywords', 'inclusion_criteria', 'description'] #


In the cell below, I define functions to filter my data. For a given patient, I remove the rows that do not match the patient's data. This filtering is applied to the `structured_data`.


In [26]:
# Filter Trials Based on Structured Data ['age', 'gender', 'organ', 'phase'] 
def filter_by_age(df, age):
    return df[
        ((df['minimum_age'].isna()) | (df['minimum_age'] <= age)) &
        ((df['maximum_age'].isna()) | (df['maximum_age'] >= age))
    ] # Extract the data corresponding to the age of the patient, and it removes the unnecessary lines 

def filter_by_gender(df, gender):
    if gender.lower() == 'male':
        return df[df['gender'].isin(['Male', 'Both'])]
    elif gender.lower() == 'female':
        return df[df['gender'].isin(['Female', 'Both'])]
    else:
        return df[df['gender'] == 'Both'] # Extract the data corresponding to the gender of the patient, and remove the unnecessary lines.

def filter_by_organ(df, organ):
    organ_lower = organ.lower()
    return df[df['organs'].apply(lambda organs_list: organ_lower in organs_list)] # Extract the data corresponding to the organ of interest, and remove the unnecessary lines.

def filter_by_phase(df, phases):
    if not phases:
        return df
    phases = [phase.lower() for phase in phases]
    return df[df['phase'].str.lower().isin(phases)] # Extract the data corresponding to the phase of the clinical trial, and remove the unnecessary lines.

This cell below generates **text embeddings** for clinical trial data. It combines unstructured text fields (unstructured_data = ['keywords', 'inclusion_criteria', 'description']) into a single text string for each trial, applies a pre-trained embedding model (`SentenceTransformer`), and stores the resulting embeddings in the dataset. These embeddings are precomputed to enable similarity comparisons later.


In [None]:
# Load embedding model
# Embedding is applied to the unstructred data
embedding_model = SentenceTransformer('all-mpnet-base-v2', device=device)
print("Embedding model loaded.")

# Compute and store embeddings for all trials
def batch_embedding(text_list, model, batch_size=16):
    
    embeddings = []
    for i in tqdm(range(0, len(text_list), batch_size), desc="Generating Embeddings"):
        batch = text_list[i:i + batch_size]
        batch_embeddings = model.encode(batch, convert_to_tensor=True, show_progress_bar=False)
        embeddings.extend(batch_embeddings.cpu().numpy())
    return embeddings

# Prepare data for embedding
def combine_text_columns(row):
    text = ' '.join([str(row[field]) for field in unstructured_data])
    return text

data['combined_text'] = data.apply(combine_text_columns, axis=1)

print("Generating embeddings for all trials...")
data['text_embedding'] = batch_embedding(data['combined_text'].tolist(), embedding_model, batch_size=16)
print("Embeddings generated and stored.")


Embedding model loaded.
Generating embeddings for all trials...


Generating Embeddings:   0%|          | 0/63 [00:00<?, ?it/s]

Embeddings generated and stored.


In [28]:
data.columns

Index(['nct_id', 'last_update_submitted_date', 'overall_status', 'brief_title',
       'official_title', 'phase', 'acronym', 'study_type', 'description',
       'minimum_age_num', 'minimum_age_unit', 'maximum_age_num',
       'maximum_age_unit', 'criteria', 'gender', 'inclusion_criteria',
       'exclusion_criteria', 'organs', 'keywords', 'combined_text',
       'text_embedding'],
      dtype='object')

## `rank_trials` Function: Ranking Clinical Trials

This function ranks clinical trials based on their relevance to a given patient's data. It takes a dictionary of patient data, which includes information such as age, gender, organ, keywords related to the patient, and any additional details.

### Steps of Calculation

1. **Generate Patient Embedding**  
   An embedding of the patient's description is created using the `embedding_model`. This embedding vector is later compared to the embeddings of the clinical trials (That are in the column `text_embedding`).

2. **Filter Clinical Trials**  
   The clinical trial data is filtered to exclude trials that do not match the patient's criteria. The following filters are applied:
   - **Age Filter**: Matches the patient’s age with the trial's age requirements.
   - **Gender Filter**: Ensures the trial is open to the patient’s gender.
   - **Organ Filter**: Checks if the trial is relevant to the specified organ.
   - **Phase Filter** (if provided): Filters trials based on specific trial phases.

3. **Calculate Similarity Scores**  
   For the remaining trials, the cosine similarity between the patient embedding and the embeddings of the trial descriptions (`text_embedding`) is calculated. The similarity score reflects how closely a trial aligns with the patient's information.

4. **Rank Trials**  
   The trials are ranked in descending order of their similarity scores, with the most relevant trials appearing first.




In [29]:
def rank_trials(patient_data, trials_data, embedding_model):
    """
    Rank trials based on the desciption of patient data.
    """
    # Create patient description for embedding
    patient_description = " ".join([
        f"Age: {patient_data.get('age', '')}",
        f"Gender: {patient_data.get('gender', '')}",
        f"Organ: {patient_data.get('organ', '')}",
        f"Keywords: {' '.join(patient_data.get('keywords', []))}", # a key word related to the patient's condition
        f"Inclusion Criteria: {patient_data.get('patient_information', '')}", # Compare patient information with the inclusion criteria present in the critera column
    ])
    
    # Embed the patient's data
    patient_embedding = embedding_model.encode(patient_description, convert_to_tensor=True)

    # Apply filters
    filtered_trials = trials_data.copy()

    # Age filter
    filtered_trials = filter_by_age(filtered_trials, patient_data['age'])

    # Gender filter
    filtered_trials = filter_by_gender(filtered_trials, patient_data['gender'])
    
    # Organ filter
    filtered_trials = filter_by_organ(filtered_trials, patient_data['organ'])

    # Phase filter(if the filter is)
    filtered_trials = filter_by_phase(filtered_trials, patient_data.get('phase', []))

    if filtered_trials.empty:
        print("No trials found matching the criteria.")
        return pd.DataFrame()

    # Compute similarity scores
    trial_embeddings = np.vstack(filtered_trials['text_embedding'].values)
    print("Computing similarity scores...")
    trial_embeddings_tensor = torch.tensor(trial_embeddings, device=patient_embedding.device)
    similarity_scores = util.pytorch_cos_sim(patient_embedding, trial_embeddings_tensor).cpu().numpy()[0]

    # Add similarity scores to the DataFrame
    filtered_trials = filtered_trials.reset_index(drop=True)
    filtered_trials['similarity_score'] = similarity_scores

    # Rank the trials
    ranked_data = filtered_trials.sort_values(by='similarity_score', ascending=False)

    return ranked_data


In [None]:
patient_data = {
    "gender": "Female",
    "age": 53,
    "organ": "Breast",
    #"phase": ["Phase 2"],  # Optional
    "keywords": [""],  # Optional
    "patient_information": "patient is pregnent",  # Optional
}

# Rank trials based on patient data
ranked_trials = rank_trials(patient_data, data, embedding_model)

# Display Top Trials
top_n = 10
top_trials = ranked_trials.head(top_n)

results_table = top_trials[['nct_id', 'similarity_score']]
results_table
# Show the results
for idx, row in top_trials.iterrows():
    print(f"Trial ID: {row['nct_id']}")  
    print(f"Similarity Score: {row['similarity_score']}")
    print(f"Title: {row['brief_title']}")
    print(f"Phase: {row['phase']}")
    print(f"Organ(s): {row['organs']}")
    #print(f"inclusion: {row['inclusion_criteria']}")
    # print(f"Keywords: {row['keywords']}")
    # print('-' * 80)


Computing similarity scores...
Trial ID: NCT00251433
Similarity Score: 0.7048949599266052
Title: GW572016 With Docetaxel and Trastuzumab for the Treatment Of Untreated ErbB2 Over-Expressing Metastatic Breast Cancer
Phase: Phase 1
Organ(s): ['breast']
Trial ID: NCT02759133
Similarity Score: 0.693721354007721
Title: Preoperative Localization of Infraclinical Breast Tumors: Isotopic Localization by iodine125 Seed Versus Standard Localization Using a Metal Wire
Phase: Not Applicable
Organ(s): ['breast']
Trial ID: NCT03498716
Similarity Score: 0.6801649928092957
Title: A Study Comparing Atezolizumab (Anti PD-L1 Antibody) In Combination With Adjuvant Anthracycline/Taxane-Based Chemotherapy Versus Chemotherapy Alone In Patients With Operable Triple-Negative Breast Cancer
Phase: Phase 3
Organ(s): ['breast', 'skin']
Trial ID: NCT00470236
Similarity Score: 0.6764039993286133
Title: Radiation Doses and Fractionation Schedules in Non-low Risk Ductal Carcinoma In Situ (DCIS) of the Breast
Phase: No

In [36]:
results_table

Unnamed: 0,nct_id,similarity_score
0,NCT00251433,0.704895
172,NCT02759133,0.693721
398,NCT03498716,0.680165
3,NCT00470236,0.676404
433,NCT03584334,0.674406
175,NCT02775903,0.673988
336,NCT03321981,0.672651
54,NCT01977274,0.670022
14,NCT01358877,0.662582
92,NCT02285179,0.650272
