# CORE - A Cell-Level Coarse-to-Fine Image Registration Engine for Multi-stain Image Alignment

This notebook demonstrates the complete workflow for Whole Slide Image (WSI) registration using rigid and non-rigid techniques with nuclei-based analysis.

## Overview
- **Coarse Registration**: Initial coarse alignment using CORE for global deformation estimation.
- **Fine Shape-aware Nuclei based Registration**: Nuclei centroid based fine shape-aware registration for local deformation estimation.
- **Interactive Visualisation**: TiaViz visualisation for real time deformation visualisation.


## 1. Setup and Imports

In [1]:
import sys
import os

# Add project root to PYTHONPATH
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))  # adjust if needed
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [None]:

%matplotlib inline
%load_ext autoreload
%autoreload 2
import SimpleITK as sitk
import numpy as np
from matplotlib import pyplot as plt

from core.utils.imports import *
from core.config import *
from core.preprocessing.padding import *
from core.preprocessing.preprocessing import *
from core.registration.registration import *
from core.evaluation.evaluation import *
from core.visualization.visualization import *
from core.preprocessing.nuclei_analysis import *
from core.preprocessing.stainnorm import *
from core.registration.nonrigid import *
from core.cpd import *


# Setup Bokeh for notebook output
setup_bokeh_notebook()

print("✅ All modules imported successfully!")
print(f"Source WSI: {SOURCE_WSI_PATH}")
print(f"Target WSI: {TARGET_WSI_PATH}")

## 2. Configuration Check

Verify that all file paths are correct and files exist.

In [None]:
import os

# Check if files exist
files_to_check = [
    SOURCE_WSI_PATH,
    TARGET_WSI_PATH,
    FIXED_POINTS_PATH,
    MOVING_POINTS_PATH
]

print("File existence check:")
for file_path in files_to_check:
    exists = os.path.exists(file_path)
    status = "✅" if exists else "❌"
    print(f"{status} {file_path}")

# Display current parameters
print("\nCurrent Parameters:")
print(f"- Preprocessing Resolution: {PREPROCESSING_RESOLUTION}")
print(f"- Registration Resolution: {REGISTRATION_RESOLUTION}")
print(f"- Patch Size: {PATCH_SIZE}")
print(f"- Fixed Threshold: {FIXED_THRESHOLD}")
print(f"- Moving Threshold: {MOVING_THRESHOLD}")
print(f"- Min Nuclei Area: {MIN_NUCLEI_AREA}")

## 3. Load and Preprocess WSIs

In [None]:
# Load WSI images
print("Loading WSI images...")
source_wsi, target_wsi, source, target = load_wsi_images(
    SOURCE_WSI_PATH, TARGET_WSI_PATH, PREPROCESSING_RESOLUTION
)

print(f"\nLoaded images:")
print(f"Source shape: {source.shape}")
print(f"Target shape: {target.shape}")

In [None]:
# Preprocess images
print("Preprocessing images...")
source_prep, target_prep,padding_params =pad_images(source, target)
# Extract tissue masks
print("Extracting tissue masks...")
source_mask, target_mask = extract_tissue_masks(source_prep, target_prep, artefacts=False)

print("✅ Preprocessing completed!")

# 4. Visualize Image and Tissue Masks

In [None]:
# Display original images side by side
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

axes[0, 0].imshow(source_prep)
axes[0, 0].set_title('Source Image (Moving)')
axes[0, 0].axis('off')

axes[0, 1].imshow(target_prep)
axes[0, 1].set_title('Target Image (Fixed)')
axes[0, 1].axis('off')

axes[1, 0].imshow(source_mask, cmap='gray')
axes[1, 0].set_title('Source Tissue Mask')
axes[1, 0].axis('off')

axes[1, 1].imshow(target_mask, cmap='gray')
axes[1, 1].set_title('Target Tissue Mask')
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

## 5. Coarse Registration

In [None]:
# Perform rigid registration
print("Performing rigid registration...")
moving_img_transformed, final_transform = perform_rigid_registration(
    source_prep, target_prep, source_mask, target_mask
)
visualize_overlays(target_prep, source_prep, moving_img_transformed)


In [None]:
r_x, r_y = util.matrix_df(source_prep,np.linalg.inv(final_transform))
rigid_field = np.stack(( r_x, r_y), axis=-1)
sitk_image = sitk.GetImageFromArray(rigid_field)
displacement_field,warped_source= elastic_image_registration(
   moving_img_transformed,target_prep,
)
print("non rigid displacement field",displacement_field.shape)

visualize_overlays(target_prep, source_prep,  moving_img_transformed)




In [None]:
disp_field=util.deform_conversion(displacement_field)
w_x,w_y=util.combine_deformation(u_x, u_y, disp_field[0], disp_field[1])
deformation_field = np.stack(( w_x, w_y), axis=-1)
sitk_image = sitk.GetImageFromArray(deformation_field)
sitk.WriteImage(sitk_image, './533.mha')

## 6. TIAViz Registration Visualization 

In [None]:
%%bash 
export BOKEH_ALLOW_WS_ORIGIN=localhost:5007
tiatoolbox visualize --slides "path-to-slides" --overlays "path-to-overlays"


## 7. Patch Visualization

In [None]:
# Scale transformation for high resolution analysis
transform_40x = scale_transformation_matrix(
    final_transform, PREPROCESSING_RESOLUTION, REGISTRATION_RESOLUTION
)

# Extract patches from target WSI
print("\nExtracting patches...")
fixed_patch_extractor = extract_patches_from_wsi(
    target_wsi, target_mask, PATCH_SIZE, PATCH_STRIDE
)

print(f"Total patches extracted: {len(fixed_patch_extractor)}")

## Coarse  Visualization

In [None]:
# Select a patch for visualization
patch_idx = 70  # You can change this index
loc = fixed_patch_extractor.coordinate_list[patch_idx]
location = (loc[0], loc[1])

print(f"Visualizing patch {patch_idx} at location {location}")

# Extract regions for comparison
fixed_tile = target_wsi.read_rect(location, VISUALIZATION_SIZE, resolution=40, units="power")
moving_tile = source_wsi.read_rect(location, VISUALIZATION_SIZE, resolution=40, units="power")

# Create transformer and extract transformed tile
tfm = AffineWSITransformer(source_wsi, transform_40x)
transformed_tile = tfm.read_rect(location=location, size=VISUALIZATION_SIZE, resolution=0, units="level")

# Visualize patches
visualize_patches(fixed_tile, moving_tile, transformed_tile)

In [None]:
visualize_overlays(fixed_tile, moving_tile, transformed_tile)

## 8. Interactive Nuclei Visualization

In [None]:

FIXED_NUCLEI_CSV="/home/u5552013/Nextcloud/HYRECO/Data/nuclei_points/he_533_nuclei.csv"
MOVING_NUCLEI_CSV="/home/u5552013/Nextcloud/HYRECO/Data/nuclei_points/ki67_533_nuclei.csv"

# Load nuclei coordinates
moving_df = load_nuclei_coordinates(MOVING_NUCLEI_CSV)
fixed_df = load_nuclei_coordinates(FIXED_NUCLEI_CSV)

print(f"Loaded nuclei data:")
print(f"- Fixed nuclei: {len(fixed_df)}")
print(f"- Moving nuclei: {len(moving_df)}")

# Create basic nuclei overlay plot
print("\nCreating interactive nuclei overlay plot...")
plot1 = create_nuclei_overlay_plot(moving_df, fixed_df, 
                                  "Original Nuclei Coordinates (Before Registration)")
show_plot(plot1)

In [None]:
deformation_field, moving_updated, fixed_points, moving_points= compute_deformation_and_apply(    source_prep,
    final_transform,
    displacement_field,
    moving_df,
    fixed_df,
    padding_params,
    util,
    pad_landmarks)

In [None]:
visualize_cluster_alignment(
    fixed_points,
    moving_points,
    moving_updated,
    figsize=(10, 10),
    title='Cluster Centers: Fixed, Original Moving, and Transformed',
    save_path=None
)

## 10. Shape-Aware Point Set Registration

In [None]:
# Perform shape-aware registration 
print("Performing Shape-Aware Point Set Registration...")

#  No of subsampling points can be adjusted
fixed_subsample =util.skip_subsample(fixed_df, n_samples=150000)
moving_subsample = util.skip_subsample(moving_df, n_samples=150000)

#  fine pointset registration
shape_registrator,shape_transform, shape_transformed_coords   = rigid.perform_shape_aware_registration(
    fixed_points,moving_updated,
    shape_weight=0.3,  # 30% weight for shape, 70% for spatial distance
    max_iterations=100,
    tolerance=1e-11
)


In [None]:
fixed_subsample =util.skip_subsample(fixed_points, n_samples=5000)
moving_subsample = util.skip_subsample(shape_transformed_coords, n_samples=5000)
X=fixed_subsample
Y=moving_subsample
print('Non-rigid CPD:')
cpd_nonrigid = CPD(method='nonrigid')
nonrigid_transformed_coords , P = cpd_nonrigid(X, Y, w=0, max_iterations=50)



In [None]:
fine_deform=util.create_deformation_field(shape_transform, source_prep, u_x, u_y, util, output_path='./533_finerigid.mha')

In [None]:
from scipy.ndimage import map_coordinates
import numpy as np
import pandas as pd
from scipy.interpolate import griddata
# Create displacement field for spatial transformation analysis
print("Creating displacement field...")
scale_factor = 64
source_points_scaled = moving_subsample  / scale_factor
target_points_scaled =nonrigid_transformed_coords / scale_factor
# Existing grid
H, W = u_x.shape
grid_y, grid_x = np.mgrid[0:H, 0:W]
displacement_field = create_displacement_field(
    source_points_scaled, target_points_scaled,
    target_prep.shape,
    method=RegistrationParams.INTERPOLATION_METHOD,
    sigma=RegistrationParams.DISPLACEMENT_SIGMA,
    max_displacement=RegistrationParams.MAX_DISPLACEMENT
)
fr_x,fr_y=util.combine_deformation( w_x, w_y, displacement_field[..., 0],displacement_field[..., 1])

deformation_field = np.stack(( w_x, w_y), axis=-1)
sitk_image = sitk.GetImageFromArray(deformation_field)
sitk.WriteImage(sitk_image, './533_nonrigid.mha')


## 11. TIAViz Registration Visualization

In [None]:
%%bash 
export BOKEH_ALLOW_WS_ORIGIN=localhost:5007

tiatoolbox visualize --slides "path-to-slides" --overlays "path-to-overlays"
