In [1]:

from pathlib import Path
from typing import List

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
import tqdm

# show the combined negative control image
from IPython.display import (
  Image as IPImage,
)
from PIL import Image, ImageDraw, ImageFont
from rich import print  # type: ignore # noqa

from readii.image_processing import displayCTSegOverlay, getCroppedImages
from readii.loaders import loadDicomSITK, loadRTSTRUCTSITK
from readii.negative_controls_refactor import NegativeControlManager


In [2]:
# Example usage:
# fig = display_slice_with_slider(your_image)
def create_donut_image(size: int = 64, radius: int = 20, max_thickness: int = 10) -> sitk.Image:
  """Generate a 3D donut-shaped SimpleITK image with varying thickness across slices."""
  grid = np.zeros((size, size, size), dtype=np.uint8)
  center = size // 2
  mid_slice = size // 2

  for z in range(size):
    # Calculate thickness for current slice based on distance from the middle slice
    thickness = max_thickness * (1 - abs(z - mid_slice) / mid_slice)  # Linear scaling
    for x in range(size):
      for y in range(size):
        distance_from_center = np.sqrt((x - center) ** 2 + (y - center) ** 2)
        if radius - thickness <= distance_from_center <= radius + thickness:
          grid[z, x, y] = 255  # Donut drawn along Z-axis (top-down view)

  return sitk.GetImageFromArray(grid)

def create_square_mask(image: sitk.Image, side_length: int = 20) -> sitk.Image:
  """
  Create a square mask for a given SimpleITK image.

  The square is placed in the top-left quadrant of the image, covering part of the donut.

  Parameters
  ----------
  image : sitk.Image
      The input image for which the mask is created.
  side_length : int, default 20
      The length of one side of the square mask in pixels.

  Returns
  -------
  sitk.Image
      A binary mask image with the square region set to 1.
  """
  # Get the dimensions of the input image
  size = sitk.GetArrayFromImage(image).shape  # (z, y, x)

  # Define the position of the square in the top-left quadrant
  # Offset from the center to ensure it covers part of the donut
  start_x = size[2] // 4 - side_length // 2
  start_y = size[1] // 4 - side_length // 2

  # Ensure the square is within bounds
  start_x = max(0, start_x)
  start_y = max(0, start_y)
  end_x = min(size[2], start_x + side_length)
  end_y = min(size[1], start_y + side_length)

  # Create the mask array
  mask = np.zeros(size, dtype=np.uint8)
  mask[:, start_y:end_y, start_x:end_x] = 1  # Apply the square ROI to all slices

  # Convert the mask to a SimpleITK image
  return sitk.GetImageFromArray(mask)

donut = create_donut_image()

donut_array = sitk.GetArrayFromImage(donut)
mask = create_square_mask(donut)
mask_array = sitk.GetArrayFromImage(mask)

def view_slices(z):
  # 1 row, 3 columns
  # first column: donut image
  # second column: mask image
  # third column: donut image with mask overlaid
  plt.figure(figsize=(16, 8))
  plt.subplot(1, 3, 1)
  plt.imshow(donut_array[z], cmap='gray')
  plt.axis('off')
  plt.title('Donut')

  plt.subplot(1, 3, 2)
  plt.imshow(mask_array[z], cmap='gray')
  plt.axis('off')
  plt.title('Mask')

  plt.subplot(1, 3, 3)
  plt.imshow(donut_array[z], cmap='gray')
  plt.imshow(mask_array[z], cmap='gray', alpha=0.5)
  plt.axis('off')
  plt.title('Donut with Mask')


slider = widgets.IntSlider(min=0, max=donut_array.shape[0]-1, step=1, value=donut_array.shape[0]//2)
fig = plt.figure(figsize=(16, 8))
widgets.interactive(view_slices, z=slider)


<Figure size 1600x800 with 0 Axes>

interactive(children=(IntSlider(value=32, description='z', max=63), Output()), _dom_classes=('widget-interact'…

In [3]:

class VisualizeNegativeControl:
    """
    Class for visualizing and saving negative control strategies applied to images.

    Attributes
    ----------
    results_dir : Path
        Directory to save all generated GIFs.
    """

    def __init__(self, results_dir: Path) -> None:
        """
        Initialize the VisualizeNegativeControl class.

        Parameters
        ----------
        results_dir : Path
            Directory to save the results.
        """
        self.results_dir = results_dir
        self.results_dir.mkdir(exist_ok=True)
        
    def save_gif_from_slices(self, ct_image: sitk.Image, seg_image: sitk.Image, output_path: Path, control_name: str = None, region_name: str = None) -> None:
        """
        Save a 3D SimpleITK image as an animated GIF with optional text overlay.

        Parameters
        ----------
        ct_image : sitk.Image
                The 3D SimpleITK image to save as a GIF.
        seg_image : sitk.Image
                The segmentation image to overlay on the CT image.
        output_path : Path
                The file path to save the GIF.
        control_name : str, optional
                The name of the control strategy (to overlay on the image).
        region_name : str, optional
                The name of the region strategy (to overlay on the image).

        Returns
        -------
        None
        """
        array = sitk.GetArrayFromImage(ct_image)
        seg_array = sitk.GetArrayFromImage(seg_image)

        slices = []
        for i in range(array.shape[0]):
            fig, ax = plt.subplots()
            displayCTSegOverlay(
                ctImage=array,
                segImage=seg_array,
                sliceIdx=i,
                cmapCT=plt.cm.Greys_r,
                cmapSeg=plt.cm.brg,
                alpha=0.3,
                crop=False,
            )
            ax.set_title(f"{control_name or ''} | {region_name or ''}".strip(" | "))
            fig.canvas.draw()

            # Convert plot to image
            img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
            img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
            slices.append(Image.fromarray(img))
            plt.close(fig)

        # Save slices as GIF
        slices[0].save(
            output_path,
            save_all=True,
            append_images=slices[1:],
            duration=100,
            loop=0,
        )

    def create_large_gif(self, output_path: Path, control_files: List[Path], original_file: Path) -> Path:
        """
        Create a large GIF combining multiple smaller GIFs into a grid.

        Parameters
        ----------
        output_path : Path
            Path to save the resulting large GIF.
        control_files : list of Path
            List of file paths for the control GIFs to arrange in a 3x3 grid.
        original_file : Path
            Path to the original GIF to be placed to the right of the grid.

        Returns
        -------
        None
        """
        control_frames = [Image.open(f) for f in control_files]
        original_frames = Image.open(original_file)

        # Get dimensions of a single frame
        frame_width, frame_height = control_frames[0].size
        num_frames = min(
            [gif.n_frames for gif in control_frames] + [original_frames.n_frames]
        )

        # Define padding
        padding = 10

        # Dimensions for the large GIF
        grid_width = (frame_width * 3) + (padding * 2)
        grid_height = (frame_height * 3) + (padding * 2)
        total_width = grid_width + frame_width + padding

        combined_frames = []
        for i in range(num_frames):
            combined_frame = Image.new("RGB", (total_width, grid_height), "black")
            for j, control_frame in enumerate(control_frames):
                control_frame.seek(i)
                x_offset = (j % 3) * frame_width + ((j % 3) * padding)
                y_offset = (j // 3) * frame_height + ((j // 3) * padding)
                combined_frame.paste(control_frame, (x_offset, y_offset))

            original_frames.seek(i)
            combined_frame.paste(original_frames, (grid_width + padding, 0))

            # Add frame counter
            draw = ImageDraw.Draw(combined_frame)
            text = f"frame: {i + 1}/{num_frames}"
            font = ImageFont.load_default()
            text_width = draw.textlength(text, font=font)
            text_height = font.size
            text_position = (total_width - text_width - 10, grid_height - text_height - 10)
            draw.text(text_position, text, fill="white", font=font)

            combined_frames.append(combined_frame)

        combined_frames[0].save(
            output_path,
            save_all=True,
            append_images=combined_frames[1:],
            duration=100,
            loop=0,
        )
        print(f"Large GIF saved at {output_path}")
        return output_path


In [None]:

# Example usage
result_dir = Path("TRASH", "results")
result_dir.mkdir(exist_ok=True, parents=True)

visualizer = VisualizeNegativeControl(results_dir=result_dir)
control_paths = []

manager = NegativeControlManager.from_strings(
    negative_control_types=["sampled", "randomized", "shuffled"],
    region_types=["roi", "non_roi", "full"],
    random_seed=42,
)

donut_image = create_donut_image(size=64, radius=20, max_thickness=10)
square_mask = create_square_mask(donut_image, side_length=10)

# Assuming donut_image and square_mask are already created and available
visualizer.save_gif_from_slices(donut_image, square_mask, visualizer.results_dir / "original.gif")

# Assuming manager.apply returns the necessary images
for image, control_name, region_name in manager.apply(donut_image, square_mask):
    output_path = Path("TRASH", "results", f"{control_name}_{region_name}.gif")
    visualizer.save_gif_from_slices(image, square_mask, output_path, control_name=control_name, region_name=region_name)
    control_paths.append(output_path)

# Create a large GIF combining all control strategies
output_path = visualizer.create_large_gif(
    output_path=Path("TRASH", "results", "combined.gif"),
    control_files=control_paths,
    original_file=Path("TRASH", "results", "original.gif"),
)

# show the combined negative control image
from IPython.display import Image as IPImage

# Display the combined GIF
IPImage(filename=output_path)


  img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
