In [None]:
import cv2
import pytesseract
import numpy as np
import os

from detextify.text_detector import TesseractTextDetector
from detextify.inpainter import LocalSDInpainter
from detextify.detextifier import Detextifier

text_detector = TesseractTextDetector("/usr/include/tesseract")
detextifier = Detextifier(text_detector, LocalSDInpainter())
detextifier.detextify("/home/nicholasjprimiano/ML/img2img-turbo/data/CTH_CBF/test_A_black/REDMOND_JONATHAN_000002.png", "/home/nicholasjprimiano/ML/img2img-turbo/data/CTH_CBF/test_A_black/REDMOND_JONATHAN_000002_notext.png")

In [None]:
!pip install nibabel

In [None]:
import nibabel as nib
import numpy as np

# Load the .nii file
nii_path = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed/BATTLE_MARIA.nii'

# Load the NIfTI file
nii = nib.load(nii_path)
data = nii.get_fdata()

# The data has a shape of (512, 512, 49, 1, 3), so we'll handle this appropriately

# Define the bounding box for the text
bottom_right = (267, 67) 
top_left = (206, 43)  

# Loop through each slice in the 3D volume
for i in range(data.shape[2]):
    slice_img = data[:, :, i, 0, :].copy()  # Extract the i-th slice as an RGB image

    # White out the region where the text "rCBF" is located
    slice_img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = [255, 255, 255]  # Set to white

    # Save the white-out slice back to the data array
    data[:, :, i, 0, :] = slice_img

# Create a new NIfTI image from the modified data
new_nii = nib.Nifti1Image(data, nii.affine)

# Save the new .nii file
output_nii_path = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed/BATTLE_MARIA_white_box_final.nii'
nib.save(new_nii, output_nii_path)


In [None]:
import nibabel as nib
import numpy as np
import cv2

# Load the .nii file
nii_path = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed/BATTLE_MARIA.nii'

# Load the NIfTI file
nii = nib.load(nii_path)
data = nii.get_fdata()

# Updated coordinates for the region of interest
top_left = (205, 43)
bottom_right = (267, 67)

# Loop through each slice in the 3D volume
for i in range(data.shape[2]):
    slice_img = data[:, :, i, 0, :].copy()  # Extract the i-th slice as an RGB image

    # Isolate the region of interest (ROI)
    roi = slice_img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]].copy()

    # Convert the ROI to grayscale for edge detection
    roi_gray = cv2.cvtColor(np.clip(roi * 255, 0, 255).astype(np.uint8), cv2.COLOR_RGB2GRAY)

    # Apply edge detection to find the text
    edges = cv2.Canny(roi_gray, 100, 200)

    # Highlight the detected edges in red on the original slice
    slice_img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1], 0][edges > 0] = 1.0  # Red channel
    slice_img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1], 1][edges > 0] = 0.0  # Green channel
    slice_img[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1], 2][edges > 0] = 0.0  # Blue channel

    # Save the modified slice back to the data array
    data[:, :, i, 0, :] = slice_img

# Create a new NIfTI image from the modified data
new_nii = nib.Nifti1Image(data, nii.affine)

# Save the new .nii file
output_nii_path = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed/BATTLE_MARIA_text_detection.nii'
nib.save(new_nii, output_nii_path)


In [None]:
import nibabel as nib
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os

def create_rectangle_mask(image_shape, top_left, bottom_right):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    cv2.rectangle(mask, top_left, bottom_right, 255, -1)
    return mask

def verify_text_detection(nii_path, output_dir, num_samples=3):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), "Unexpected data shape"

    total_slices = data.shape[2]
    sample_indices = np.linspace(0, total_slices-1, num_samples, dtype=int)

    fig, axes = plt.subplots(1, num_samples, figsize=(5*num_samples, 5))
    if num_samples == 1:
        axes = [axes]

    # Define rectangle coordinates
    top_left = (67, 205)
    bottom_right = (43, 267)

    for i, slice_idx in enumerate(sample_indices):
        slice_rgb = data[:, :, slice_idx, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0

        # Create mask with rectangle
        mask = create_rectangle_mask(slice_rgb.shape, top_left, bottom_right)

        # Create a copy for drawing
        slice_draw = slice_rgb.copy()
        slice_draw[mask > 0] = [1, 0, 0]  # Red color for the box

        axes[i].imshow(slice_draw)
        axes[i].set_title(f'Slice {slice_idx}')
        axes[i].axis('off')

        plt.imsave(os.path.join(output_dir, f'slice_{slice_idx}_verified.png'), slice_draw)

    plt.tight_layout()
    plt.show()

    return top_left, bottom_right

def remove_text_and_inpaint(nii_path, output_dir, top_left, bottom_right):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), "Unexpected data shape"

    # Create mask for the text region
    mask = create_rectangle_mask(data.shape[:2], top_left, bottom_right)

    for i in range(data.shape[2]):
        slice_rgb = data[:, :, i, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0

        # Convert to uint8 for inpainting
        slice_uint8 = (slice_rgb * 255).astype(np.uint8)

        # Inpaint the masked region
        result_uint8 = cv2.inpaint(slice_uint8, mask, 3, cv2.INPAINT_TELEA)

        # Convert back to float and [0, 1] range
        result = result_uint8.astype(float) / 255.0

        # Update the data
        data[:, :, i, 0, :] = result

        plt.imsave(os.path.join(output_dir, f'slice_{i}_processed.png'), result)

    new_img = nib.Nifti1Image(data, nii_img.affine, nii_img.header)
    nib.save(new_img, os.path.join(output_dir, 'processed.nii'))

# Usage
nii_path = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed/BATTLE_MARIA.nii'
output_dir = '/mnt/d/CTH_archive/TEXT_REMOVAL'

os.makedirs(output_dir, exist_ok=True)

# Verify text detection
top_left, bottom_right = verify_text_detection(nii_path, output_dir)

# Uncomment the following line to perform the inpainting after verification
remove_text_and_inpaint(nii_path, output_dir, top_left, bottom_right)

In [None]:
import nibabel as nib
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os

def create_rectangle_mask(image_shape, top_left, bottom_right):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    cv2.rectangle(mask, top_left, bottom_right, 255, -1)
    return mask

def verify_text_detection(nii_path, output_dir, num_samples=3):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), "Unexpected data shape"

    total_slices = data.shape[2]
    sample_indices = np.linspace(0, total_slices-1, num_samples, dtype=int)

    fig, axes = plt.subplots(1, num_samples, figsize=(5*num_samples, 5))
    if num_samples == 1:
        axes = [axes]

    # Define rectangle coordinates
    top_left = (67, 205)
    bottom_right = (43, 267)

    for i, slice_idx in enumerate(sample_indices):
        slice_rgb = data[:, :, slice_idx, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0

        # Create mask with rectangle
        mask = create_rectangle_mask(slice_rgb.shape, top_left, bottom_right)

        # Create a copy for drawing
        slice_draw = slice_rgb.copy()
        slice_draw[mask > 0] = [1, 0, 0]  # Red color for the box

        axes[i].imshow(slice_draw)
        axes[i].set_title(f'Slice {slice_idx}')
        axes[i].axis('off')

        plt.imsave(os.path.join(output_dir, f'slice_{slice_idx}_verified.png'), slice_draw)

    plt.tight_layout()
    plt.show()

    return top_left, bottom_right

def remove_text_and_inpaint(nii_path, output_dir, top_left, bottom_right):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), "Unexpected data shape"

    # Create mask for the text region
    mask = create_rectangle_mask(data.shape[:2], top_left, bottom_right)

    for i in range(data.shape[2]):
        slice_rgb = data[:, :, i, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0

        # Convert to uint8 for inpainting
        slice_uint8 = (slice_rgb * 255).astype(np.uint8)

        # Inpaint the masked region
        result_uint8 = cv2.inpaint(slice_uint8, mask, 3, cv2.INPAINT_TELEA)

        # Convert back to float and [0, 1] range
        result = result_uint8.astype(float) / 255.0

        # Update the data
        data[:, :, i, 0, :] = result

        plt.imsave(os.path.join(output_dir, f'slice_{i}_processed.png'), result)

    # Ensure data is in the correct format and range for NIfTI
    data = np.clip(data, 0, 1)  # Clip values to [0, 1] range
    data = data.astype(np.float32)  # Ensure float32 data type
    
    # Create a new NIfTI image with the processed data and original header
    new_img = nib.Nifti1Image(data, nii_img.affine, header=nii_img.header)

    # Update the header to reflect the new data type
    new_img.set_data_dtype(np.float32)

    # Save the processed NIfTI
    nib.save(new_img, os.path.join(output_dir, 'processed.nii.gz'))

# Usage
nii_path = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed/BATTLE_MARIA.nii'
output_dir = '/mnt/d/CTH_archive/TEXT_REMOVAL'

os.makedirs(output_dir, exist_ok=True)

# Verify text detection
top_left, bottom_right = verify_text_detection(nii_path, output_dir)

# Uncomment the following line to perform the inpainting after verification
remove_text_and_inpaint(nii_path, output_dir, top_left, bottom_right)

In [None]:
import nibabel as nib
import numpy as np
import cv2
import os
from tqdm import tqdm

def create_rectangle_mask(image_shape, top_left, bottom_right):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    cv2.rectangle(mask, top_left, bottom_right, 255, -1)
    return mask

def remove_text_and_inpaint(nii_path, output_path):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), "Unexpected data shape"

    # Define rectangle coordinates
    top_left = (67, 205)
    bottom_right = (43, 267)

    # Create mask for the text region
    mask = create_rectangle_mask(data.shape[:2], top_left, bottom_right)

    for i in range(data.shape[2]):
        slice_rgb = data[:, :, i, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0

        # Convert to uint8 for inpainting
        slice_uint8 = (slice_rgb * 255).astype(np.uint8)

        # Inpaint the masked region
        result_uint8 = cv2.inpaint(slice_uint8, mask, 3, cv2.INPAINT_TELEA)

        # Convert back to float and [0, 1] range
        result = result_uint8.astype(float) / 255.0

        # Update the data
        data[:, :, i, 0, :] = result

    # Ensure data is in the correct format and range for NIfTI
    data = np.clip(data, 0, 1)  # Clip values to [0, 1] range
    data = data.astype(np.float32)  # Ensure float32 data type

    # Create a new NIfTI image with the processed data and original header
    new_img = nib.Nifti1Image(data, nii_img.affine, header=nii_img.header)

    # Update the header to reflect the new data type
    new_img.set_data_dtype(np.float32)

    # Save the processed NIfTI
    nib.save(new_img, output_path)

def process_all_nifti_files(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    nii_files = [f for f in os.listdir(input_dir) if f.endswith('.nii')]
    
    for nii_file in tqdm(nii_files, desc="Processing NIfTI files"):
        input_path = os.path.join(input_dir, nii_file)
        output_path = os.path.join(output_dir, nii_file)
        remove_text_and_inpaint(input_path, output_path)

# Usage
input_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed'
output_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFIT_NOTEXT'

process_all_nifti_files(input_dir, output_dir)

In [None]:
import nibabel as nib
import numpy as np
import cv2
import os
from tqdm import tqdm

def create_rectangle_mask(image_shape, top_left, bottom_right):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    cv2.rectangle(mask, top_left, bottom_right, 255, -1)
    return mask

def remove_text_and_inpaint(nii_path, output_path):
    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), "Unexpected data shape"

    # Define rectangle coordinates
    top_left = (67, 205)
    bottom_right = (43, 267)

    # Create mask for the text region
    mask = create_rectangle_mask(data.shape[:2], top_left, bottom_right)

    for i in range(data.shape[2]):
        slice_rgb = data[:, :, i, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0

        # Convert to uint8 for inpainting
        slice_uint8 = (slice_rgb * 255).astype(np.uint8)

        # Inpaint the masked region
        result_uint8 = cv2.inpaint(slice_uint8, mask, 3, cv2.INPAINT_TELEA)

        # Convert back to float and [0, 1] range
        result = result_uint8.astype(float) / 255.0

        # Update the data
        data[:, :, i, 0, :] = result

    # Ensure data is in the correct format and range for NIfTI
    data = np.clip(data, 0, 1)  # Clip values to [0, 1] range
    data = data.astype(np.float32)  # Ensure float32 data type

    # Create a new NIfTI image with the processed data and original header
    new_img = nib.Nifti1Image(data, nii_img.affine, header=nii_img.header)

    # Update the header to reflect the new data type
    new_img.set_data_dtype(np.float32)

    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    # Save the processed NIfTI
    nib.save(new_img, output_path)

def process_all_nifti_files(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    nii_files = [f for f in os.listdir(input_dir) if f.endswith('.nii')]
    
    for nii_file in tqdm(nii_files, desc="Processing NIfTI files"):
        input_path = os.path.join(input_dir, nii_file)
        output_path = os.path.join(output_dir, nii_file)
        try:
            remove_text_and_inpaint(input_path, output_path)
        except Exception as e:
            print(f"Error processing {nii_file}: {str(e)}")

# Usage
input_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed'
output_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFIT_NOTEXT'

process_all_nifti_files(input_dir, output_dir)

In [None]:
import nibabel as nib
import numpy as np
import cv2
import os
import shutil
from tqdm import tqdm
import logging
import re

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def create_rectangle_mask(image_shape, top_left, bottom_right):
    mask = np.zeros(image_shape[:2], dtype=np.uint8)
    cv2.rectangle(mask, top_left, bottom_right, 255, -1)
    return mask

def sanitize_filename(filename):
    return re.sub(r'[^a-zA-Z0-9_.-]', '_', filename)

def remove_text_and_inpaint(nii_path, output_path):
    logging.info(f"Processing file: {os.path.basename(nii_path)}")
    logging.info(f"Output path: {output_path}")
    
    if not os.path.exists(nii_path):
        raise FileNotFoundError(f"Input file not found: {nii_path}")

    nii_img = nib.load(nii_path)
    data = nii_img.get_fdata()
    assert data.shape[3:] == (1, 3), f"Unexpected data shape: {data.shape}"

    top_left = (67, 205)
    bottom_right = (43, 267)
    mask = create_rectangle_mask(data.shape[:2], top_left, bottom_right)

    for i in range(data.shape[2]):
        slice_rgb = data[:, :, i, 0, :]
        if slice_rgb.max() > 1:
            slice_rgb = slice_rgb / 255.0
        slice_uint8 = (slice_rgb * 255).astype(np.uint8)
        result_uint8 = cv2.inpaint(slice_uint8, mask, 3, cv2.INPAINT_TELEA)
        result = result_uint8.astype(float) / 255.0
        data[:, :, i, 0, :] = result

    data = np.clip(data, 0, 1)
    data = data.astype(np.float32)
    new_img = nib.Nifti1Image(data, nii_img.affine, header=nii_img.header)
    new_img.set_data_dtype(np.float32)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    logging.info(f"Saving processed file to: {output_path}")
    try:
        nib.save(new_img, output_path)
        logging.info(f"File saved successfully: {output_path}")
    except Exception as e:
        logging.error(f"Error saving file: {str(e)}")
        alternative_output_path = os.path.join(os.path.dirname(output_path), sanitize_filename(os.path.basename(output_path)))
        logging.info(f"Attempting to save with sanitized filename: {alternative_output_path}")
        try:
            nib.save(new_img, alternative_output_path)
            logging.info(f"File saved successfully with sanitized filename: {alternative_output_path}")
        except Exception as e2:
            logging.error(f"Error saving file with sanitized filename: {str(e2)}")
            logging.info("Attempting to copy the original file")
            try:
                shutil.copy2(nii_path, output_path)
                logging.info(f"Copied original file to: {output_path}")
            except Exception as e3:
                logging.error(f"Error copying original file: {str(e3)}")
                raise

def process_all_nifti_files(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    nii_files = [f for f in os.listdir(input_dir) if f.endswith('.nii')]
    
    for nii_file in tqdm(nii_files, desc="Processing NIfTI files"):
        input_path = os.path.join(input_dir, nii_file)
        output_path = os.path.join(output_dir, nii_file)
        try:
            remove_text_and_inpaint(input_path, output_path)
        except Exception as e:
            logging.error(f"Error processing {nii_file}: {str(e)}")
            if nii_file == "CLARKSON-FARRELL_EDWARD.nii":
                logging.error(f"Detailed error for CLARKSON-FARRELL_EDWARD.nii:")
                logging.error(f"Input path: {input_path}")
                logging.error(f"Output path: {output_path}")
                logging.error(f"Input path exists: {os.path.exists(input_path)}")
                logging.error(f"Output directory exists: {os.path.exists(os.path.dirname(output_path))}")
                logging.error(f"Can write to output directory: {os.access(os.path.dirname(output_path), os.W_OK)}")
                logging.error(f"File permissions of input file: {oct(os.stat(input_path).st_mode)[-3:]}")
                logging.error(f"File permissions of output directory: {oct(os.stat(os.path.dirname(output_path)).st_mode)[-3:]}")
                
                # Try to list the contents of the output directory
                try:
                    logging.info(f"Contents of output directory: {os.listdir(os.path.dirname(output_path))}")
                except Exception as dir_error:
                    logging.error(f"Error listing output directory contents: {str(dir_error)}")

# Usage
input_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFTI_RESIZED_Processed'
output_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFIT_NOTEXT'

process_all_nifti_files(input_dir, output_dir)

In [None]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, widgets, Layout
from IPython.display import display, clear_output
from collections import defaultdict

def get_patient_pairs(cth_dir, cbf_dir, start=0, batch_size=6):
    print("Indexing patient pairs...")
    cth_files = sorted([f for f in os.listdir(cth_dir) if f.endswith('.nii')])
    cbf_files = sorted([f for f in os.listdir(cbf_dir) if f.endswith('.nii')])
    
    patient_pairs = defaultdict(list)
    for cth, cbf in zip(cth_files, cbf_files):
        patient_name = cth.split('.')[0]
        if patient_name == cbf.split('.')[0]:
            patient_pairs[patient_name].append((
                os.path.join(cth_dir, cth),
                os.path.join(cbf_dir, cbf)
            ))
    
    total_patients = len(patient_pairs)
    patient_names = sorted(patient_pairs.keys())
    
    if start >= total_patients:
        start = 0  # Wrap around to the beginning
    
    batch_patients = patient_names[start:start+batch_size]
    batch_pairs = {name: patient_pairs[name] for name in batch_patients}
    
    print(f"Loaded {len(batch_patients)} patients (batch starting at patient {start+1}).")
    return batch_pairs, start, min(start+batch_size, total_patients), total_patients

def load_image(file_path):
    print(f"Loading image: {os.path.basename(file_path)}")
    img = nib.load(file_path).get_fdata()
    return np.squeeze(img)  # Remove any singleton dimensions

def visualize_registration(cth_dir, cbf_dir, batch_size=6):
    cth_img, cbf_img = None, None
    current_start = 0
    patient_pairs, start, end, total_patients = get_patient_pairs(cth_dir, cbf_dir, current_start, batch_size)
    patient_names = list(patient_pairs.keys())
    
    patient_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(patient_names)-1,
        step=1,
        description='Patient:',
        continuous_update=False
    )
    
    image_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=0,
        step=1,
        description='Image:',
        continuous_update=False
    )
    
    slice_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=100,
        step=1,
        description='Slice:',
        continuous_update=False
    )
    
    opacity_slider = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1,
        step=0.1,
        description='Overlay Opacity:',
        continuous_update=False
    )
    
    def update(patient_index, image_index, slice_index, opacity):
        nonlocal cth_img, cbf_img
        
        if len(patient_names) == 0:
            print("No patients available in the current batch.")
            return
        
        patient_name = patient_names[patient_index]
        patient_images = patient_pairs[patient_name]
        
        if image_index >= len(patient_images):
            image_index = 0
            image_slider.value = 0
        
        image_slider.max = len(patient_images) - 1
        
        if cth_img is None or cbf_img is None or (patient_index, image_index) != update.last_indices:
            cth_img = load_image(patient_images[image_index][0])
            cbf_img = load_image(patient_images[image_index][1])
            update.last_indices = (patient_index, image_index)
            
            print(f"CTH shape: {cth_img.shape}, CBF shape: {cbf_img.shape}")
            
            # Update slice slider range
            max_slice = min(cth_img.shape[2], cbf_img.shape[2]) - 1
            slice_slider.max = max_slice
            slice_slider.value = min(slice_slider.value, max_slice)
        
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
        
        # CTH image
        ax1.imshow(cth_img[:, :, slice_index], cmap='gray')
        ax1.set_title('CTH')
        ax1.axis('off')
        
        # CBF image
        if cbf_img.ndim == 3:
            cbf_slice = cbf_img[:, :, slice_index]
            ax2.imshow(cbf_slice, cmap='jet')
        else:  # If it's a color image (4D)
            cbf_slice = cbf_img[:, :, slice_index, :]
            ax2.imshow(cbf_slice)
        ax2.set_title('CBF')
        ax2.axis('off')
        
        # Overlay
        ax3.imshow(cth_img[:, :, slice_index], cmap='gray')
        if cbf_img.ndim == 3:
            ax3.imshow(cbf_slice, cmap='jet', alpha=opacity)
        else:
            ax3.imshow(cbf_slice, alpha=opacity)
        ax3.set_title('Overlay')
        ax3.axis('off')
        
        plt.suptitle(f"Patient: {patient_name} (Image {image_index+1}/{len(patient_images)}, Overall: {start + patient_index + 1} / {total_patients})")
        
        clear_output(wait=True)
        display(fig)
        plt.close(fig)
    
    update.last_indices = (-1, -1)
    
    def load_next_batch(b):
        nonlocal current_start, patient_pairs, start, end
        current_start = end
        patient_pairs, start, end, _ = get_patient_pairs(cth_dir, cbf_dir, current_start, batch_size)
        patient_names = list(patient_pairs.keys())
        patient_slider.max = max(0, len(patient_names) - 1)
        patient_slider.value = 0
        image_slider.value = 0
        update(0, 0, slice_slider.value, opacity_slider.value)
    
    def load_prev_batch(b):
        nonlocal current_start, patient_pairs, start, end
        current_start = max(0, start - batch_size)
        patient_pairs, start, end, _ = get_patient_pairs(cth_dir, cbf_dir, current_start, batch_size)
        patient_names = list(patient_pairs.keys())
        patient_slider.max = max(0, len(patient_names) - 1)
        patient_slider.value = 0
        image_slider.value = 0
        update(0, 0, slice_slider.value, opacity_slider.value)
    
    next_batch_button = widgets.Button(description="Next Batch")
    next_batch_button.on_click(load_next_batch)
    
    prev_batch_button = widgets.Button(description="Previous Batch")
    prev_batch_button.on_click(load_prev_batch)
    
    controls = widgets.HBox([prev_batch_button, next_batch_button])
    
    display(controls)
    
    # Create and display the interactive widget
    interactive_plot = interactive(update, patient_index=patient_slider, image_index=image_slider, slice_index=slice_slider, opacity=opacity_slider)
    display(interactive_plot)

# Usage
cth_dir = "/mnt/d/CTH_archive/CTH_NIFTI"
cbf_dir = "/mnt/d/CTH_archive/CBF_COLORED_NIFIT_NOTEXT_REG"

visualize_registration(cth_dir, cbf_dir, batch_size=6)

In [None]:
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
import os
from matplotlib.widgets import Slider
from IPython.display import display
import ipywidgets as widgets

# Enable interactive Matplotlib backend
%matplotlib widget

# Function to load and prepare the images
def load_images(cth_path, reg_path):
    cth_img = nib.load(cth_path).get_fdata()
    reg_img = nib.load(reg_path).get_fdata()

    # Ensure dimensions match by reconciling dimensions
    reg_img = np.squeeze(reg_img)  # Remove any singleton dimensions

    # Normalize CTH image only to [0, 1] for better overlay visualization
    cth_img = (cth_img - np.min(cth_img)) / (np.max(cth_img) - np.min(cth_img))

    return cth_img, reg_img

# Function to visualize with overlay and transparency control
def visualize_overlay(cth_img, reg_img):
    plt.close('all')  # Close any existing figures
    fig, ax = plt.subplots()
    plt.subplots_adjust(left=0.25, bottom=0.25)
    mid_slice = cth_img.shape[2] // 2

    l = ax.imshow(cth_img[:, :, mid_slice], cmap='gray')
    overlay = ax.imshow(reg_img[:, :, mid_slice], alpha=0.5)  # Initial transparency, preserving original colors

    ax.margins(x=0)

    # Slider for transparency
    axcolor = 'lightgoldenrodyellow'
    ax_alpha = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)

    alpha_slider = Slider(ax_alpha, 'Transparency', 0.0, 1.0, valinit=0.5)

    def update(val):
        overlay.set_alpha(alpha_slider.val)
        fig.canvas.draw_idle()

    alpha_slider.on_changed(update)

    # Slider for scrolling through slices
    ax_slice = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
    slice_slider = Slider(ax_slice, 'Slice', 0, cth_img.shape[2] - 1, valinit=mid_slice, valfmt='%0.0f')

    def update_slice(val):
        slice_idx = int(slice_slider.val)
        l.set_data(cth_img[:, :, slice_idx])
        overlay.set_data(reg_img[:, :, slice_idx])
        fig.canvas.draw_idle()

    slice_slider.on_changed(update_slice)

    plt.show()

# Define directories
cth_dir = '/mnt/d/CTH_archive/CTH_NIFTI'
reg_dir = '/mnt/d/CTH_archive/CBF_COLORED_NIFIT_NOTEXT_REG'

# Function to process and visualize a single patient
def process_patient(patient_file):
    cth_path = os.path.join(cth_dir, patient_file)
    reg_path = os.path.join(reg_dir, patient_file)

    if os.path.exists(reg_path):
        print(f"\nVisualizing registration for {patient_file}")
        cth_img, reg_img = load_images(cth_path, reg_path)
        visualize_overlay(cth_img, reg_img)
    else:
        print(f"Coregistered file for {patient_file} not found.")

# Iterate through all patient files
cth_files = sorted([f for f in os.listdir(cth_dir) if f.endswith('.nii')])

# Create a dropdown to select a patient file
patient_dropdown = widgets.Dropdown(
    options=cth_files,
    description='Select Patient:',
    disabled=False,
)

def on_patient_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        process_patient(change['new'])

patient_dropdown.observe(on_patient_change)

display(patient_dropdown)
