In [1]:
# --- Load JSON Config ---
import json
import matplotlib.cm as cm
CONFIG_PATH = "xray_config.json"

with open(CONFIG_PATH, "r") as f:
    CONFIG = json.load(f)

# Model parameters
STANDARD_OVERLAY_SIZE = tuple(CONFIG["model"]["standard_overlay_size"])
WEIGHT = CONFIG["model"]["weight"]

# Thresholds
PATHOLOGY_THRESHOLD = CONFIG["thresholds"]["pathology"]
ANATOMY_THRESHOLD = CONFIG["thresholds"]["anatomy"]

# Disease overlay parameters
DISEASE_BOX_COLOR = tuple(CONFIG["disease_overlay"]["box_color"])
DISEASE_BOX_THICKNESS = CONFIG["disease_overlay"]["box_thickness"]
DISEASE_ALPHA = CONFIG["disease_overlay"]["alpha"]

# Visualization parameters
VISUAL_ALPHAS = CONFIG["visualization"]["alphas"]
VISUAL_COLORMAPS = [getattr(cm, name) for name in CONFIG["visualization"]["colormaps"]]
LABEL_MIN_DIST = CONFIG["visualization"]["label_min_dist"]
LABEL_OFFSET_STEP = CONFIG["visualization"]["label_offset_step"]
LABEL_FONTSIZE = CONFIG["visualization"]["label_fontsize"]

# ------------------------
# XRayProcessor Class
# ------------------------
import os, shutil, datetime
import torch, numpy as np, skimage.io
from skimage.transform import resize
import torchxrayvision as xrv
from safetensors.torch import save_file
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import cv2
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

class ImagePreprocessor:
    @staticmethod
    def load_and_preprocess_for_display(img_path):
        img = skimage.io.imread(img_path)
        if len(img.shape) == 3 and img.shape[2] == 4:
            img = img[..., :3]
        img_normalized = img.astype(np.float32)
        img_normalized = (img_normalized - img_normalized.min()) / (img_normalized.max() - img_normalized.min() + 1e-8)
        return resize(img_normalized, STANDARD_OVERLAY_SIZE, anti_aliasing=True, preserve_range=True)

    @staticmethod
    def load_and_preprocess_for_model(img_path, target_size):
        img = skimage.io.imread(img_path)
        if len(img.shape) == 3:
            img = img.mean(2)
        img = xrv.datasets.normalize(img, 255)
        img = resize(img, (target_size, target_size), anti_aliasing=True, preserve_range=True)
        return img[None, ...]

    @staticmethod
    def convert_dicom_to_jpg(dicom_path, output_dir, jpg_filename):
        import dicom2jpg
        try:
            dicom2jpg.dicom2jpg(dicom_path, output_dir)
            for file in os.listdir(output_dir):
                if file.lower().endswith(".jpg"):
                    shutil.move(os.path.join(output_dir, file), os.path.join(output_dir, jpg_filename))
                    return os.path.join(output_dir, jpg_filename)
            raise FileNotFoundError("No JPG file found after conversion.")
        except Exception as e:
            raise RuntimeError(f"Error converting DICOM to JPG: {e}")

class XRayProcessor:
    def __init__(self, xray_file, output_path, unique_id):
        self.xray_file = xray_file
        self.output_path = output_path
        self.id = str(unique_id)
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        os.makedirs(self.output_path, exist_ok=True)
        print(f"[{now}][API] Output directory ensured: {self.output_path}")
        self.jpg_path = os.path.join(output_path, f"{self.id}.jpg")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"[{now}][API] Using device: {self.device}")
        self.cls_model = xrv.models.DenseNet(weights=WEIGHT).to(self.device).eval()
        self.seg_model = xrv.baseline_models.chestx_det.PSPNet().to(self.device).eval()

    def _get_image_for_processing(self):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        ext = os.path.splitext(self.xray_file)[-1].lower()
        if ext == ".dcm":
            jpg_path = ImagePreprocessor.convert_dicom_to_jpg(self.xray_file, self.output_path, os.path.basename(self.jpg_path))
            print(f"[{now}][API] Converted DICOM to JPG: {jpg_path}")
            return jpg_path
        elif ext in [".jpg", ".png", ".jpeg"]:
            shutil.copy2(self.xray_file, self.jpg_path)
            print(f"[{now}][API] Copied image: {self.jpg_path}")
            return self.jpg_path
        else:
            raise ValueError("Unsupported file format. Only DICOM, JPG, PNG are supported.")

    def _analyze_pathologies(self, img_path):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{now}][API] Analyzing image for pathologies: {img_path}")
        img = ImagePreprocessor.load_and_preprocess_for_model(img_path, target_size=224)
        input_tensor = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
        with torch.no_grad():
            outputs = self.cls_model(input_tensor)
        results = dict(zip(self.cls_model.pathologies, map(float, outputs[0].cpu().numpy())))
        print(f"[{now}][API] Pathology analysis complete.")
        return results

    def _generate_disease_overlays(self, img_path, original_image):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{now}][API] Generating disease overlays.")
        img = ImagePreprocessor.load_and_preprocess_for_model(img_path, target_size=224)
        input_tensor = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
        target_layer = self.cls_model.features.denseblock3.denselayer16.conv2
        cam = GradCAM(model=self.cls_model, target_layers=[target_layer])

        if len(original_image.shape) == 2:
            original_image_rgb = np.stack([original_image]*3, axis=-1)
        else:
            original_image_rgb = original_image.copy()

        overlays = { "original_display_image": torch.from_numpy(original_image_rgb).float() }

        for i, pathology in enumerate(self.cls_model.pathologies):
            grayscale_cam = cam(input_tensor=input_tensor, targets=[ClassifierOutputTarget(i)])[0]
            norm_heatmap = (grayscale_cam - grayscale_cam.min()) / (grayscale_cam.max() - grayscale_cam.min() + 1e-8)
            cam_resized = resize(norm_heatmap, STANDARD_OVERLAY_SIZE, preserve_range=True)
            cam_rgb = np.stack([cam_resized]*3, axis=-1)
            boxed_img = (1-DISEASE_ALPHA)*original_image_rgb + DISEASE_ALPHA*cam_rgb
            mask = cam_resized >= PATHOLOGY_THRESHOLD
            mask_uint8 = (mask.astype(np.uint8)*255)
            contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if contours:
                contour = max(contours, key=cv2.contourArea)
                x, y, w, h = cv2.boundingRect(contour)
                boxed_img_uint8 = (boxed_img*255).astype(np.uint8)
                cv2.rectangle(boxed_img_uint8, (x, y), (x+w, y+h), DISEASE_BOX_COLOR, DISEASE_BOX_THICKNESS)
                overlays[pathology] = torch.from_numpy(boxed_img_uint8/255.0).float()
            else:
                overlays[pathology] = torch.from_numpy(boxed_img).float()
        print(f"[{now}][API] Disease overlays generated.")
        return overlays

    def _generate_anatomical_overlays(self, img_path, original_image):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"[{now}][API] Generating anatomical overlays.")
        img = ImagePreprocessor.load_and_preprocess_for_model(img_path, target_size=512)
        input_tensor = torch.from_numpy(img).unsqueeze(0).float().to(self.device)
        overlays = { "original_display_image": torch.from_numpy(original_image).float() }
        with torch.no_grad():
            pred = self.seg_model(input_tensor).cpu().squeeze(0).numpy()
        for i, region in enumerate(self.seg_model.targets):
            heatmap = pred[i]
            norm_heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
            mask = norm_heatmap >= ANATOMY_THRESHOLD
            active_area = norm_heatmap.copy()
            active_area[~mask] = np.nan
            overlays[region] = torch.from_numpy(active_area).float()
        print(f"[{now}][API] Anatomical overlays generated.")
        return overlays

    def process(self):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        try:
            img_path = self._get_image_for_processing()
            original_image = ImagePreprocessor.load_and_preprocess_for_display(img_path)
            results = self._analyze_pathologies(img_path)
            json_path = os.path.join(self.output_path, f"{self.id}.json")
            with open(json_path, "w") as f:
                json.dump({"id": self.id, "input_img": os.path.basename(self.xray_file), "results": results}, f, indent=4)
            print(f"[{now}][API] Saved pathology results: {json_path}")
            disease_overlays = self._generate_disease_overlays(img_path, original_image)
            save_file(disease_overlays, os.path.join(self.output_path, f"DO-{self.id}.safetensors"))
            anatomical_overlays = self._generate_anatomical_overlays(img_path, original_image)
            save_file(anatomical_overlays, os.path.join(self.output_path, f"AO-{self.id}.safetensors"))
            print(f"[{now}][API] Processing complete for ID: {self.id}")
            return self.id, results
        except Exception as e:
            print(f"[{now}][API] Processing failed: {e}")
            return None

# ------------------------
# XRayVisualizer Class
# ------------------------
from safetensors.torch import load_file
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.patheffects as path_effects

class XRayVisualizer:
    def __init__(self, uuid_str, output_path):
        self.uuid = uuid_str
        self.output_path = output_path
        self.disease_results = None
        self.disease_data = None
        self.anatomical_data = None
        self.base_image_display = None
        self._load_data()

    def _load_data(self):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        try:
            json_path = os.path.join(self.output_path, f"{self.uuid}.json")
            disease_path = os.path.join(self.output_path, f"DO-{self.uuid}.safetensors")
            anatomy_path = os.path.join(self.output_path, f"AO-{self.uuid}.safetensors")
            with open(json_path, "r") as f:
                self.disease_results = json.load(f)["results"]
            self.disease_data = load_file(disease_path)
            self.anatomical_data = load_file(anatomy_path)
            self.base_image_display = self.disease_data["original_display_image"].cpu().numpy()
            if self.base_image_display.ndim == 2:
                self.base_image_display = np.stack([self.base_image_display]*3, axis=-1)
            print(f"[{now}][API] Data loaded successfully for UUID: {self.uuid}")
        except Exception as e:
            print(f"[{now}][API] Data loading failed: {e}")
            raise

    def get_disease_results(self):
        return self.disease_results

    def show_overlays(self, keys, alphas=None, colormaps=None, return_bytes=False):
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        if self.base_image_display is None:
            print(f"[{now}][API] Visualization failed: No base image")
            return None
        alphas = alphas if alphas else VISUAL_ALPHAS
        colormaps = colormaps if colormaps else VISUAL_COLORMAPS
        fig, ax = plt.subplots(figsize=(10,10))
        composite_image = self.base_image_display.copy()
        ax.imshow(composite_image)
        ax.axis('off')
        label_positions = []

        def is_too_close(x, y):
            for (ex, ey) in label_positions:
                if np.sqrt((x-ex)**2 + (y-ey)**2) < LABEL_MIN_DIST:
                    return True
            return False

        for key, alpha, cmap in zip(keys, alphas, colormaps):
            overlay_tensor = self.disease_data.get(key) or self.anatomical_data.get(key)
            if overlay_tensor is None:
                print(f"[{now}][API] Warning: Key '{key}' not found")
                continue
            overlay_data = overlay_tensor.cpu().numpy()
            if key in self.disease_data:
                heatmap_rgb = overlay_data
                valid_mask = np.any(heatmap_rgb != 0, axis=-1)
            else:
                heatmap_rgba = cmap(overlay_data)
                heatmap_rgb = heatmap_rgba[..., :3]
                valid_mask = (~np.isnan(overlay_data)) & (np.sum(heatmap_rgb, axis=-1) > 0)
            for c in range(3):
                composite_image[...,c][valid_mask] = (1-alpha)*composite_image[...,c][valid_mask] + alpha*heatmap_rgb[...,c][valid_mask]
            if np.any(valid_mask):
                intensity_map = np.mean(overlay_data, axis=2) if overlay_data.ndim==3 else overlay_data
                max_idx = np.nanargmax(intensity_map)
                max_y, max_x = np.unravel_index(max_idx, intensity_map.shape)
                label_x, label_y = max_x, max_y
                attempts=0
                while is_too_close(label_x, label_y):
                    label_x += LABEL_OFFSET_STEP
                    label_y += LABEL_OFFSET_STEP
                    attempts += 1
                    if attempts>10: break
                label_positions.append((label_x,label_y))
                label_text = f"{'Disease' if key in self.disease_data else 'Anatomy'}: {key}"
                txt = ax.text(label_x, label_y, label_text, color='white', fontsize=LABEL_FONTSIZE,
                              weight='bold', ha='center', va='center')
                txt.set_path_effects([path_effects.Stroke(linewidth=2, foreground='black'), path_effects.Normal()])
        ax.imshow(np.clip(composite_image,0,1))
        fig.subplots_adjust(left=0,right=1,top=1,bottom=0)

        if return_bytes:
            from io import BytesIO
            img_io = BytesIO()
            canvas = FigureCanvas(fig)
            canvas.print_png(img_io)
            plt.close(fig)
            img_io.seek(0)
            print(f"[{now}][API] Visualization rendered to bytes")
            return img_io
        else:
            plt.show()
            print(f"[{now}][API] Visualization complete")
            return None


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
INFILE = r"00000022_001.jpg"
OUTFILE = r"xray_output"

# Predefined list of colormaps for overlays in visualization
PREDEFINED_COLORMAPS = [
    cm.jet, cm.viridis, cm.plasma, cm.inferno, cm.magma,
    cm.cividis, cm.cool, cm.hot, cm.spring, cm.summer
]

selected_diseases = ["Lung Opacity", "Pneumonia", "Left Scapula","Right Scapula", "Heart","Right Lung"]

In [3]:
# --- Example Usage (Cell 1: Processing) ---

if __name__ == '__main__':
    now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    try:

        processor = XRayProcessor(INFILE, OUTFILE, 'test123')
        processed_uuid, results = processor.process()

        if processed_uuid:
            print(f"\n[{now}] Processing complete. The ID '{processed_uuid}' is ready for visualization.")
            print(f"Copy the ID and run the code in the next cell to visualize the results.")
            print(f"\n[{now}] Overall process status\n----- PASS\n")
        else:
            print(f"[{now}] Overall process status\n----- FAIL: X-ray processing failed.\n")

    except (FileNotFoundError, ValueError, RuntimeError) as e:
        print(f"\n[{now}] An error occurred: {e}\n----- FAIL\n")
    except Exception as e:
        print(f"\n[{now}] An unexpected error occurred: {e}\n----- FAIL\n")


[2025-09-14 03:04:52][API] Output directory ensured: xray_output
[2025-09-14 03:04:52][API] Using device: cpu
[2025-09-14 03:04:54][API] Copied image: xray_output\test123.jpg
[2025-09-14 03:04:54][API] Analyzing image for pathologies: xray_output\test123.jpg
[2025-09-14 03:04:54][API] Pathology analysis complete.
[2025-09-14 03:04:54][API] Saved pathology results: xray_output\test123.json
[2025-09-14 03:04:55][API] Generating disease overlays.
[2025-09-14 03:04:54][API] Processing failed: name 'DIISEASE_ALPHA' is not defined

[2025-09-14 03:04:52] An unexpected error occurred: cannot unpack non-iterable NoneType object
----- FAIL



In [None]:
from IPython.display import display
from PIL import Image
import io

try:
    # Define transparency and colormaps for each overlay key
    alphas = [0.5] * len(selected_diseases)
    colormaps = [PREDEFINED_COLORMAPS[i % len(PREDEFINED_COLORMAPS)] for i in range(len(selected_diseases))]

    # Create visualizer instance
    visualizer = XRayVisualizer(processed_uuid, OUTFILE)
    
    # Generate the overlay image bytes for the selected diseases
    image_io = visualizer.show_overlays(
        keys=selected_diseases,
        alphas=alphas,
        colormaps=colormaps,
        return_bytes=True  # ensure the method returns image bytes
    )

    if image_io:
        # Load image from bytes
        image = Image.open(image_io)
        # Display image inline in Jupyter notebook
        display(image)

except Exception as e:
    print(f"[ERROR] render_overlay failed: {e}")
