# RAG on clinical trial data

ClinicalTrials.gov is a publicly accessible database maintained by the U.S. National Library of Medicine (NLM) at the National Institutes of Health (NIH). It provides information on both privately and publicly funded clinical studies conducted around the world. We can download the database of clinical trials [here](https://classic.clinicaltrials.gov/api/gui/ref/download_all).
 
Each trial is captured as a JSON file, with a mixture of useful semantic context (trials description, etc) as well as trial metadata (such as the trial ID `NCTId` and phases). This data is well suited to semantic search with metadata filtering, with some fields indexed for semantic search and others for `metadata filtering`.
 
We wil use JSONLoader (docs [here](https://python.langchain.com/docs/modules/data_connection/document_loaders/json)) to load the JSON data and specify metadata fields by passing the `extract_metadata` function to `metadata_func`.

## Data Loding

Download the database of clinical trials [here](https://classic.clinicaltrials.gov/api/gui/ref/download_all).

There are `471,942` JSON files in the downloaded directory `AllAPIJSON` and its subdirectories.

Let's grab 5 from one of the sub-directories as a sample.

In [None]:
# Define the directory path, selecting one of the sub-directories to read from
import os 
path_to_db = "/Users/rlm/Desktop/Clinical-Trials/AllAPIJSON/NCT0000xxxx"

# List all files in the directory
all_files = os.listdir(path_to_db)

# Filter only JSON files
json_files = [file for file in all_files if file.endswith('.json')]

# Sort and select the first 5 JSON files (you can customize the sorting if needed)
first_5_jsons = sorted(json_files)[:5]

Let's look at the structure of each record.

In [17]:
import json
import pandas as pd

# Assuming first_5_jsons is a list of file paths.
with open(path_to_db+'/'+first_5_jsons[0], 'r') as f:
    data = json.load(f)

# Normalize the JSON data to a DataFrame
df = pd.json_normalize(data)

# Print the columns to understand the structure
print(df.columns)

Index(['FullStudy.Rank',
       'FullStudy.Study.ProtocolSection.IdentificationModule.NCTId',
       'FullStudy.Study.ProtocolSection.IdentificationModule.OrgStudyIdInfo.OrgStudyId',
       'FullStudy.Study.ProtocolSection.IdentificationModule.SecondaryIdInfoList.SecondaryIdInfo',
       'FullStudy.Study.ProtocolSection.IdentificationModule.Organization.OrgFullName',
       'FullStudy.Study.ProtocolSection.IdentificationModule.Organization.OrgClass',
       'FullStudy.Study.ProtocolSection.IdentificationModule.BriefTitle',
       'FullStudy.Study.ProtocolSection.StatusModule.StatusVerifiedDate',
       'FullStudy.Study.ProtocolSection.StatusModule.OverallStatus',
       'FullStudy.Study.ProtocolSection.StatusModule.ExpandedAccessInfo.HasExpandedAccess',
       'FullStudy.Study.ProtocolSection.StatusModule.StudyFirstSubmitDate',
       'FullStudy.Study.ProtocolSection.StatusModule.StudyFirstSubmitQCDate',
       'FullStudy.Study.ProtocolSection.StatusModule.StudyFirstPostDateStruct.Stud

We can [load each JSON record](https://python.langchain.com/docs/modules/data_connection/document_loaders/json).

With `extract_metadata`, the parsed JSON data (a dictionary) is provided to the JSONLoader. 

Here, it corresponds to the `ProtocolSection` of the study.

The function first attempts to access the `IdentificationModule` within the sample dictionary using the get method. 

If IdentificationModule is present, it'll return its value; otherwise, it'll return an empty dictionary ({}).

Next, the function attempts to access NCTId from the previously fetched value. 

If NCTId is present, its value is returned; otherwise, None is returned.

We can perform this for each desired metadata field.

In [3]:
from typing import Any, Dict
from langchain.docstore.document import Document
from langchain.document_loaders import JSONLoader

def extract_metadata(sample: Dict[str, Any], default_metadata: Dict[str, Any]) -> Dict[str, Any]:
    nctid = sample.get('IdentificationModule', {}).get('NCTId', None)
    study_type = sample.get('DesignModule', {}).get('StudyType', None)
    phase_list = sample.get('DesignModule', {}).get('PhaseList', {}).get('Phase', None)
    metadata = {
        **default_metadata,
        'NCTId': nctid,
        'StudyType': study_type,
        'PhaseList': str(phase_list) # Metadata needs to be str
    }
    return metadata

def load_json(dir: str) -> Document:
    
    loader = JSONLoader(
        file_path=dir,
        jq_schema='.FullStudy.Study.ProtocolSection',
        metadata_func=extract_metadata,
        text_content=False
    )
    
    data = loader.load()
    return data[0]

Read the JSON objects.

In [4]:
# Process each of the selected JSON files
clinical_trial_data_sample = []
for json_file in first_5_jsons:
    file_path = os.path.join(dir_path, json_file)
    data = load_json(file_path)
    clinical_trial_data_sample.append(data)

Split documents for embedding and vectorstore.

In [5]:
# Split documents
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=0)
all_splits = text_splitter.split_documents(clinical_trial_data_sample)

In [6]:
all_splits[0].metadata

{'source': '/Users/rlm/Desktop/GENE-workshop/AllAPIJSON/NCT0000xxxx/NCT00000102.json',
 'seq_num': 1,
 'NCTId': 'NCT00000102',
 'StudyType': 'Interventional',
 'PhaseList': "['Phase 1', 'Phase 2']"}

Embed and add to vectorDB.

We save this vectorstore to the directory for future use.

In [7]:
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
vectorstore = Chroma.from_documents(
    documents=all_splits,
    collection_name="rag-biomedical",
    persist_directory='./vectorstore',
    embedding=OpenAIEmbeddings(),
)

In [8]:
# Test
retriever = vectorstore.as_retriever()
docs = retriever.get_relevant_documents("What was the focus of trial NCT00000102?")
[doc.metadata["NCTId"] for doc in docs]

['NCT00000104', 'NCT00000105', 'NCT00000102', 'NCT00000107']

We get a mix of results.

[Chroma](https://python.langchain.com/docs/integrations/vectorstores/chroma) allows for metadata filtering. 

* [LangChain guide](https://python.langchain.com/docs/integrations/vectorstores/chroma)
* [Chroma guide](https://docs.trychroma.com/usage-guide#filtering-by-metadata)

In [21]:
docs = vectorstore.get(where={"NCTId": "NCT00000102"})
docs['metadatas']

[{'NCTId': 'NCT00000102',
  'PhaseList': "['Phase 1', 'Phase 2']",
  'StudyType': 'Interventional',
  'seq_num': 1,
  'source': '/Users/rlm/Desktop/GENE-workshop/AllAPIJSON/NCT0000xxxx/NCT00000102.json'},
 {'NCTId': 'NCT00000102',
  'PhaseList': "['Phase 1', 'Phase 2']",
  'StudyType': 'Interventional',
  'seq_num': 1,
  'source': '/Users/rlm/Desktop/GENE-workshop/AllAPIJSON/NCT0000xxxx/NCT00000102.json'},
 {'NCTId': 'NCT00000102',
  'PhaseList': "['Phase 1', 'Phase 2']",
  'StudyType': 'Interventional',
  'seq_num': 1,
  'source': '/Users/rlm/Desktop/GENE-workshop/AllAPIJSON/NCT0000xxxx/NCT00000102.json'}]

We can build a retriever that reasons about metadata filter(s) from the user question.

In [22]:
from langchain.llms import OpenAI
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.base import AttributeInfo

# Provide context about the metadata
metadata_field_info = [
    
    AttributeInfo(
        name="NCTId",
        description="The unique identifier assigned to each clinical trial when registered on ClinicalTrials.gov. ",
        type="string",
    ),
    AttributeInfo(
        name="StudyType",
        description="The nature of the study, indicating whether participants receive specific interventions or are merely observed for specific outcomes.",
        type="string",
    ),
    AttributeInfo(
        name="PhaseList",
        description="This pertains to the phase of the study in drug trials.",
        type="string",
    )
]

# Overall context for the data
document_content_description = "Information about clinical trial on ClinicalTrials.gov"

# LLM
llm = OpenAI(temperature=0,)

# Retriever
retriever_self_query = SelfQueryRetriever.from_llm(
    llm, 
    vectorstore, 
    document_content_description, 
    metadata_field_info, 
    verbose=True
)

In [23]:
docs = retriever_self_query.get_relevant_documents("What was the focus of trial NCT00000102?")
[doc.metadata["NCTId"] for doc in docs]

['NCT00000102', 'NCT00000102', 'NCT00000102']