In [1]:
import os
import re
import xml.etree.ElementTree as ET
import time
import glob
from PIL import Image
import numpy as np
import torch
import faiss # Using CPU version now
from tqdm import tqdm
import gc
import math
import random
from torch.utils.data import DataLoader as TorchDataLoader
import pandas as pd
import base64
import io
import json # For embedding JS data and parsing LLM output


In [2]:

# Ollama Client Library
try:
    import ollama
    OLLAMA_AVAILABLE = True
except ImportError:
    print("WARNING: Ollama library not found. Triple extraction/generation skipped. Install with: pip install ollama")
    ollama = None
    OLLAMA_AVAILABLE = False

# Hugging Face Libraries
from transformers import (
    AutoProcessor, AutoModel, AutoTokenizer,
    pipeline, BitsAndBytesConfig
)
# Sentence Transformers no longer needed for fine-tuning

# Scikit-learn for cosine similarity
from sklearn.metrics.pairwise import cosine_similarity


# Configuration
# ------------------------------------------
CONFIG = {
    "scan_dir": r"D:\NLP apps\Scans",
    "report_dir": r"D:\NLP apps\Reports",
    "num_reports_to_process": 150,  # Keep low for testing KG vis
    "max_reports_total": 3999,
    "output_dir": r"D:\NLP apps\radiology_rag_kg_vis_output_v3", # New output dir

    # --- Triple Extraction (using Ollama Llama3) ---
    "triple_extractor_model": "llama3", # Ollama model for extraction
    "ollama_base_url": "http://localhost:11434", # Default Ollama API endpoint

    # --- Embedding Model (CLIP for Retrieval & Evaluation) ---
    "embedding_model_name": "openai/clip-vit-base-patch32", # Use CLIP

    # --- RAG Components ---
    # Generator Model (Using Ollama)
    "generator_type": "ollama",
    "ollama_generator_model": "llava-llama3", # Separate model for generation
    "ollama_num_ctx": 4096,
    # Retrieval
    "top_k_retrieval": 3,

    # --- Evaluation Metrics ---
    "eval_embedding_weight": 0.7, # Weight for embedding similarity in combined score
    "eval_graph_similarity_weight": 0.3, # Weight for graph similarity

    # --- Hardware ---
    "use_gpu": torch.cuda.is_available(),
    "embedding_device": "cuda" if torch.cuda.is_available() else "cpu",
    "faiss_use_gpu": False, # Sticking to CPU Faiss
    "generator_device": "cpu", # Ollama runs externally
    "triple_extractor_device": "cpu", # Ollama runs externally
}
# ------------------------------------------


In [3]:

# --- Helper Functions ---
def cleanup_memory():
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()

def extract_text_from_xml(xml_path):
    try:
        tree = ET.parse(xml_path); root = tree.getroot(); texts = []
        for tag in ['AbstractText', 'FINDINGS', 'IMPRESSION', 'REPORT_TEXT', 'paragraph']:
             for elem in root.findall(f'.//{tag}'):
                 if elem.text:
                     cleaned_text = re.sub(r'\s+', ' ', elem.text.strip())
                     if cleaned_text: texts.append(cleaned_text)
        if not texts:
             all_text = ' '.join(node.text.strip() for node in root.iter() if node.text and node.text.strip())
             if all_text: all_text = re.sub(r'\s+', ' ', all_text).strip(); texts.append(all_text)
        full_text = "\n".join(texts); return re.sub(r'\s+', ' ', full_text).strip()
    except Exception: return None

def find_images_for_report(report_id, scan_dir):
    front_pattern = os.path.join(scan_dir, f"CXR{report_id}_*_IM-*-4*.[Pp][Nn][Gg]")
    side_pattern = os.path.join(scan_dir, f"CXR{report_id}_*_IM-*-3*.[Pp][Nn][Gg]")
    front_images = glob.glob(front_pattern); side_images = glob.glob(side_pattern)
    front_image_path = front_images[0] if front_images else None
    side_image_path = side_images[0] if side_images else None
    if not front_image_path and not side_image_path:
         any_pattern = os.path.join(scan_dir, f"CXR{report_id}_*.[Pp][Nn][Gg]")
         any_images = sorted(glob.glob(any_pattern))
         if len(any_images) >= 1: front_image_path = any_images[0]
         if len(any_images) >= 2: side_image_path = any_images[1]
         if not front_image_path and not side_image_path:
            jpg_pattern = os.path.join(scan_dir, f"CXR{report_id}_*.[Jj][Pp][Gg]")
            jpeg_pattern = os.path.join(scan_dir, f"CXR{report_id}_*.[Jj][Pp][Ee][Gg]")
            jpg_images = sorted(glob.glob(jpg_pattern) + glob.glob(jpeg_pattern))
            if len(jpg_images) >= 1: front_image_path = jpg_images[0]
            if len(jpg_images) >= 2: side_image_path = jpg_images[1]
    return front_image_path, side_image_path

def encode_image_to_base64(image_path):
    try:
        with Image.open(image_path) as img:
            if img.mode != 'RGB': img = img.convert('RGB')
            buffer = io.BytesIO(); img.save(buffer, format="JPEG")
            img_bytes = buffer.getvalue(); base64_string = base64.b64encode(img_bytes).decode('utf-8')
            return base64_string
    except Exception as e: print(f"Error encoding image {image_path}: {e}"); return None

def calculate_graph_similarity(triples1, triples2):
    """ Calculates simple Jaccard similarity based on entities and predicates. """
    if not triples1 or not triples2: return 0.0
    entities1 = set(s for s,p,o in triples1) | set(o for s,p,o in triples1)
    entities2 = set(s for s,p,o in triples2) | set(o for s,p,o in triples2)
    predicates1 = set(p for s,p,o in triples1)
    predicates2 = set(p for s,p,o in triples2)
    entity_intersect = len(entities1.intersection(entities2)); entity_union = len(entities1.union(entities2))
    entity_sim = entity_intersect / entity_union if entity_union > 0 else 0
    predicate_intersect = len(predicates1.intersection(predicates2)); predicate_union = len(predicates1.union(predicates2))
    predicate_sim = predicate_intersect / predicate_union if predicate_union > 0 else 0
    # Give slightly more weight to entity overlap? Or keep simple average.
    return (entity_sim + predicate_sim) / 2.0


In [4]:

# --- Core Classes ---

class DataLoader:
    """Loads reports and associated images."""
    def __init__(self, report_dir, scan_dir, num_to_load, max_total):
        self.report_dir = report_dir; self.scan_dir = scan_dir
        self.num_to_load = min(num_to_load, max_total); self.max_total = max_total

    def load_data(self):
        """Loads report texts, IDs, and image paths."""
        data = []; report_files = sorted(glob.glob(os.path.join(self.report_dir, "*.[Xx][Mm][Ll]")))
        if not report_files: raise FileNotFoundError(f"No XML reports found in {self.report_dir}")
        print(f"Found {len(report_files)} reports. Processing up to {self.num_to_load}...")
        processed_count, skipped_count = 0, 0
        for report_path in tqdm(report_files, desc="Loading Reports"):
            if processed_count >= self.num_to_load: break
            report_filename = os.path.basename(report_path)
            report_id_match = re.match(r"(\d+)\.[Xx][Mm][Ll]", report_filename, re.IGNORECASE)
            if not report_id_match: continue
            report_id = report_id_match.group(1); report_text = extract_text_from_xml(report_path)
            if report_text:
                front_img, side_img = find_images_for_report(report_id, self.scan_dir)
                if front_img: # Require at least front image
                    data.append({"report_id": report_id, "report_path": report_path, "report_text": report_text,
                                 "front_image_path": front_img, "side_image_path": side_img})
                    processed_count += 1
                else: skipped_count += 1
            else: skipped_count += 1
        print(f"Successfully loaded {len(data)} reports with front images. Skipped {skipped_count}.")
        if len(data) < self.num_to_load: print(f"Warning: Loaded fewer reports ({len(data)}) than requested.")
        return data


In [5]:

class TripleExtractor:
    """ Extracts triples using an Ollama model. """
    def __init__(self, config):
        self.config = config
        self.model_name = config["triple_extractor_model"]
        self.base_url = config.get("ollama_base_url", "http://localhost:11434")
        self.client = None
        self._initialize_client()

    def _initialize_client(self):
        """Initializes Ollama client."""
        if not OLLAMA_AVAILABLE: print("Ollama library not available for Triple Extractor."); return
        try:
            print(f"Initializing Ollama client for Triple Extraction (Model: '{self.model_name}') at {self.base_url}...")
            self.client = ollama.Client(host=self.base_url); self.client.list()
            print("Ollama client for Triple Extraction initialized.")
            available_models = [m['name'] for m in self.client.list()['models']]
            if not any(m.startswith(self.model_name) for m in available_models):
                 print(f"Warning: Triple extractor model '{self.model_name}' not found in Ollama. Run `ollama pull {self.model_name}`.")
        except Exception as e: print(f"Error initializing Ollama client for triples: {e}"); self.client = None

    def _create_extraction_prompt(self, text):
        """ Creates the prompt for asking Llama3 to extract triples. """
        prompt = f"""
Analyze the following radiology report text. Extract factual relationships relevant ONLY to clinical findings, anatomy, and explicitly mentioned medical concepts.
Present the relationships as a JSON list of lists, where each inner list is a triple: [Subject, Predicate, Object].

Rules:
- Subjects and Objects MUST be specific clinical entities found in the text (e.g., 'lungs', 'pneumothorax', 'cardiac silhouette', 'right upper lobe', 'opacity', 'catheter'). Normalize terms (e.g., lowercase).
- Predicates SHOULD reflect the action or state described in the text, using verbs or short descriptive phrases where possible (e.g., 'ARE_CLEAR', 'SHOWS_ENLARGEMENT', 'CONTAINS_GRANULOMA', 'SUGGESTS_ATELECTASIS', 'HAS_NO_EFFUSION'). Use uppercase snake_case. Prefer predicates derived from the text's verbs.
- Extract ONLY relationships explicitly stated or strongly implied in the text. Do not infer relationships not present.
- Focus ONLY on medical facts relevant to the patient's condition as described. Ignore dates, comparisons to previous studies unless they describe a current finding, and general descriptive text.
- If a finding is explicitly negated (e.g., "no pneumothorax"), use a predicate reflecting negation (e.g., ['chest', 'HAS_NO_PNEUMOTHORAX', 'pneumothorax'] or ['pneumothorax', 'IS_ABSENT', '']).
- Output ONLY the JSON list of lists, nothing else. If no relevant triples are found, output an empty list [].

Radiology Report Text:
\"\"\"
{text}
\"\"\"

JSON Output:
"""
        return prompt

    def extract_triples(self, text):
        """ Extracts triples using the Ollama model. """
        if not text or self.client is None: return []
        prompt = self._create_extraction_prompt(text); triples = []
        try:
            response = self.client.generate(model=self.model_name, prompt=prompt, stream=False,
                                            options={'temperature': 0.1, 'num_ctx': CONFIG.get('ollama_num_ctx', 2048)})
            raw_output = response.get('response', '').strip()
            try:
                json_start = raw_output.find('['); json_end = raw_output.rfind(']') + 1
                if json_start != -1 and json_end > json_start:
                    json_str = raw_output[json_start:json_end]
                    parsed_output = json.loads(json_str)
                    if isinstance(parsed_output, list):
                        for item in parsed_output:
                            if isinstance(item, list) and len(item) == 3:
                                subj = str(item[0]).lower().strip(); pred = str(item[1]).upper().strip().replace(" ", "_"); obj = str(item[2]).lower().strip()
                                if subj and pred and obj and pred.isupper() and '_' in pred:
                                     triples.append((subj, pred, obj))
            except json.JSONDecodeError: print(f"Warning: Failed to parse JSON from LLM output: {raw_output}")
            except Exception as parse_e: print(f"Error parsing triples from LLM output: {parse_e}\nOutput: {raw_output}")
        except Exception as e: print(f"Error during Ollama call for triple extraction: {e}")
        return triples

    def unload_pipeline(self): pass


In [6]:

class EmbeddingManager:
    """Handles CLIP embeddings for retrieval and evaluation."""
    def __init__(self, config):
        self.config = config; self.device = config.get("embedding_device", "cpu")
        print(f"EmbeddingManager (CLIP) using device: {self.device}")
        self.model = None; self.processor = None; self.tokenizer = None
        self.faiss_index = None; self.report_id_map = []; self.loaded_model_path = None

    def _load_clip_model(self):
        """Loads CLIP model, processor, tokenizer."""
        model_name_or_path = self.config['embedding_model_name']
        if self.model is None or self.loaded_model_path != model_name_or_path:
            print(f"Loading CLIP model/processor/tokenizer: {model_name_or_path}...")
            try:
                self.processor = AutoProcessor.from_pretrained(model_name_or_path)
                self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
                self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device)
                self.model.eval(); self.loaded_model_path = model_name_or_path
                print("CLIP components loaded.")
            except Exception as e: print(f"Error loading CLIP {model_name_or_path}: {e}"); self.model=None; self.processor=None; self.tokenizer=None; self.loaded_model_path=None; raise

    def create_report_text_embeddings(self, reports_data):
        """Generates CLIP embeddings for full report texts (for RAG index)."""
        self._load_clip_model()
        if self.model is None or self.tokenizer is None: raise RuntimeError("CLIP model/tokenizer failed.")
        print(f"Generating CLIP report text embeddings using: {self.loaded_model_path}")
        report_texts = [item["report_text"] for item in reports_data]
        self.report_id_map = [item["report_id"] for item in reports_data]
        batch_size = 128; all_embeddings_list = []
        try:
            self.model.eval()
            for i in tqdm(range(0, len(report_texts), batch_size), desc="Embedding Reports (CLIP)"):
                 batch_texts = report_texts[i:i+batch_size]
                 inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device)
                 with torch.no_grad():
                     text_features = self.model.get_text_features(**inputs)
                     text_features = torch.nn.functional.normalize(text_features, p=2, dim=1)
                     all_embeddings_list.append(text_features.cpu())
            if not all_embeddings_list: raise ValueError("No embeddings generated.")
            embeddings_tensor = torch.cat(all_embeddings_list, dim=0)
            print(f"Generated {embeddings_tensor.shape[0]} CLIP report text embeddings.")
            return embeddings_tensor.numpy().astype('float32')
        except Exception as e: print(f"Error during report text embedding: {e}"); cleanup_memory(); return None

    # --- Added method for single text embedding ---
    def create_single_text_embedding(self, text: str):
        """Generates CLIP embedding for a single piece of text."""
        if not text: return None
        self._load_clip_model()
        if self.model is None or self.tokenizer is None:
            print("Error: CLIP model/tokenizer not loaded for single embedding.")
            return None
        try:
            self.model.eval()
            inputs = self.tokenizer([text], padding=True, truncation=True, return_tensors="pt").to(self.device)
            with torch.no_grad():
                text_features = self.model.get_text_features(**inputs)
                text_features = torch.nn.functional.normalize(text_features, p=2, dim=1)
            return text_features.cpu().numpy().astype('float32')
        except Exception as e:
            print(f"Error generating single text embedding: {e}")
            return None
    # --- End Added method ---

    def embed_query_image(self, image_path):
        """Generates query image embedding using CLIP."""
        self._load_clip_model()
        if not self.model or not self.processor: print("Error: CLIP model/processor unavailable."); return None
        if not image_path or not os.path.exists(image_path): print(f"Error: Invalid image path: {image_path}"); return None
        try:
            image = Image.open(image_path).convert("RGB")
            with torch.no_grad():
                inputs = self.processor(images=image, return_tensors="pt").to(self.device)
                image_features = self.model.get_image_features(**inputs)
                image_features = torch.nn.functional.normalize(image_features, p=2, dim=1)
            return image_features.cpu().numpy().astype('float32')
        except Exception as e: print(f"Error embedding query image {image_path}: {e}"); return None

    def build_faiss_index(self, embeddings):
        """Builds Faiss index."""
        if embeddings is None or embeddings.shape[0] == 0: print("Error: No embeddings for Faiss."); return None
        dimension = embeddings.shape[1]; num_embeddings = embeddings.shape[0]
        print(f"Building Faiss index for {num_embeddings} embeddings (Dim: {dimension})...")
        self.faiss_index = faiss.IndexFlatIP(dimension); print("Using CPU for Faiss index."); self.config["faiss_use_gpu"] = False
        faiss.normalize_L2(embeddings); self.faiss_index.add(embeddings); print(f"Faiss index built. Size: {self.faiss_index.ntotal}"); return self.faiss_index

    def save_index(self, index_path, map_path):
        """Saves Faiss index and map."""
        if self.faiss_index and self.report_id_map:
            print(f"Saving Faiss index ({self.faiss_index.ntotal}) to {index_path}"); faiss.write_index(self.faiss_index, index_path)
            print(f"Saving report ID map ({len(self.report_id_map)}) to {map_path}"); np.save(map_path, np.array(self.report_id_map, dtype=object))
        else: print("Index or map empty, nothing to save.")

    def load_index(self, index_path, map_path):
        """Loads Faiss index and map."""
        if os.path.exists(index_path) and os.path.exists(map_path):
            print(f"Loading Faiss index from {index_path}"); self.faiss_index = faiss.read_index(index_path)
            print(f"Loading report ID map from {map_path}"); self.report_id_map = np.load(map_path, allow_pickle=True).tolist()
            print(f"Loaded index ({self.faiss_index.ntotal}) and map ({len(self.report_id_map)}).")
            expected_dim = 512
            if self.faiss_index.d != expected_dim: print(f"WARNING: Index dim ({self.faiss_index.d}) != expected ({expected_dim}).")
            if self.faiss_index.ntotal != len(self.report_id_map): print(f"FATAL: Index size != map size."); self.faiss_index = None; self.report_id_map = []; return False
            print("Keeping loaded Faiss index on CPU."); self.config["faiss_use_gpu"] = False; return True
        else: print(f"Index/map file not found."); return False

    def unload_model(self):
        """Unloads loaded models"""
        if self.model is not None:
            print("Unloading CLIP AutoModel components...")
            del self.model; del self.processor; del self.tokenizer
            self.model = None; self.processor = None; self.tokenizer = None
        self.loaded_model_path = None
        cleanup_memory()


In [7]:

class Retriever:
    """Retrieves relevant reports using CLIP embeddings."""
    def __init__(self, embedding_manager, all_reports_data):
        self.embed_manager = embedding_manager
        self.report_lookup = {item['report_id']: item for item in all_reports_data}
        if len(self.embed_manager.report_id_map) > 0 and len(self.report_lookup) != len(self.embed_manager.report_id_map):
             print(f"Retriever Warning: Lookup/map size mismatch. Rebuilding lookup."); self.rebuild_lookup(all_reports_data)

    def rebuild_lookup(self, all_reports_data):
         self.report_lookup = {report_id: next((item for item in all_reports_data if item['report_id'] == report_id), None)
                               for report_id in self.embed_manager.report_id_map}
         self.report_lookup = {k: v for k, v in self.report_lookup.items() if v is not None}
         print(f"Rebuilt lookup size: {len(self.report_lookup)}")

    def retrieve(self, query_image_embedding, k):
        """Finds top-k reports using CLIP embeddings (Image to Text)."""
        if self.embed_manager.faiss_index is None: print("Error: Faiss index not ready."); return [], []
        if query_image_embedding is None: print("Error: Invalid query embedding."); return [], []
        if self.embed_manager.faiss_index.ntotal == 0: print("Error: Faiss index empty."); return [], []
        k_actual = min(k, self.embed_manager.faiss_index.ntotal)
        faiss.normalize_L2(query_image_embedding)
        # print(f"Searching index ({self.embed_manager.faiss_index.ntotal} items) for top {k_actual} reports...")
        distances, indices = [], []
        try:
            distances, indices = self.embed_manager.faiss_index.search(query_image_embedding, k_actual)
        except Exception as e: print(f"Error during Faiss search: {e}"); return [], []
        retrieved_reports_data, retrieved_ids = [], []
        if len(indices) == 0 or len(distances) == 0 or len(indices[0]) == 0: print("Warning: Faiss search returned empty results."); return [], []
        if indices[0][0] == -1: print("Warning: Faiss search returned no valid neighbors (-1 index)."); return [], []
        for i, idx in enumerate(indices[0]):
             if 0 <= idx < len(self.embed_manager.report_id_map):
                 report_id = self.embed_manager.report_id_map[idx]
                 if report_id in self.report_lookup:
                     report_data = self.report_lookup[report_id].copy()
                     report_data['retrieval_score'] = float(distances[0][i])
                     retrieved_reports_data.append(report_data); retrieved_ids.append(report_id)
        retrieved_reports_data.sort(key=lambda x: x['retrieval_score'], reverse=True)
        print(f"Retrieved {len(retrieved_reports_data)} reports.")
        return retrieved_reports_data, retrieved_ids


In [8]:

class Generator:
    """Generates radiology reports using Ollama."""
    def __init__(self, config):
        self.config = config; self.model_name = config["ollama_generator_model"] # Use specific key
        self.base_url = config.get("ollama_base_url", "http://localhost:11434"); self.client = None
        self._initialize_client()

    def _initialize_client(self):
        """Initializes Ollama client."""
        if not OLLAMA_AVAILABLE: print("Ollama library not available for Generator."); return
        try:
            print(f"Initializing Ollama client for GENERATION (Model: '{self.model_name}') at {self.base_url}...")
            self.client = ollama.Client(host=self.base_url); self.client.list(); print("Ollama client for GENERATION initialized.")
            available_models = [m['name'] for m in self.client.list()['models']]
            if not any(m.startswith(self.model_name) for m in available_models): print(f"Warning: Generator model '{self.model_name}' not found in Ollama.")
        except Exception as e: print(f"Error initializing Ollama client for generation: {e}"); self.client = None

    def format_prompt(self, image_path, retrieved_reports, retrieved_triples_map):
        """Creates text prompt for Ollama Generator."""
        context_str = ""
        if retrieved_reports:
             context_str += "Context from similar reports (higher score is more similar):\n"
             retrieved_reports.sort(key=lambda x: x.get('retrieval_score', -1), reverse=True)
             texts = [f"- {r['report_text'][:120]}..." for i, r in enumerate(retrieved_reports)]
             context_str += "\n".join(texts) + "\n"
             # Add triples context
             context_str += "Extracted facts from similar reports:\n"
             for i, r in enumerate(retrieved_reports):
                  triples = retrieved_triples_map.get(r['report_id'], [])
                  if triples: context_str += f"  Report {i+1} Facts: {'; '.join([f'({s}-{p}->{o})' for s, p, o in triples[:5]])}\n" # Show top 5

        final_prompt = (
            f"{context_str}\nGiven the provided chest X-ray image, and using the context above (report snippets and extracted facts) internally if helpful, "
            "generate a radiology report. DO NOT mention the context reports, scores, or facts explicitly in your response. "
            "The report should contain ONLY a 'Findings:' section and an 'Impression:' section. "
            "Start the report directly with 'Findings:'."
        )
        return final_prompt

    def generate(self, image_path, retrieved_reports, retrieved_triples_map):
        """Generates report text using Ollama API."""
        if self.client is None: print("Error: Ollama generator client not initialized."); return "Error: Ollama client not available."
        if not image_path or not os.path.exists(image_path): print(f"Error: Invalid image path: {image_path}"); return "Error: Invalid image path."
        base64_image = encode_image_to_base64(image_path)
        if base64_image is None: return "Error: Failed to encode image."
        prompt_text = self.format_prompt(image_path, retrieved_reports, retrieved_triples_map)
        print(f"Sending request to Ollama generator model: {self.model_name}..."); start_time = time.time()
        generated_text = f"Error: Ollama API call failed."
        try:
            ollama_options = {'num_ctx': self.config.get('ollama_num_ctx', 2048)}
            response = self.client.chat(model=self.model_name, messages=[{'role': 'user', 'content': prompt_text, 'images': [base64_image]}], options=ollama_options)
            if response and 'message' in response and 'content' in response['message']:
                 generated_text = response['message']['content'].strip(); end_time = time.time()
                 print(f"Ollama Generation took {end_time - start_time:.2f} seconds.")
                 print("\n--- Generated Report (Ollama) ---"); print(generated_text); print("---------------------------------\n")
            else: print(f"Error: Unexpected response from Ollama: {response}"); generated_text = "Error: Unexpected Ollama response."
        except Exception as e:
             print(f"Error during Ollama API call: {e}")
             if "connection refused" in str(e).lower(): print(">>> Is the Ollama server running? <<<")
             generated_text = f"Error during Ollama API call: {e}"
        finally: return generated_text


In [9]:

# --- Visualization Function ---
def create_kg_visualization_html(output_filename, query_info, retrieved_info, generated_info, evaluation_info):
    """Creates an HTML file with SEPARATE vis.js graph visualizations and evaluation."""

    html_content = f"""
<!DOCTYPE html>
<html>
<head>
    <title>Radiology RAG Report with KG Visualizations & Evaluation</title>
    <script type="text/javascript" src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script>
    <style type="text/css">
        body {{ font-family: sans-serif; line-height: 1.6; margin: 20px; }}
        .kg-container {{ margin-bottom: 30px; padding-bottom: 20px; border-bottom: 1px solid #ddd; }}
        .vis-network {{ width: 95%; height: 450px; border: 1px solid lightgray; margin-top: 10px; }}
        h1, h2, h3 {{ border-bottom: 1px solid #eee; padding-bottom: 5px; }}
        pre {{ white-space: pre-wrap; word-wrap: break-word; background-color: #f8f8f8; border: 1px solid #ddd; padding: 10px; border-radius: 4px; }}
        .evaluation {{ background-color: #eef; padding: 15px; border: 1px solid #dde; border-radius: 5px; margin-top: 20px; }}
        .evaluation strong {{ display: inline-block; min-width: 180px; }}
    </style>
</head>
<body>
    <h1>RAG Report with KG Visualizations & Evaluation</h1>
    <p><strong>Query Report ID:</strong> {query_info.get("report_id", "N/A")}</p>
    <p><strong>Query Image Path:</strong> {query_info.get("image_path", "N/A")}</p>
    <div>{query_info.get("image_html", "<p>Image not available.</p>")}</div>

    <h2>Retrieved Reports</h2>
    <ol>
"""
    retrieved_reports = retrieved_info.get("reports", [])
    if retrieved_reports:
        retrieved_reports.sort(key=lambda x: x.get('retrieval_score', -1), reverse=True)
        for i, r in enumerate(retrieved_reports):
            html_content += f"<li>ID={r['report_id']}, Score={r.get('retrieval_score', 'N/A'):.4f}<br>Text Snippet: <em>{r.get('report_text', 'N/A')[:150]}...</em></li>\n"
    else:
        html_content += "<p>No reports retrieved.</p>\n"
    html_content += "</ol>\n"

    html_content += f"<h2>Generated Report (Ollama: {generated_info.get('model_name', 'N/A')})</h2>\n"
    html_content += f"<pre>{generated_info.get('text', 'Error: Generation failed.')}</pre>\n"

    # --- Add Evaluation Section ---
    html_content += "<hr><h2>Evaluation vs Ground Truth</h2>\n"
    html_content += "<div class='evaluation'>\n"
    html_content += f"<p><strong>Embedding Similarity (Cosine):</strong> {evaluation_info.get('embedding_similarity', 'N/A'):.4f}</p>\n"
    html_content += f"<p><strong>KG Triple Similarity (Jaccard):</strong> {evaluation_info.get('graph_similarity', 'N/A'):.4f}</p>\n"
    html_content += f"<p><strong>Combined Similarity Score:</strong> {evaluation_info.get('combined_similarity', 'N/A'):.4f}</p>\n"
    # Display ground truth text for comparison
    html_content += f"<h3>Ground Truth Report Text:</h3>\n<pre>{evaluation_info.get('ground_truth_text', 'N/A')}</pre>\n"
    html_content += "</div>\n"
    # --- End Evaluation Section ---


    html_content += "<hr><h2>Knowledge Graph Visualizations</h2>"

    # --- Function to generate graph data and script for one report ---
    def generate_graph_section(div_id, title, triples, group_color='blue'):
        nodes, edges, node_ids, current_id = [], [], {}, 0
        def add_node(name, node_ids, nodes):
            nonlocal current_id
            name_lower = name.lower()
            if name_lower not in node_ids: node_ids[name_lower] = current_id; nodes.append({"id": current_id, "label": name}); current_id += 1 # No group needed for single graph
            return node_ids[name_lower]

        if triples:
            for subj, pred, obj in triples:
                subj_id = add_node(subj, node_ids, nodes)
                obj_id = add_node(obj, node_ids, nodes)
                edges.append({"from": subj_id, "to": obj_id, "label": pred, "arrows": "to", "color": group_color})
        else:
             nodes.append({"id": 0, "label": "No triples extracted", "color": "grey"})

        nodes_json = json.dumps(nodes); edges_json = json.dumps(edges)
        section_html = f"""
        <div class="kg-container">
            <h3>{title}</h3>
            <div id="{div_id}" class="vis-network"></div>
            <script type="text/javascript">
              (function() {{ // IIFE to avoid variable conflicts
                var nodes_{div_id} = new vis.DataSet({nodes_json}); var edges_{div_id} = new vis.DataSet({edges_json});
                var container_{div_id} = document.getElementById('{div_id}'); var data_{div_id} = {{ nodes: nodes_{div_id}, edges: edges_{div_id} }};
                var options_{div_id} = {{ nodes: {{ shape: 'box', size: 16, margin: 10 }}, edges: {{ font: {{ size: 10, align: 'middle' }}, smooth: {{ type: "continuous" }} }}, physics: {{ stabilization: {{ iterations: 150 }}, solver: 'repulsion' }}, interaction: {{ tooltipDelay: 200 }} }};
                var network_{div_id} = new vis.Network(container_{div_id}, data_{div_id}, options_{div_id});
              }})();
            </script>
        </div>"""
        return section_html

    # --- Generate KG for Generated Report ---
    html_content += generate_graph_section(div_id="gen_network", title="Generated Report KG",
                                           triples=generated_info.get("triples", []), group_color='blue')

    # --- Generate KG for Ground Truth Report ---
    html_content += generate_graph_section(div_id="gt_network", title="Ground Truth Report KG",
                                           triples=evaluation_info.get("ground_truth_triples", []), group_color='purple')


    # --- Generate KG for Retrieved Reports ---
    retrieved_triples_map = retrieved_info.get("triples_map", {})
    retrieved_colors = ["red", "green", "orange"] # Colors for retrieved reports
    for i, report_data in enumerate(retrieved_reports):
        report_id = report_data["report_id"]; triples = retrieved_triples_map.get(report_id, [])
        html_content += generate_graph_section(div_id=f"retrieved{i+1}_network", title=f"Retrieved Report {i+1} (ID: {report_id}) KG",
                                               triples=triples, group_color=retrieved_colors[i % len(retrieved_colors)])

    html_content += "</body>\n</html>"

    # Save the HTML content
    try:
        with open(output_filename, "w", encoding="utf-8") as f: f.write(html_content)
        print(f"Saved HTML report with KG visualization to: {output_filename}")
    except Exception as e: print(f"Error saving HTML output file {output_filename}: {e}")



In [10]:

# --- Main Execution ---
if __name__ == "__main__":
    print("Starting Radiology RAG System with KG Vis & Evaluation...")
    os.makedirs(CONFIG["output_dir"], exist_ok=True)
    cleanup_memory()

    # --- Check Dependencies ---
    if not OLLAMA_AVAILABLE: print("Ollama not available. Cannot proceed.") ; exit()

    # 1. Load Data
    print("\n--- Stage 1: Loading Data ---")
    data_loader = DataLoader(CONFIG["report_dir"], CONFIG["scan_dir"], CONFIG["num_reports_to_process"], CONFIG["max_reports_total"])
    all_data = data_loader.load_data();
    if not all_data: print("No data loaded. Exiting."); exit()
    # Create a quick lookup for ground truth text by ID
    ground_truth_lookup = {item['report_id']: item['report_text'] for item in all_data}


    # 2. Initialize Managers
    print("\n--- Stage 2: Initializing Managers ---")
    embed_manager = EmbeddingManager(CONFIG) # Handles CLIP embeddings
    triple_extractor = TripleExtractor(CONFIG) # Handles Ollama triple extraction
    if triple_extractor.client is None: print("Triple extractor failed to initialize. Exiting."); exit()


    # 4. Create/Load RAG Index using BASE CLIP model
    print("\n--- Stage 3: Creating/Loading RAG Index (Base CLIP) ---")
    index_file = os.path.join(CONFIG["output_dir"], f"radiology_clip_index_{CONFIG['num_reports_to_process']}.faiss")
    map_file = os.path.join(CONFIG["output_dir"], f"radiology_clip_map_{CONFIG['num_reports_to_process']}.npy")
    model_path_for_indexing = CONFIG['embedding_model_name']
    print(f"Using embedding model path for index: {model_path_for_indexing}")
    if not embed_manager.load_index(index_file, map_file):
        print("Building RAG index from scratch...")
        report_text_embeddings_np = embed_manager.create_report_text_embeddings(all_data)
        if report_text_embeddings_np is not None and report_text_embeddings_np.shape[0] > 0:
            embed_manager.build_faiss_index(report_text_embeddings_np)
            embed_manager.save_index(index_file, map_file)
        else: print("Failed to create report text embeddings. Cannot build RAG index. Exiting."); exit()
    else: print("Loaded existing RAG index and map.")

    # 5. Initialize Retriever
    print("\n--- Stage 4: Initializing Retriever ---")
    retriever = Retriever(embed_manager, all_data)

    # 6. Initialize Generator 
    print("\n--- Stage 5: Initializing Ollama Generator ---")
    generator = None
    if OLLAMA_AVAILABLE:
        try:
            generator = Generator(CONFIG)
            if generator.client is None: raise RuntimeError("Ollama client failed.")
        except Exception as e: print(f"Failed to initialize Ollama Generator: {e}"); generator = None; cleanup_memory()
    else: print("Ollama library not available. Skipping Generator.")

    # 7. Initialize Triple Extractor (reuse instance)
    print("\n--- Stage 6: Initializing Triple Extractor for RAG ---")
    triple_extractor_rag = triple_extractor # Reuse the already initialized one
    if not (triple_extractor_rag and triple_extractor_rag.client):
         print("Triple extractor not available. KG extraction will be skipped.")


    # --- Example RAG Pipeline Execution ---
    print("\n--- Running Example RAG with KG Extraction, Visualization & Evaluation ---")
    if not all_data: print("No data available."); exit()

    # --- Select Query Item ---
    TARGET_REPORT_ID = "173" # <<< CHANGE THIS ID TO TEST A DIFFERENT REPORT, or set to None for random
    query_item = None
    if TARGET_REPORT_ID:
        query_item = next((item for item in all_data if item['report_id'] == TARGET_REPORT_ID), None)
        if not query_item:
            print(f"Error: Target report ID '{TARGET_REPORT_ID}' not found. Falling back..."); TARGET_REPORT_ID = None
        elif not query_item.get("front_image_path") or not os.path.exists(query_item["front_image_path"]):
             print(f"Error: Target report ID '{TARGET_REPORT_ID}' image missing. Falling back..."); TARGET_REPORT_ID = None; query_item = None
    if query_item is None: # Fallback if target ID failed or wasn't specified
        valid_items = [item for item in all_data if item.get("front_image_path") and os.path.exists(item["front_image_path"])]
        if not valid_items: print("Could not find any valid query items. Exiting."); exit()
        query_item = random.choice(valid_items); print(f"Using report ID: {query_item['report_id']}")
    # --- End Select Query Item ---

    query_report_id = query_item["report_id"]; query_image_path = query_item["front_image_path"]
    ground_truth_text = ground_truth_lookup.get(query_report_id, "") # Get ground truth text
    print(f"Querying with Report ID: {query_report_id}, Image: {query_image_path}")

    # a. Embed query image using CLIP
    query_embedding = embed_manager.embed_query_image(query_image_path)

    if query_embedding is not None:
        # b. Retrieve relevant reports
        retrieved_reports_data, retrieved_ids = retriever.retrieve(query_embedding, k=CONFIG["top_k_retrieval"])

        # c. Extract triples for retrieved reports (using Ollama Llama3)
        retrieved_triples_map = {}
        if retrieved_ids and triple_extractor_rag and triple_extractor_rag.client:
            print(f"Extracting triples from retrieved reports: {retrieved_ids}")
            for report_data in tqdm(retrieved_reports_data, desc="Extracting Retrieved Triples"):
                triples = triple_extractor_rag.extract_triples(report_data["report_text"])
                if triples: retrieved_triples_map[report_data["report_id"]] = triples
        else: print("Skipping triple extraction for retrieved reports.")

        # d. Generate report (using Ollama LLaVA)
        generated_report_text = "Error: Generator not available."
        if generator and generator.client:
             generated_report_text = generator.generate(query_image_path, retrieved_reports_data, retrieved_triples_map)
        else: print("Skipping generation as Ollama generator is not available.")

        # e. Extract triples from generated report AND ground truth report
        generated_triples = []
        ground_truth_triples = []
        if not generated_report_text.startswith("Error:") and triple_extractor_rag and triple_extractor_rag.client:
            print("Extracting triples from generated report...")
            generated_triples = triple_extractor_rag.extract_triples(generated_report_text)
        else: print("Skipping triple extraction for generated report.")

        if ground_truth_text and triple_extractor_rag and triple_extractor_rag.client:
             print("Extracting triples from ground truth report...")
             ground_truth_triples = triple_extractor_rag.extract_triples(ground_truth_text)
        else: print("Skipping triple extraction for ground truth report.")

        # --- f. Perform Evaluation ---
        print("\n--- Performing Evaluation vs Ground Truth ---")
        eval_results = {"embedding_similarity": 0.0, "graph_similarity": 0.0, "combined_similarity": 0.0,
                        "ground_truth_text": ground_truth_text, "ground_truth_triples": ground_truth_triples}

        if not generated_report_text.startswith("Error:") and ground_truth_text:
            # i. Embedding Similarity
            print("Calculating embedding similarity...")
            gen_embedding = embed_manager.create_single_text_embedding(generated_report_text)
            gt_embedding = embed_manager.create_single_text_embedding(ground_truth_text)
            if gen_embedding is not None and gt_embedding is not None:
                # Cosine similarity between the two 1D vectors
                sim = cosine_similarity(gen_embedding, gt_embedding)[0][0]
                eval_results["embedding_similarity"] = float(sim) # Ensure float
                print(f"  Embedding Similarity: {sim:.4f}")
            else:
                print("  Could not calculate embedding similarity (embedding failed).")
                eval_results["embedding_similarity"] = "Error"


            # ii. Graph (Triple) Similarity
            print("Calculating graph similarity...")
            graph_sim = calculate_graph_similarity(generated_triples, ground_truth_triples)
            eval_results["graph_similarity"] = graph_sim
            print(f"  Graph Similarity (Jaccard): {graph_sim:.4f}")

            # iii. Combined Metric
            if isinstance(eval_results["embedding_similarity"], float): # Check if embedding sim was successful
                 combined_sim = (CONFIG["eval_embedding_weight"] * eval_results["embedding_similarity"] +
                                 CONFIG["eval_graph_similarity_weight"] * eval_results["graph_similarity"])
                 eval_results["combined_similarity"] = combined_sim
                 print(f"  Combined Similarity: {combined_sim:.4f}")
            else:
                 eval_results["combined_similarity"] = "Error"
                 print(f"  Combined Similarity: Error (due to embedding error)")

        else:
            print("Skipping evaluation because generated report or ground truth is missing/invalid.")
            eval_results = {k: "N/A" for k in eval_results} # Set all to N/A
            eval_results["ground_truth_text"] = ground_truth_text # Still keep GT text if available
            eval_results["ground_truth_triples"] = ground_truth_triples


        # Unload models
        print("Unloading models...")
        if 'triple_extractor_rag' in locals() and triple_extractor_rag: triple_extractor_rag.unload_pipeline()
        embed_manager.unload_model()
        cleanup_memory()

        # g. Prepare data for HTML output
        query_info = {"report_id": query_report_id, "image_path": query_image_path,
                      "image_html": f'<img src="data:image/jpeg;base64,{encode_image_to_base64(query_image_path)}" alt="Query Image {query_report_id}" style="max-width: 400px; height: auto; border: 1px solid #ccc; margin-bottom: 10px;"><br>' if os.path.exists(query_image_path) else "<p>Query image not found.</p>"}
        retrieved_info = {"reports": retrieved_reports_data, "triples_map": retrieved_triples_map}
        generated_info = {"model_name": CONFIG["ollama_generator_model"], "text": generated_report_text, "triples": generated_triples}

        # h. Create HTML visualization including evaluation
        output_filename_html = os.path.join(CONFIG["output_dir"], f"generated_report_eval_kg_vis_{CONFIG['ollama_generator_model'].replace(':','-')}_{query_report_id}.html")
        create_kg_visualization_html(output_filename_html, query_info, retrieved_info, generated_info, eval_results) # Pass eval results

    else: print("Failed to generate query image embedding. Cannot proceed.")

    print("\nRadiology RAG+KG Vis System Finished.")
    cleanup_memory()


Starting Radiology RAG System with KG Vis & Evaluation...

--- Stage 1: Loading Data ---
Found 393 reports. Processing up to 150...


Loading Reports:  39%|███▉      | 154/393 [00:01<00:01, 146.29it/s]


Successfully loaded 150 reports with front images. Skipped 4.

--- Stage 2: Initializing Managers ---
EmbeddingManager (CLIP) using device: cuda
Initializing Ollama client for Triple Extraction (Model: 'llama3') at http://localhost:11434...
Ollama client for Triple Extraction initialized.

--- Stage 3: Creating/Loading RAG Index (Base CLIP) ---
Using embedding model path for index: openai/clip-vit-base-patch32
Loading Faiss index from D:\NLP apps\radiology_rag_kg_vis_output_v3\radiology_clip_index_150.faiss
Loading report ID map from D:\NLP apps\radiology_rag_kg_vis_output_v3\radiology_clip_map_150.npy
Loaded index (150) and map (150).
Keeping loaded Faiss index on CPU.
Loaded existing RAG index and map.

--- Stage 4: Initializing Retriever ---

--- Stage 5: Initializing Ollama Generator ---
Initializing Ollama client for GENERATION (Model: 'llava-llama3') at http://localhost:11434...
Ollama client for GENERATION initialized.

--- Stage 6: Initializing Triple Extractor for RAG ---

---

Extracting Retrieved Triples: 100%|██████████| 3/3 [11:07<00:00, 222.46s/it]


Sending request to Ollama generator model: llava-llama3...
Ollama Generation took 65.77 seconds.

--- Generated Report (Ollama) ---
Findings:
- The chest X-ray reveals that the cardiomediastinal silhouette is within normal limits.
- There are no visible signs of pneumothorax or pleural effusion.
- Focal consolidation and rib fractures were not observed.

Impression:
The findings indicate a relatively normal chest scan with no significant abnormalities.
---------------------------------

Extracting triples from generated report...
Extracting triples from ground truth report...

--- Performing Evaluation vs Ground Truth ---
Calculating embedding similarity...
  Embedding Similarity: 0.8741
Calculating graph similarity...
  Graph Similarity (Jaccard): 0.6333
  Combined Similarity: 0.8018
Unloading models...
Unloading CLIP AutoModel components...
Saved HTML report with KG visualization to: D:\NLP apps\radiology_rag_kg_vis_output_v3\generated_report_eval_kg_vis_llava-llama3_173.html

Radiol