# Convert NXML files from PMC to more compact XML format

* create 2 empty directories: <b>raw_data</b> and <b>res_data</b>
* Use the <b>extract_data_from_zip.sh</b> script to extract the nxml files from the <b>pmc-00.tar.gz</b>, <b>pmc-01.tar.gz</b>, and <b>pmc-02.tar.gz</b> files and <b>pmc-03.tar.gz</b> files (the file should be run from the zip files directory).
* Use the <b>vector_db.ipynb</b> notebook to convert the nxml files to a more compact xml format to be stored in the res_data directory.
* When you finish, you can delete the raw_data directory.


In [None]:
import os
import xml.etree.ElementTree as ET
from pathlib import Path
from tqdm import tqdm

def extract_text_from_nxml(nxml_file):
    """Extract text content from NXML file."""
    try:
        tree = ET.parse(nxml_file)
        root = tree.getroot()
        
        # Extract title
        title_elem = root.find('.//article-title')
        title = title_elem.text if title_elem is not None else ""
        
        # Extract authors/writers
        authors = []
        author_elems = root.findall('.//contrib[@contrib-type="author"]')
        if not author_elems:
            author_elems = root.findall('.//name')
        
        for author_elem in author_elems:
            surname_elem = author_elem.find('.//surname')
            given_names_elem = author_elem.find('.//given-names')
            
            surname = surname_elem.text if surname_elem is not None else ""
            given_names = given_names_elem.text if given_names_elem is not None else ""
            
            if surname or given_names:
                full_name = f"{given_names} {surname}".strip()
                if full_name:
                    authors.append(full_name)
        
        authors_str = "; ".join(authors) if authors else ""
        
        # Extract source/journal information
        source_info = []
        
        # Journal title
        journal_title = root.find('.//journal-title')
        if journal_title is not None and journal_title.text:
            source_info.append(journal_title.text)
        
        # Publisher name
        publisher = root.find('.//publisher-name')
        if publisher is not None and publisher.text:
            source_info.append(publisher.text)
        
        # Publication date
        pub_date = root.find('.//pub-date')
        if pub_date is not None:
            year = pub_date.find('.//year')
            month = pub_date.find('.//month')
            day = pub_date.find('.//day')
            
            date_parts = []
            if year is not None and year.text:
                date_parts.append(year.text)
            if month is not None and month.text:
                date_parts.append(month.text)
            if day is not None and day.text:
                date_parts.append(day.text)
            
            if date_parts:
                source_info.append("-".join(date_parts))
        
        # Volume and issue
        volume = root.find('.//volume')
        issue = root.find('.//issue')
        if volume is not None and volume.text:
            vol_issue = f"Vol. {volume.text}"
            if issue is not None and issue.text:
                vol_issue += f", Issue {issue.text}"
            source_info.append(vol_issue)
        
        source_str = "; ".join(source_info) if source_info else ""
        
        # Extract abstract
        abstract_elem = root.find('.//abstract')
        abstract = ""
        if abstract_elem is not None:
            abstract = " ".join([elem.text or "" for elem in abstract_elem.iter() if elem.text])
        
        # Extract body text
        body_elem = root.find('.//body')
        body = ""
        if body_elem is not None:
            body_text = " ".join([elem.text or "" for elem in body_elem.iter() if elem.text])
            body = body_text
        
        return title, authors_str, source_str, abstract, body
    except Exception as e:
        print(f"Error parsing {nxml_file}: {e}")
        return "", "", "", "", ""

# Load documents from your TREC medical dataset
import json

def write_medical_documents(data_dir, res_path, max_docs=50):
    """Load documents from TREC medical dataset and save as JSON."""
    doc_count = 0
    
    print(f"Loading documents from {data_dir}...")

    print(os.listdir(data_dir))
    for dir_name in tqdm(os.listdir(data_dir), desc="Processing directories"):
        # if doc_count >= max_docs:
        #     break
        dir_path = os.path.join(data_dir, dir_name)
        # print(f"Processing directory: {dir_path}")
        if os.path.isdir(dir_path): 
            nxml_files = [f for f in os.listdir(dir_path) if f.endswith('.nxml')]
            json_file = os.path.join(res_path, f'pmc-{dir_name}.json')
            files_to_delete = []
            documents = []
            # print(f"Found {len(nxml_files)} NXML files in {dir_name} directory.")
            if nxml_files:
                # print(f"Processing directory: {root}")
                for nxml_file in nxml_files:
                    # if doc_count >= max_docs:
                    #     break

                    file_path = os.path.join(dir_path, nxml_file)
                    title, authors_str, source_str, abstract, body = extract_text_from_nxml(file_path)
                    
                    if (title and abstract) or abstract:
                        cur_id = os.path.splitext(nxml_file)[0]
                        doc_count += 1
                        cur_folder = dir_name  # Use the directory name as the folder
                        cur_file = nxml_file.split('.nxml')[0]  # Get the file name without extension
                        
                        # Create document dictionary
                        document = {
                            "id": cur_id,
                            "title": title,
                            "authors": authors_str,
                            "source": source_str,
                            "abstract": abstract,
                        }
                        documents.append(document)
                        files_to_delete.append(file_path)

            # save content to JSON file
            with open(json_file, 'w', encoding='utf-8') as json_f:
                json.dump({"documents": documents}, json_f, indent=2, ensure_ascii=False)
            
            # Delete the original NXML files
            for file_to_delete in files_to_delete:
                try:
                    os.remove(file_to_delete)
                    # print(f"Deleted file: {file_to_delete}")
                except Exception as e:
                    print(f"Error deleting file {file_to_delete}: {e}")
                    pass  # Ignore errors in deletion

            # if doc_count % 1000 == 0:
            #     print(f"Loaded {doc_count} documents...")
    print(f"Total documents processed: {doc_count}")
    

In [7]:
overwrite_medical_documents('/home/student/project/raw_data/pmc-03','/home/student/project/res_data/pmc-03', max_docs = 1000)

Loading documents from /home/student/project/raw_data/pmc-03...
['71', '70', '58', '63', '72', '62', '59', '65', '68', '57', '61', '64', '69', '60', '66', '67', '56']


Processing directories: 100%|██████████| 17/17 [03:11<00:00, 11.25s/it]

Total documents processed: 76224





In [42]:
# check if file exists
if Path('/home/student/project/raw_data/pmc-00/01/4560455.nxml').exists():
    print("File exists")
else:
    print("File does not exist")

File does not exist


# Create Vector Database from res_data xml files

### install milvus and other packages:

In [None]:
! pip install -U pymilvus



### prepare the encoding function

using  PubMedBERT-base model

In [None]:
# ! pip uninstall -y torch torchvision torchaudio

# # Install compatible PyTorch ecosystem for CPU (simpler and more reliable)
# ! pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126

# # Install transformers after PyTorch is properly installed
# ! pip install transformers

Found existing installation: torch 2.8.0+cu126
Uninstalling torch-2.8.0+cu126:
  Successfully uninstalled torch-2.8.0+cu126
Found existing installation: torchvision 0.23.0+cu126
Uninstalling torchvision-0.23.0+cu126:
  Successfully uninstalled torchvision-0.23.0+cu126
[0mLooking in indexes: https://download.pytorch.org/whl/cu126
Collecting torch
  Using cached https://download.pytorch.org/whl/cu126/torch-2.8.0%2Bcu126-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/cu126/torchvision-0.23.0%2Bcu126-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Downloading https://download.pytorch.org/whl/cu126/torch-2.8.0%2Bcu126-cp310-cp310-manylinux_2_28_x86_64.whl (821.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m821.9/821.9 MB[0m [31m91.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hUsing cached https://download.pytorch.org/whl/cu126/torchvision-0.23.0%2Bcu126-cp310-cp310-manyli

In [1]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.8.0+cu126
CUDA available: True


In [2]:
from transformers import AutoTokenizer, AutoModel
import torch

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load the PubMedBERT-base model and tokenizer
# tokenizer = AutoTokenizer.from_pretrained("neuml/pubmedbert-base-embeddings")
# model = AutoModel.from_pretrained("neuml/pubmedbert-base-embeddings")

tokenizer = AutoTokenizer.from_pretrained("pritamdeka/S-PubMedBert-MS-MARCO")
model = AutoModel.from_pretrained("pritamdeka/S-PubMedBert-MS-MARCO")

# Move model to GPU if available
model = model.to(device)

def encode_text(title, abstract):
    """Encode text using PubMedBERT with GPU support."""
    margin = 12
    max_length = 512 - margin # Maximum length for PubMedBERT
    text = f"{title} {abstract}"
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=max_length)
    
    # Move inputs to the same device as model
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Move embeddings back to CPU for numpy conversion
    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
    return embeddings

embedding_dim = model.config.hidden_size
print(f"Embedding dimension: {embedding_dim}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Embedding dimension: 768


In [4]:
from pymilvus import MilvusClient

milvus_client = MilvusClient(uri="./milvus_pmc.db")

collection_name = "pmc_trec_2016"

In [None]:
index_params = milvus_client.prepare_index_params()

index_params.add_index(
    field_name="vector", # Name of the vector field to be indexed
    index_type="IVF_FLAT", # Type of the index to create
    index_name="vector_index", # Name of the index to create
    metric_type="COSINE", # Metric type used to measure similarity
    params={
        "nlist": 64, # Number of clusters for the index
    } # Index building params
)

milvus_client.create_index(
    collection_name=collection_name,
    index_params=index_params
)


In [None]:
from pymilvus import FieldSchema, CollectionSchema, MilvusClient, DataType

fields = [
    FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=20),
    FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
    FieldSchema(name="doc", dtype=DataType.JSON)
]
schema = CollectionSchema(fields)

milvus_client.create_collection(
    collection_name=collection_name,
    schema=schema,
    metric_type="COSINE",
    consistency_level="Bounded",
    index_params=index_params
)

In [None]:
import os
import json
from tqdm import tqdm

data_folder = '/home/student/project/res_data'
data_subfolders = ['pmc-00', 'pmc-01', 'pmc-02', 'pmc-03']#[f for f in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, f))]

for subfolder in data_subfolders:
    print(f"Processing subfolder: {subfolder}")
    for file_name in tqdm(os.listdir(os.path.join(data_folder, subfolder))):
        if file_name.endswith('.json'):
            file_path = os.path.join(data_folder, subfolder, file_name)
            try:
                with open(file_path, 'r', encoding='utf-8') as file:
                    json_data = json.load(file)
                    
                data = []
                documents = json_data.get("documents", [])

                for doc in documents:
                    doc_id = doc.get("id", "")
                    title = doc.get("title", "")
                    abstract = doc.get("abstract", "")
                    authors = doc.get("authors", "")
                    source = doc.get("source", "")
                    # body = doc.get("body", "")

                    embedding = encode_text(title, abstract)

                    data.append({"id": doc_id, "vector": embedding, "doc": {"title": title, "abstract": abstract, "authors": authors, "source": source}})
                
                if data:
                    milvus_client.insert(collection_name=collection_name, data=data)
            except Exception as e:
                print(f"Error processing file {file_path}: {e}")
                continue

Processing subfolder: pmc-00


100%|██████████| 53/53 [2:04:49<00:00, 141.30s/it]  


Processing subfolder: pmc-01


100%|██████████| 49/49 [1:36:02<00:00, 117.60s/it]


Processing subfolder: pmc-02


100%|██████████| 78/78 [2:45:49<00:00, 127.56s/it]  


Processing subfolder: pmc-03


100%|██████████| 73/73 [2:43:14<00:00, 134.18s/it]  


In [None]:
# Remove XML declaration from all XML files
import os
from tqdm import tqdm

def remove_xml_declaration(res_data_folder):
    """Remove XML declaration from existing XML files."""

    for subfolder in os.listdir(res_data_folder):
        subfolder_path = os.path.join(res_data_folder, subfolder)
        if os.path.isdir(subfolder_path):
            print(f"Processing subfolder: {subfolder}")

            for file_name in tqdm(os.listdir(subfolder_path)):
                if file_name.endswith('.xml'):
                    file_path = os.path.join(subfolder_path, file_name)
                    
                    # Read the current content
                    with open(file_path, 'r', encoding='utf-8') as f:
                        content = f.read()
                    
                    # Remove XML declaration if it exists
                    if content.startswith('<?xml version="1.0" encoding="UTF-8"?>\n'):
                        # Remove the XML declaration
                        cleaned_content = content.replace('<?xml version="1.0" encoding="UTF-8"?>\n', '')
                        
                        # Write back the cleaned content
                        with open(file_path, 'w', encoding='utf-8') as f:
                            f.write(cleaned_content)
                        
                        # print(f"Removed XML declaration from: {file_name}")
                    else:
                        print(f"No XML declaration found in: {file_name}")

# Run the cleanup
res_data_folder = '/home/student/project/res_data'
remove_xml_declaration(res_data_folder)

Processing subfolder: pmc-01


100%|██████████| 49/49 [01:58<00:00,  2.41s/it]
100%|██████████| 49/49 [01:58<00:00,  2.41s/it]


Processing subfolder: pmc-00


100%|██████████| 53/53 [02:09<00:00,  2.45s/it]
100%|██████████| 53/53 [02:09<00:00,  2.45s/it]


Processing subfolder: pmc-03


100%|██████████| 73/73 [03:22<00:00,  2.78s/it]
100%|██████████| 73/73 [03:22<00:00,  2.78s/it]


Processing subfolder: pmc-02


100%|██████████| 78/78 [02:17<00:00,  1.76s/it]
100%|██████████| 78/78 [02:17<00:00,  1.76s/it]
