# Imports & Configuration

In [1]:
# Core libraries
import SimpleITK as sitk
import numpy as np
import glob
import os
import sys
import time
import tifffile
import matplotlib.pyplot as plt
from IPython.display import clear_output
from scipy.signal import correlate2d

# Project config
import config

# Import from the repo
import sys
sys.path.append("py_alpha_amd_release")

from register import Register
from transforms import AffineTransform, Rigid2DTransform, make_image_centered_transform
import filters

# ==================== CONFIGURATION ====================

# Data paths
DATA_BASE_PATH = os.path.join(config.DATASPACE, "TMA_Cores_Grouped_NEW")
WORK_OUTPUT = os.path.join(config.DATASPACE, "Registered")
TARGET_CORE = "Core_11"

# Input/Output folders
INPUT_FOLDER = os.path.join(DATA_BASE_PATH, TARGET_CORE)
OUTPUT_FOLDER = os.path.join(WORK_OUTPUT, TARGET_CORE)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# Registration parameters
DOWNSAMPLE_FACTOR = 4  # For faster registration
CK_CHANNEL_INDEX = 6   # Cytokeratin channel for registration
LEARNING_RATE = 1.0
MAX_ITERATIONS = 200

# Channel names
CHANNEL_NAMES = ['DAPI', 'CD31', 'GAP43', 'NFP', 'CD3', 'CD163', 'CK', 'AF']
print(f"Target Core: {TARGET_CORE}")
print(f"Input Folder: {INPUT_FOLDER}")
print(f"Output Folder: {OUTPUT_FOLDER}")

# Get sorted list of input files
file_list = sorted(glob.glob(os.path.join(INPUT_FOLDER, "*.ome.tif")))
print(f"\nFirst slice: {os.path.basename(file_list[0])}")



Target Core: Core_11
Input Folder: /data3/junming/3D-TMA-Register/TMA_Cores_Grouped_NEW/Core_11
Output Folder: /data3/junming/3D-TMA-Register/Registered/Core_11

First slice: 240919_3D_BL_TMA_10_Core11.ome.tif


In [12]:
import re
def get_slice_number(filename):
    # Matches "TMA_" followed by digits, captures the digits
    match = re.search(r"TMA_(\d+)_", os.path.basename(filename))
    if match:
        return int(match.group(1))
    return 0 # Fallback if no number found

# Get file list and sort using the custom key
raw_files = glob.glob(os.path.join(INPUT_FOLDER, "*.ome.tif"))
files = sorted(raw_files, key=get_slice_number)

# Verify the order
print(f"Target: {TARGET_CORE} | Found {len(file_list)} slices")
for i, f in enumerate(file_list):
    print(f"[{i}] {os.path.basename(f)}")


# --- MANUAL SELECTION ---
# UPDATE THESE INDICES BASED ON THE OUTPUT ABOVE
REF_INDEX = 0  # Index of the fixed image
FLO_INDEX = 1  # Index of the moving image

ref_path = files[REF_INDEX]
flo_path = files[FLO_INDEX]

print(f"\nReference: {os.path.basename(ref_path)}")
print(f"Floating:  {os.path.basename(flo_path)}")

Target: Core_11 | Found 20 slices
[0] 240919_3D_BL_TMA_1_Core11.ome.tif
[1] 240919_3D_BL_TMA_2_Core11.ome.tif
[2] 240919_3D_BL_TMA_3_Core11.ome.tif
[3] 240919_3D_BL_TMA_4_Core11.ome.tif
[4] 240919_3D_BL_TMA_5_Core11.ome.tif
[5] 240919_3D_BL_TMA_6_Core11.ome.tif
[6] 240919_3D_BL_TMA_7_Core11.ome.tif
[7] 240919_3D_BL_TMA_8_Core11.ome.tif
[8] 240919_3D_BL_TMA_9_Core11.ome.tif
[9] 240919_3D_BL_TMA_10_Core11.ome.tif
[10] 240919_3D_BL_TMA_11_Core11.ome.tif
[11] 240919_3D_BL_TMA_12_Core11.ome.tif
[12] 240919_3D_BL_TMA_13_Core11.ome.tif
[13] 240919_3D_BL_TMA_14_Core11.ome.tif
[14] 240919_3D_BL_TMA_15_Core11.ome.tif
[15] 240919_3D_BL_TMA_16_Core11.ome.tif
[16] 240919_3D_BL_TMA_17_Core11.ome.tif
[17] 240919_3D_BL_TMA_18_Core11.ome.tif
[18] 240919_3D_BL_TMA_19_Core11.ome.tif
[19] 240919_3D_BL_TMA_20_Core11.ome.tif

Reference: 240919_3D_BL_TMA_1_Core11.ome.tif
Floating:  240919_3D_BL_TMA_2_Core11.ome.tif


# Load & Extract Channel 6

In [3]:
def get_channel_and_cast(path, channel_idx):
    """
    Replaces: sitk.Cast(image[:, :, channel_idx], sitk.sitkFloat32)
    """
    # 1. Load the volume (Numpy array)
    vol = tifffile.imread(path)
    
    # 2. Extract the specific channel
    # Tifffile usually loads OME-TIFF as (Channel, Y, X) or (Z, Channel, Y, X)
    # We check the shape to pick the right dimension.
    
    img_2d = None
    
    # CASE A: Standard (Channel, Y, X) -> Shape like (40, 1000, 1000)
    if vol.ndim == 3:
        print(f"  Layout detected: (Channel, Y, X). Extracting index {channel_idx}...")
        img_2d = vol[channel_idx, :, :]
        
    # CASE B: Channel Last (Y, X, Channel) -> Shape like (1000, 1000, 40)
    # (This matches your old ITK logic directly)
    elif vol.ndim == 3 and vol.shape[2] < vol.shape[0]: 
        print(f"  Layout detected: (Y, X, Channel). Extracting index {channel_idx}...")
        img_2d = vol[:, :, channel_idx]

    # CASE C: Complex (Time/Z, Channel, Y, X) -> Shape like (1, 40, 1000, 1000)
    elif vol.ndim == 4:
        print(f"  Layout detected: (Z, Channel, Y, X). Extracting index {channel_idx} from first Z-slice...")
        img_2d = vol[0, channel_idx, :, :]
        
    else:
        raise ValueError(f"Unknown image shape: {vol.shape}. Cannot automatically find Channel {channel_idx}.")

    # 3. Cast to Float32 (Exact equivalent of sitk.Cast(..., sitkFloat32))
    img_float = img_2d.astype(np.float32)
    
    # 4. Normalize (Required for Alpha-AMD to work well)
    # This scales it to 0.0 - 1.0, preserving relative intensity
    return filters.normalize(img_float)

# --- EXECUTION ---
print(f"Processing: {TARGET_CORE}")
print(f"Extracting Channel Index: {CK_CHANNEL_INDEX}")

ref_im = get_channel_and_cast(ref_path, CK_CHANNEL_INDEX)
flo_im = get_channel_and_cast(flo_path, CK_CHANNEL_INDEX)

print(f"Done. Reference shape: {ref_im.shape}, Type: {ref_im.dtype}")

Processing: Core_11
Extracting Channel Index: 6
  Layout detected: (Channel, Y, X). Extracting index 6...
  Layout detected: (Channel, Y, X). Extracting index 6...
Done. Reference shape: (6080, 6080), Type: float32


# Configure Registration (Naive AMD)

In [4]:
# Initialize 2D Registration
reg = Register(2)

# Set Images
reg.set_reference_image(ref_im)
reg.set_floating_image(flo_im)

# Set Masks (Use all pixels)
reg.set_reference_mask(np.ones(ref_im.shape, dtype='bool'))
reg.set_floating_mask(np.ones(flo_im.shape, dtype='bool'))

# --- PYRAMID SETUP ---
# We use your DOWNSAMPLE_FACTOR as the coarsest level.
# Level format: (downsample_factor, blur_sigma)
reg.add_pyramid_level(DOWNSAMPLE_FACTOR, 5.0)  # Coarse
reg.add_pyramid_level(DOWNSAMPLE_FACTOR // 2, 3.0) # Medium
reg.add_pyramid_level(1, 0.0)                  # Fine (Full resolution)

# --- TRANSFORM SETUP ---
# Using Affine (includes scaling/shearing). 
# If tissue shouldn't stretch, use Rigid2DTransform(2) instead.
reg.add_initial_transform(AffineTransform(2))

# --- OPTIMIZER SETUP ---
reg.set_iterations(MAX_ITERATIONS)
reg.set_sampling_fraction(0.1)  # Use 10% of pixels for speed
reg.set_optimizer('adam')

# Step lengths must match the number of pyramid levels defined above (3 levels)
# Format: [start_step, end_step]
steps = np.array([
    [LEARNING_RATE, LEARNING_RATE],       # Coarse level
    [LEARNING_RATE, LEARNING_RATE * 0.5], # Medium level
    [LEARNING_RATE * 0.5, 0.01]           # Fine level
])
reg.set_step_lengths(steps)

print("Registration configured.")

Registration configured.


# Run Registration

In [5]:
debug_path = os.path.join(OUTPUT_FOLDER, 'debug')
os.makedirs(debug_path, exist_ok=True)

print("Starting registration... (Check terminal/notebook logs for progress)")
reg.initialize(debug_path)
reg.run()

# Retrieve result
(final_transform, final_metric) = reg.get_output(0)
print(f"\nFinal Metric (Lower is better): {final_metric}")
print("Transform Parameters:", final_transform.get_params())

Starting registration... (Check terminal/notebook logs for progress)
#1. --- Value: 1.341205961881025, Grad: [ 0.99999995 -1.          1.          0.99999997 -0.99998197 -0.9999792 ], Param: [ 5.28403739e-08  9.99999995e-01 -9.99999996e-01  3.26552471e-08
  9.99981973e-01  9.99979198e-01]
#26. --- Value: 3.715690814378383, Grad: [-0.08689003  0.07180036 -0.21198293 -0.02377622 -0.5097322  -0.47812568], Param: [-0.24049744  0.61551504 -0.7458162  -0.54614288 18.46043866 10.40877757]
#51. --- Value: 3.5859263546750997, Grad: [ 0.01642573 -0.01995137  0.04302359 -0.00383861 -0.67270434 -0.41831526], Param: [-0.48039225  0.68681678 -0.90705303 -0.60787225 32.55783482 18.35005305]
#76. --- Value: 3.2921748702990365, Grad: [ 0.00333428  0.0174134  -0.01632782 -0.00422634 -0.66916746 -0.63371577], Param: [-0.48020374  0.67271675 -0.95569014 -0.65722678 49.00560936 31.11899123]
#101. --- Value: 3.283373077504367, Grad: [-0.0072477   0.00290045  0.0066129   0.01314448 -0.54367972 -0.48537104], 

# fix deprication

In [6]:
import os
import re

# 1. Define the directory to search (current directory)
target_folder = '.' 

# 2. Define the Regex Pattern
# strict_pattern matches "np.int" ONLY if it is NOT followed by a letter, number, or underscore.
# This ensures we convert "np.int(x)" -> "int(x)"
# But we DO NOT touch "np.int32" or "np.interpolate"
strict_pattern = r"np\.int(?![a-zA-Z0-9_])"

print("Starting automatic fix for 'np.int' deprecation...")
count = 0

# 3. Walk through all files
for root, dirs, files in os.walk(target_folder):
    for filename in files:
        if filename.endswith(".py"):
            filepath = os.path.join(root, filename)
            
            with open(filepath, 'r') as f:
                content = f.read()
            
            # Check if the file contains the target pattern
            if re.search(strict_pattern, content):
                print(f"  Fixing: {filepath}")
                
                # Apply the safe replacement
                new_content = re.sub(strict_pattern, "int", content)
                
                # Write the changes back
                with open(filepath, 'w') as f:
                    f.write(new_content)
                count += 1

print(f"\nDone! Fixed {count} files.")
print("IMPORTANT: Please go to 'Kernel' > 'Restart Kernel' now to apply changes.")

Starting automatic fix for 'np.int' deprecation...
  Fixing: ./venv/lib/python3.14/site-packages/pandas/tests/io/excel/test_writers.py
  Fixing: ./venv/lib/python3.14/site-packages/dask/array/tests/test_reductions.py
  Fixing: ./venv/lib/python3.14/site-packages/scipy/special/_basic.py
  Fixing: ./py_alpha_amd_release/distances/q_image.py
  Fixing: ./py_alpha_amd_release/filters/filt.py

Done! Fixed 5 files.
IMPORTANT: Please go to 'Kernel' > 'Restart Kernel' now to apply changes.


# Visualizing the Result

In [8]:
import imageio.v2 as imageio
# Create the centering transform wrapper
center_fix = make_image_centered_transform(final_transform, ref_im, flo_im)

# Create a blank canvas for the warped result
ref_im_warped = np.zeros(ref_im.shape, dtype=np.float32)

# Warp the floating image
#center_fix.warp(In=flo_im, Out=ref_im_warped, mode='linear', bg_value=0.0)
center_fix.warp(In=flo_im, Out=ref_im_warped, out_spacing=np.ones(ref_im.ndim), mode='linear', bg_value=0.0)

# Save Comparison (Reference vs Registered)
# Combine side-by-side
comparison = np.hstack((ref_im, ref_im_warped))
save_path = os.path.join(OUTPUT_FOLDER, f"Registered_Check_{TARGET_CORE}.png")

# Normalize to 0-255 for saving
save_img = (comparison * 255).astype(np.uint8)
imageio.imwrite(save_path, save_img)

print(f"Saved comparison to: {save_path}")

Saved comparison to: /data3/junming/3D-TMA-Register/Registered/Core_11/Registered_Check_Core_11.png


In [11]:
import numpy as np
import imageio.v2 as imageio
import os

# Ensure the images are 0.0 - 1.0 floats for color mixing
# (They should be already, but this is a safety clip)
ref_norm = np.clip(ref_im, 0.0, 1.0)
reg_norm = np.clip(ref_im_warped, 0.0, 1.0)

# --- METHOD 1: RED-GREEN OVERLAY ---
# Create an empty RGB image
h, w = ref_norm.shape
rgb_overlay = np.zeros((h, w, 3), dtype=np.float32)

# Assign Channels
rgb_overlay[..., 0] = ref_norm      # Red   = Reference
rgb_overlay[..., 1] = reg_norm      # Green = Registered
# rgb_overlay[..., 2] is left as 0 (Blue)

# Convert to 0-255 for saving
overlay_save = (np.clip(rgb_overlay, 0, 1) * 255).astype(np.uint8)

# Save
overlay_path = os.path.join(OUTPUT_FOLDER, f"Overlay_RG_{TARGET_CORE}.png")
imageio.imwrite(overlay_path, overlay_save)
print(f"Saved Red/Green Overlay to: {overlay_path}")

# --- METHOD 2: 50/50 BLEND (Grayscale) ---
# Simple average of the two images
blend = (ref_norm * 0.5) + (reg_norm * 0.5)
blend_save = (np.clip(blend, 0, 1) * 255).astype(np.uint8)

blend_path = os.path.join(OUTPUT_FOLDER, f"Overlay_Blend_{TARGET_CORE}.png")
imageio.imwrite(blend_path, blend_save)
print(f"Saved 50/50 Blend to:       {blend_path}")

Saved Red/Green Overlay to: /data3/junming/3D-TMA-Register/Registered/Core_11/Overlay_RG_Core_11.png
Saved 50/50 Blend to:       /data3/junming/3D-TMA-Register/Registered/Core_11/Overlay_Blend_Core_11.png
