In [1]:
"""Load reports and embed them into the Vector store so they can be searched."""

# Imports & Basic Setup

import re
import sys
import random
from typing import Protocol, List, Optional, Type
import os

import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector

from sentence_transformers import SentenceTransformer
from tokenizers import Tokenizer
from transformers import AutoTokenizer, AutoModel
import torch
import pyarrow as pa
import pandas as pd

from pydantic import BaseModel, Field


In [2]:

# Define data models


class Report(BaseModel):
    """Represents the entire text of a radiology report."""
    id: str
    text: str

class Fragment(BaseModel):
    """Represents a chunk of text (one section or smaller) from a report."""
    report_id: str
    section: Optional[str]
    sequence_number: int
    text: str
    vector: Optional[List[float]] = None

In [3]:
class SectionSplitter:
    """
    Splits text into sections based on naive headings:
    e.g. "Header:", "Findings:", "Impression:".
    """
    def __init__(self):
        self.known_sections = ["Header:", "Findings:", "Impression:"]

    def split_into_sections(self, report_text: str):
        """
        Returns a list of tuples (section_label, section_text).
        If no headings are found, returns one tuple with (None, entire_text).
        """
        pattern = r"(" + "|".join(map(re.escape, self.known_sections)) + r")"
        parts = re.split(pattern, report_text)
        results = []
        current_section_label = None
        current_text_chunks = []

        for part in parts:
            part_stripped = part.strip()
            if not part_stripped:
                continue

            if part in self.known_sections:
                # Save the previous chunk
                if current_section_label and current_text_chunks:
                    combined_text = " ".join(current_text_chunks).strip()
                    results.append((current_section_label, combined_text))
                # Update the label
                current_section_label = part_stripped
                current_text_chunks = []
            else:
                current_text_chunks.append(part_stripped)

        # Final chunk
        if current_section_label and current_text_chunks:
            combined_text = " ".join(current_text_chunks).strip()
            results.append((current_section_label, combined_text))

        if not results and report_text.strip():
            # No recognized sections, return entire text
            results.append((None, report_text.strip()))

        return results

    def create_smaller_fragments(self, section_fragments):
        smaller_fragments = []
        for label, text in section_fragments:
            if ':' in text:
                fragments = text.split(':')
                for i in range(1, len(fragments)):
                    fragment_text = fragments[i-1].split()[-1] + ': ' + fragments[i].strip()
                    smaller_fragments.append((label, fragment_text.strip()))
            else:
                smaller_fragments.append((label, text.strip()))
        return smaller_fragments

# NOTE: Maybe refactor this? DRY violation.??
# Function to create fragments from a SINGLE report
def create_fragments_from_report(
    report: Report,
    section_splitter: SectionSplitter
) -> List[Fragment]:
    """
    1) Split the entire report text into sections.
    2) Chunk each section by tokens.
    3) Return a list of Fragment objects.
    """
    fragments = []
    sections = section_splitter.split_into_sections(report.text)
    smaller_fragments = section_splitter.create_smaller_fragments(sections)
    seq_num = 0

    for section_label, section_text in smaller_fragments:
        fragments.append(
            Fragment(
                report_id=report.id,
                section=section_label,
                sequence_number=seq_num,
                text=section_text,
                vector=None
            )
        )
        seq_num += 1

    return fragments
# Function to create fragments from a folder of reports
def create_fragments_from_reports(
    reports: List[Report],
    section_splitter: SectionSplitter
) -> List[Fragment]:
    """
    Process a list of Report objects and return a list of Fragment objects.
    """
    all_fragments = []
    for report in reports:
        fragments = create_fragments_from_report(report, section_splitter)
        all_fragments.extend(fragments)
    return all_fragments



In [4]:
# ------------------------------
# 3. Load the embedding model
# ------------------------------
# Example: "BAAI/bge-en-icl" – you can pick a different one if you want
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-en-icl")
model = AutoModel.from_pretrained("BAAI/bge-en-icl")
def embed_text(text: str):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    # Get the embeddings from the model
    with torch.no_grad():
        outputs = model(**inputs)
    # Use the last hidden state as the embedding
    embeddings = outputs.last_hidden_state.mean(dim=1)
    return embeddings.squeeze().tolist()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
def read_reports_from_folder(folder_path: str) -> List[Report]:
    reports = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".txt"):
            with open(os.path.join(folder_path, filename), 'r', encoding='utf-8') as file:
                text = file.read()
                # The "id" of the report is just the filename in this example
                reports.append(Report(id=filename, text=text))
    return reports

reports_dir = "./reports"  # Folder containing .txt files
reports_list = read_reports_from_folder(reports_dir)

In [6]:
# ------------------------------
# 6. Create all fragments
# ------------------------------
splitter = SectionSplitter()
all_fragments = create_fragments_from_reports(reports_list, splitter)


In [7]:
manual_inserts = []
for frag in all_fragments:
    # Manually embed the text
    vec = embed_text(frag.text)

    # Build the LanceDB-compatible object
    # We need a dict with the Pydantic fields:
    doc = {
        "report_id": frag.report_id,
        "section": frag.section,
        "sequence_number": frag.sequence_number,
        "text": frag.text,
        "vector": vec
    }
    manual_inserts.append(doc)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [13]:
print("Finished embedding all fragments.")
print(manual_inserts)


Finished embedding all fragments.
[{'report_id': '3.txt', 'section': 'Header:', 'sequence_number': 0, 'text': 'Chest CT without contrast, conducted to assess lung findings.', 'vector': [0.22026760876178741, 1.1508272886276245, 1.2065538167953491, 4.365844249725342, -2.035968542098999, -1.9974439144134521, -2.4552254676818848, -0.07327519357204437, -0.7504902482032776, 0.8640574216842651, -6.253283500671387, -2.4446847438812256, 2.207645893096924, -4.157927989959717, 5.051506996154785, 0.1632586419582367, 3.948500871658325, -0.38817113637924194, -1.6236820220947266, -4.712957859039307, 2.5640480518341064, -2.609605073928833, -3.4489054679870605, 0.8079095482826233, -2.5157949924468994, 0.43806585669517517, 3.24462628364563, -3.1169192790985107, 1.4397822618484497, -2.9550304412841797, 5.454566955566406, 3.3199005126953125, -3.8266241550445557, 0.07335971295833588, 0.4662316143512726, -0.4557182788848877, -0.24913910031318665, 0.03170201927423477, -3.8224236965179443, -1.0516469478607178

In [9]:
# ------------------------------
# 4. Create LanceDB connection
# ------------------------------
db_path = "./data/lancedb"  # or wherever you'd like
os.makedirs(db_path, exist_ok=True)

db = lancedb.connect(db_path)

In [14]:
table = db.create_table("manual",data=manual_inserts)

In [15]:
table

LanceTable(connection=LanceDBConnection(/Users/yz/dev/embed_reports_lancedb/data/lancedb), name="manual")

In [16]:
print(table.head(10))

pyarrow.Table
report_id: string
section: string
sequence_number: int64
text: string
vector: fixed_size_list<item: float>[4096]
  child 0, item: float
----
report_id: [["3.txt","3.txt","3.txt","3.txt","3.txt","3.txt","3.txt","3.txt","3.txt","3.txt"]]
section: [["Header:","Findings:","Findings:","Findings:","Findings:","Findings:","Findings:","Findings:","Findings:","Impression:"]]
sequence_number: [[0,1,2,3,4,5,6,7,8,9]]
text: [["Chest CT without contrast, conducted to assess lung findings.","Devices/Tubes/Lines: None.
Lungs","Lungs: Persistent 5 mm left lower lobe nodule with no significant growth. Scattered subcentimeter calcified granulomas. No new consolidation or masses. Mild bronchiectasis in the right middle lobe. Clear central airways.
Pleura","Pleura: Trace right pleural effusion. No pneumothorax.
Mediastinum","Mediastinum: Normal size and configuration of the heart. Minimal calcifications in the aortic root. No lymphadenopathy.
Lymph Nodes","Nodes: No axillary or hilar lymphad