In [1]:
import os
from pinecone import Pinecone, ServerlessSpec
from mistralai.client import MistralClient
from dotenv import load_dotenv, find_dotenv
import html2text
import xml.etree.ElementTree as ET
import requests
import itertools
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\dlind\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
load_dotenv(find_dotenv())
pc_api_key = os.getenv("PINECONE_API_KEY")
ms_api_key = os.getenv("MISTRAL_API_KEY")

In [3]:
pc = Pinecone(api_key=pc_api_key)
mistral_client = MistralClient(api_key=ms_api_key)

### Fetching the XML file

In [4]:
def download_xml(url, path):
    try:
        response = requests.get(url)
        response.raise_for_status()
        with open(path, 'w', encoding='utf-8') as file:
            file.write(response.text)
        return True
    except requests.RequestException:
        return False

XML_PATH = "medlineplus.xml"
xml_url = 'https://medlineplus.gov/xml/mplus_topics_2024-03-30.xml'
download_xml(xml_url, XML_PATH)


True

In [5]:
def html_to_markdown(html_content):
    h = html2text.HTML2Text()
    h.ignore_links = True
    return h.handle(html_content)

In [6]:
def mediplus_xml_to_dicts(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    themes = []
    
    for health_topic in root.iter('health-topic'):
        if health_topic.get('language') != 'English':
            continue
        
        title = health_topic.get('title')
        summary = "".join(health_topic.find('full-summary').itertext()) if health_topic.find('full-summary') is not None else 'No summary available'
        
        alternate_names = [also_called.text for also_called in health_topic.findall('also-called')]
        
        groups = [group.text for group in health_topic.findall('group')]
        
        related_topics = [related_topic.text for related_topic in health_topic.findall('related-topic')]
        
        health_topic_data = {
            "Title": title,
            "Summary": html_to_markdown(summary),
            "Alternate Names": alternate_names or ['None'],
            "Groups": groups or ['None'],
            "Related Topics": related_topics or ['None']
        }
        
        themes.append(health_topic_data)
    
    return themes

In [7]:
mediplus_dicts = mediplus_xml_to_dicts(XML_PATH)

In [8]:
from langchain.text_splitter import MarkdownTextSplitter

markdown_splitter = MarkdownTextSplitter(chunk_size=512, chunk_overlap=64)
docs = markdown_splitter.create_documents([mediplus_dicts[510]["Summary"]])
docs

[Document(page_content='### What is HPV?\n\nHuman papillomavirus (HPV) is a group of related viruses. They can cause warts\non different parts of your body. There are more than 200 types. About 40 of\nthem are spread through direct sexual contact with someone who has the virus.\nThey can also spread through other intimate, skin-to-skin contact. Some of\nthese types can cause cancer.'),
 Document(page_content='There are two categories of sexually transmitted HPV. Low-risk HPV can cause\nwarts on or around your genitals, anus, mouth, or throat. High-risk HPV can\ncause various cancers:\n\n  * Cervical cancer\n  * Anal cancer\n  * Some types of oral and throat cancer\n  * Vulvar cancer\n  * Vaginal cancer\n  * Penile cancer'),
 Document(page_content="Most HPV infections go away on their own and don't cause cancer. But sometimes\nthe infections last longer. When a high-risk HPV infection lasts for many\nyears, it can lead to cell changes. If these changes are not treated, they may\nget wor

In [9]:
from langchain.text_splitter import MarkdownTextSplitter

markdown_splitter = MarkdownTextSplitter(chunk_size=512, chunk_overlap=64)

def split_and_retain_metadata(documents, text_key="Summary"):
    def chunk_document(doc):
        return [
            {**doc, text_key: chunk.page_content, "Chunk Index": i}
            for i, chunk in enumerate(markdown_splitter.create_documents([doc[text_key]]))
        ]

    return [chunked_doc for doc in documents for chunked_doc in chunk_document(doc)]

chunked_mediplus_docs = split_and_retain_metadata(mediplus_dicts)

### Initializing the index

In [10]:
def create_index(index_name, dimension):
    if index_name in pc.list_indexes().names():
        pc.delete_index(index_name)

    # create a new index
    pc.create_index(
        index_name,
        dimension=dimension,  
        metric='cosine',
        spec=ServerlessSpec(
            cloud='aws', 
            region='us-west-2'
        ) 
    )

    index = pc.Index(index_name)
    index.describe_index_stats()

In [11]:
INDEX_NAME = "mediplus-corpus"
create_index(INDEX_NAME, 1024)

### Embedding the mediplus articles

In [12]:
def batch_embed_mistral(texts, model="mistral-embed", batch_size=50):
    for i in range(0, len(texts), batch_size):
        batch_texts = [text.replace("\n", " ") for text in texts[i:i + batch_size]]
        
        embeddings_batch_response = mistral_client.embeddings(
            model=model,
            input=batch_texts
        )
        
        for embedding in embeddings_batch_response.data:
            yield embedding.embedding

In [13]:
summaries = [article["Summary"] for article in chunked_mediplus_docs]
embeddings = list(batch_embed_mistral(summaries))
print(len(embeddings), "of", len(chunked_mediplus_docs), "documents embedded!")

5371 of 5371 documents embedded!


In [14]:
chunked_mediplus_docs[800]

{'Title': 'Carbohydrates',
 'Summary': '* When eating grains, choose mostly whole grains and not refined grains: \n    * Whole grains are foods like whole-wheat bread, brown rice, whole cornmeal, and oatmeal. They offer lots of nutrients that your body needs, like vitamins, minerals, and fiber. To figure out whether a product has a lot of whole grain, check the ingredients list on the package and see if a whole grain is one of the first few items listed.',
 'Alternate Names': ['Carbs'],
 'Groups': ['Food and Nutrition'],
 'Related Topics': ['Carbohydrate Metabolism Disorders',
  'Diabetic Diet',
  'Dietary Fiber',
  'Dietary Proteins',
  'Nutrition'],
 'Chunk Index': 7}

### Upsert embeddings

In [15]:
def upsert_embeddings(index_name, articles, embeddings):
    def chunks(iterable, size):
        it = iter(iterable)
        for chunk in iter(lambda: tuple(itertools.islice(it, size)), ()):
            yield chunk
    
    index = pc.Index(index_name)
    
    # Prepare data for upserting: include embedding and adjust metadata
    upsert_data = []
    for i, (article, embedding) in enumerate(zip(articles, embeddings)):
        upsert_doc = article.copy()
        upsert_doc["text"] = upsert_doc.pop("Summary")  
        upsert_data.append({
            "id": f'id-{i}',
            "vector": embedding,
            "metadata": upsert_doc  
        })
    
    # Batch upsert the data into Pinecone
    for chunk in chunks(upsert_data, 100):
        vectors_with_metadata = [(item["id"], item["vector"], item["metadata"]) for item in chunk]
        index.upsert(vectors=vectors_with_metadata)

In [16]:
upsert_embeddings(INDEX_NAME, chunked_mediplus_docs, embeddings)

### Evaluate

In [20]:
from langchain_pinecone import PineconeVectorStore
vectorstore = PineconeVectorStore(index_name=INDEX_NAME, embedding=embeddings)
index = pc.Index(INDEX_NAME)

def get_embedding(text):
   text = text.replace("\n", " ")
   return mistral_client.embeddings(
      model="mistral-embed",
      input=[text],
   ).data[0].embedding

In [23]:
question = "What is Leukemia?"

In [24]:
index.query(
  vector=get_embedding(question),
  top_k=3,
  include_values=False,
  include_metadata=True
)

{'matches': [{'id': 'id-3139',
              'metadata': {'Alternate Names': ['None'],
                           'Chunk Index': 0.0,
                           'Groups': ['Cancers',
                                      'Blood, Heart and Circulation'],
                           'Related Topics': ['Acute Lymphocytic Leukemia',
                                              'Acute Myeloid Leukemia',
                                              'Childhood Leukemia',
                                              'Chronic Lymphocytic Leukemia',
                                              'Chronic Myeloid Leukemia'],
                           'Title': 'Leukemia',
                           'text': '### What is leukemia?\n'
                                   '\n'
                                   'Leukemia is a term for cancers of the '
                                   'blood cells. Leukemia starts in blood-\n'
                                   'forming tissues such as the bone marro