# 1. Basic Image Registration with Twocan

This notebook demonstrates a basic image registration workflow using the `Twocan` library. We will register an example same-slide IF and IMC image using the default utilities:
1. Preprocessors `IFprocessor` and `IMCprocessor`
1. `RegEstimator` to get the transformation matrix
1. `registation_trial` and `single_objective` which are designed for IF/IMC registration

Custom functions can be provided in place of these defaults to register pairs of images of different technologies or experimental set ups (eg. serial sections) 

We will set the IF as the "fixed" or target image, and IMC as the "moving" or soruce image. 

# 1. A note on data formats
# 1. Using callbacks (& other Optuna features)
# 1. Custom trial & objectives
# 1. Picking registration channels
# 1. Rescaling for mismatched resolution


In [22]:
import numpy as np
import pandas as pd
from twocan import IFProcessor, IMCProcessor, RegEstimator, registration_trial, single_objective
from twocan.callbacks import SaveTrialsDFCallback
from twocan.utils import pick_best_registration 
import matplotlib.pyplot as plt
from spatialdata import SpatialData, read_zarr
import optuna

## 1. Load Data

In [11]:
images = read_zarr('examples/01_imc_if/data/cell-line-0028-bd18455.zarr')

Both the IF and IMC image are saved into this zarr, with their channel names. The IMC image has a transformation matrix in the 'aligned' coordinate system, and we will show how this matrix was found using Twocan.  

In [18]:
images['IF']

Unnamed: 0,Array,Chunk
Bytes,1.50 MiB,1.50 MiB
Shape,"(3, 512, 512)","(3, 512, 512)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 1.50 MiB 1.50 MiB Shape (3, 512, 512) (3, 512, 512) Dask graph 1 chunks in 2 graph layers Data type uint16 numpy.ndarray",512  512  3,

Unnamed: 0,Array,Chunk
Bytes,1.50 MiB,1.50 MiB
Shape,"(3, 512, 512)","(3, 512, 512)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray


In [23]:
images['IMC']

Unnamed: 0,Array,Chunk
Bytes,82.36 MiB,82.36 MiB
Shape,"(24, 944, 953)","(24, 944, 953)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 82.36 MiB 82.36 MiB Shape (24, 944, 953) (24, 944, 953) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",953  944  24,

Unnamed: 0,Array,Chunk
Bytes,82.36 MiB,82.36 MiB
Shape,"(24, 944, 953)","(24, 944, 953)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


We will register using nuclear channels: DAPI and DNA1/DNA2 respectively. All channels with matching names to the list will be summed along the channel axis to create the registration image. That is: DAPI will be registered to the DNA1+DNA2 signal.

In [None]:
registration_channels = ['DAPI', 'DNA1', 'DNA2']

The resolution of the IF image (1.714um/px) is different from the IMC (1um/px)

In [24]:
study = optuna.create_study(direction='maximize', study_name="example_1", sampler=optuna.samplers.TPESampler(seed=435))

[I 2025-05-28 15:42:36,639] A new study created in memory with name: example_1


In [26]:
t = study.ask()


In [None]:
study.optimize(
    lambda trial: single_objective(trial, images, registration_channels, moving_image='IMC', static_image='IF', 
    moving_preprocesser=IMCProcessor(), static_preprocesser=IFProcessor()), 
    n_trials=50
)

## 2. Determine Default Parameters from Best Trial

In [None]:
study_df_full = pd.read_csv(RESULTS_CSV_PATH)

# Replicate pick_best_registration logic
def get_best_trial_from_df(df):
    temp_df = df.copy()
    # Ensure user_attrs_logical_and is not NaN and is non-negative before log
    temp_df['user_attrs_logical_and'] = temp_df['user_attrs_logical_and'].fillna(0).clip(lower=0)
    temp_df['user_attrs_logical_iou'] = temp_df['user_attrs_logical_iou'].fillna(0)
    temp_df['user_attrs_reg_image_max_corr'] = temp_df['user_attrs_reg_image_max_corr'].fillna(0)
    
    # Avoid division by zero or log of zero if max is 0 or values are 0
    log_and_plus_1 = np.log10(temp_df['user_attrs_logical_and'] + 1)
    max_log_and_plus_1 = log_and_plus_1.max()
    temp_df['norm_and'] = log_and_plus_1 / max_log_and_plus_1 if max_log_and_plus_1 > 0 else 0
    
    max_iou = temp_df['user_attrs_logical_iou'].max()
    temp_df['norm_iou'] = temp_df['user_attrs_logical_iou'] / max_iou if max_iou > 0 else 0
    
    max_corr = temp_df['user_attrs_reg_image_max_corr'].max()
    temp_df['norm_corr'] = temp_df['user_attrs_reg_image_max_corr'] / max_corr if max_corr > 0 else 0
    
    # Fill NaNs that might have occurred from division by zero (if all values were 0)
    temp_df[['norm_and', 'norm_iou', 'norm_corr']] = temp_df[['norm_and', 'norm_iou', 'norm_corr']].fillna(0)
    
    temp_df['balanced_score'] = 1/3 * abs(
        temp_df['norm_and'] * temp_df['norm_corr'] + 
        temp_df['norm_corr'] * temp_df['norm_iou'] + 
        temp_df['norm_iou'] * temp_df['norm_and']
    )
    best_row = temp_df.loc[temp_df['balanced_score'].idxmax()]
    return best_row

best_trial_series = get_best_trial_from_df(study_df_full)

default_params = {
    'IF_binarization_threshold': best_trial_series.get('params_IF_binarization_threshold', 0.5),
    'IF_gaussian_sigma': best_trial_series.get('params_IF_gaussian_sigma', 1.0),
    'IMC_arcsinh_normalize': best_trial_series.get('params_IMC_arcsinh_normalize', True),
    'IMC_arcsinh_cofactor': best_trial_series.get('params_IMC_arcsinh_cofactor', 5.0),
    'IMC_winsorization_lower_limit': best_trial_series.get('params_IMC_winsorization_lower_limit', 0.01),
    'IMC_winsorization_upper_limit': best_trial_series.get('params_IMC_winsorization_upper_limit', 0.01),
    'IMC_binarization_threshold': best_trial_series.get('params_IMC_binarization_threshold', 0.5),
    'IMC_gaussian_sigma': best_trial_series.get('params_IMC_gaussian_sigma', 1.0),
    'binarize_images': best_trial_series.get('params_binarize_images', True),
    'registration_max_features': int(best_trial_series.get('params_registration_max_features', 100000)),
    'registration_percentile': best_trial_series.get('params_registration_percentile', 0.9),
    'registration_target': best_trial_series.get('params_registration_target', 'IF')
}

print("Default parameters extracted from best trial:")
for k, v in default_params.items():
    print(f"{k}: {v}")

## 3. Perform Registration

In [None]:
# Extract image arrays
IF_image = sdata['IF']
IMC_image = sdata['IMC']

IF_arr = IF_image.data[IF_image.c.isin(registration_channels)].compute() # dask to numpy
IMC_arr = IMC_image.data[IMC_image.c.isin(registration_channels)].compute()

# Preprocess images
IF_processed = preprocess_if(
    IF_arr,
    if_scale=if_scale,
    binarize=default_params['binarize_images'],
    threshold=default_params['IF_binarization_threshold'],
    sigma=default_params['IF_gaussian_sigma']
)

IMC_processed = preprocess_imc(
    IMC_arr,
    arcsinh_norm=default_params['IMC_arcsinh_normalize'],
    arcsinh_cofactor=default_params['IMC_arcsinh_cofactor'],
    winsorize_limits=[default_params['IMC_winsorization_lower_limit'], default_params['IMC_winsorization_upper_limit']],
    binarize=default_params['binarize_images'],
    threshold=default_params['IMC_binarization_threshold'],
    sigma=default_params['IMC_gaussian_sigma']
)

print(f"IF processed shape: {IF_processed.shape}, dtype: {IF_processed.dtype}")
print(f"IMC processed shape: {IMC_processed.shape}, dtype: {IMC_processed.dtype}")

# Register images
reg = RegEstimator(
    max_features=default_params['registration_max_features'],
    percentile=default_params['registration_percentile']
)

try:
    # Ensure target (fixed image) is IF_processed if that's the registration_target
    if default_params['registration_target'] == 'IF':
        reg.fit(moving=IMC_processed, target=IF_processed) # IMC is moving, IF is target
    else:
        reg.fit(moving=IF_processed, target=IMC_processed) # IF is moving, IMC is target
    
    registration_matrix = reg.M_
    # Scores are based on the processed, binarized images used for fitting
    scores = reg.score(IMC_processed, IF_processed) 
    print("\nRegistration successful.")
    print(f"Transformation Matrix:\n{registration_matrix}")
    print(f"Scores (IOU, etc.):\n{scores}")
except cv2.error as e:
    print(f"Registration failed: {e}")
    registration_matrix = np.eye(2,3) # Placeholder
    scores = {'iou': np.nan, 'and': np.nan, 'or': np.nan, 'xor': np.nan, 'source_sum': np.nan, 'target_sum': np.nan}
except Exception as e: # Catch other potential errors during fit/score
    print(f"An error occurred during registration: {e}")
    registration_matrix = np.eye(2,3)
    scores = {'iou': np.nan, 'and': np.nan, 'or': np.nan, 'xor': np.nan, 'source_sum': np.nan, 'target_sum': np.nan}

## 4. Save Trial Data with Callback

In [None]:
# Prepare data for the callback (mimicking Optuna trial user_attrs)
user_attrs_for_callback = {
    'registration_matrix': str(registration_matrix.tolist()), # Callback expects string representation
    'logical_iou': scores.get('iou', np.nan),
    'logical_and': scores.get('and', np.nan),
    'logical_or': scores.get('or', np.nan),
    'logical_xor': scores.get('xor', np.nan),
    'prop_source_covered': scores.get('source_sum', np.nan), # Simplified, might need actual proportion
    'prop_target_covered': scores.get('target_sum', np.nan), # Simplified
    # Add other metrics if calculated and required by callback's typical usage
    # For this basic example, we mainly focus on IOU and the matrix.
    # The original Snakefile calculates many more correlations.
    'reg_image_max_corr': best_trial_series.get('user_attrs_reg_image_max_corr', np.nan), # using value from best trial as placeholder
    'corr_image_max_corr': best_trial_series.get('user_attrs_corr_image_max_corr', np.nan), # placeholder
    'objective': 'iou-single-objective', # From CSV
    'sampler': 'DefaultNotebookTrial', # Custom for this notebook
    'zarr_id': zarr_id
}

mock_trial_obj = MockFrozenTrial(
    params=default_params,
    user_attrs=user_attrs_for_callback,
    value=scores.get('iou', 0) * user_attrs_for_callback.get('reg_image_max_corr',0) # Example objective value
)

mock_study_obj = MockStudy()
mock_study_obj.add_trial(mock_trial_obj) # Callback might look at study.trials

# Instantiate and call the callback
if os.path.exists(OUTPUT_CALLBACK_CSV):
    os.remove(OUTPUT_CALLBACK_CSV) # Clear previous run
    
save_trials_callback = SaveTrialsDFCallback(OUTPUT_CALLBACK_CSV, anno_dict={
    'objective': user_attrs_for_callback['objective'],
    'sampler': user_attrs_for_callback['sampler'],
    'zarr_id': user_attrs_for_callback['zarr_id']
})

# The callback is usually called by Optuna after each trial. We call it manually here.
save_trials_callback(mock_study_obj, mock_trial_obj)

print(f"Trial data saved by callback to: {OUTPUT_CALLBACK_CSV}")
if os.path.exists(OUTPUT_CALLBACK_CSV):
    callback_df = pd.read_csv(OUTPUT_CALLBACK_CSV)
    print("\nCallback CSV content (first 5 rows):")
    print(callback_df.head())

## 5. Visualize Registration Results

In [None]:
# Adapt plot_registration_results from workflow_utils.py
def plot_registration_for_notebook(sdata_obj, trial_params_dict, trial_user_attrs_dict, metadata_info_series, output_fig_path):
    """Create visualization of registration results for the notebook.
    
    Args:
        sdata_obj: SpatialData object containing IF and IMC images
        trial_params_dict: Dictionary of parameters used for this trial (our default_params)
        trial_user_attrs_dict: Dictionary of user_attrs for this trial (like registration_matrix)
        metadata_info_series: Pandas Series with metadata (like zarr_id, if_scale, registration_channels)
        output_fig_path: Path to save the visualization
    """
    reg_channels = metadata_info_series['registration_channels'].split(' ')
    scale_if = metadata_info_series['if_scale']
    current_zarr_id = metadata_info_series['zarr_id']
    
    IF_simg = sdata_obj['IF']
    IMC_simg = sdata_obj['IMC']
    
    IF_reg_arr = IF_simg.data[IF_simg.c.isin(reg_channels)].compute()
    IMC_reg_arr = IMC_simg.data[IMC_simg.c.isin(reg_channels)].compute()
    
    # Process images using the trial parameters
    IF_proc = preprocess_if(IF_reg_arr, scale_if, 
                               trial_params_dict['binarize_images'],
                               trial_params_dict['IF_binarization_threshold'],
                               trial_params_dict['IF_gaussian_sigma'])
    
    IMC_proc = preprocess_imc(IMC_reg_arr,
                                 trial_params_dict['IMC_arcsinh_normalize'],
                                 trial_params_dict['IMC_arcsinh_cofactor'],
                                 [trial_params_dict['IMC_winsorization_lower_limit'],
                                  trial_params_dict['IMC_winsorization_upper_limit']],
                                 trial_params_dict['binarize_images'],
                                 trial_params_dict['IMC_binarization_threshold'],
                                 trial_params_dict['IMC_gaussian_sigma'])
    
    # Get registration matrix (ensure it's a numpy array)
    m_str = trial_user_attrs_dict.get('registration_matrix', '[[1,0,0],[0,1,0]]')
    if isinstance(m_str, str):
        try:
            # Simplified parsing, make robust if needed or use twocan.utils.read_M if available and working
            m_list = eval(m_str) 
            M = np.array(m_list)
        except:
            print(f"Warning: Could not parse matrix string '{m_str}'. Using identity.")
            M = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
    elif isinstance(m_str, np.ndarray):
        M = m_str
    else:
        M = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # Fallback
        
    if M.shape != (2,3):
        print(f"Warning: Matrix shape is {M.shape}, expected (2,3). Using identity.")
        M = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
        
    # Create figure
    fig, axes = plt.subplots(3, 2, figsize=(10, 15)) # Adjusted figsize
    fig.suptitle(f"Registration: {current_zarr_id}", fontsize=16)
    
    # Plot original images (summing channels for grayscale visualization)
    axes[0,0].imshow(winsorize(IMC_reg_arr.sum(0), limits=[0.05,0.05]))
    axes[0,0].set_title('IMC Original (Reg Channels Summed)', fontsize=10)
    axes[0,0].axis('off')
    
    # Resize IF to match IMC for visual comparison if scale was applied before processing
    # This is tricky because IF_processed is already scaled. For original comparison, use original IF_reg_arr
    # For now, let's show original IF at its native resolution for reg channels
    axes[0,1].imshow(winsorize(IF_reg_arr.sum(0), limits=[0.05,0.05]))
    axes[0,1].set_title('IF Original (Reg Channels Summed)', fontsize=10)
    axes[0,1].axis('off')
    
    # Plot processed images
    axes[1,0].imshow(IMC_proc, cmap='gray')
    axes[1,0].set_title('IMC Processed', fontsize=10)
    axes[1,0].axis('off')
    
    axes[1,1].imshow(IF_proc, cmap='gray')
    axes[1,1].set_title('IF Processed (Scaled)', fontsize=10)
    axes[1,1].axis('off')
    
    # Plot transformed IMC overlayed with IF contour (or similar)
    # For simplicity, let's show transformed IMC and IF side-by-side after alignment
    IMC_transformed_display = cv2.warpAffine(IMC_proc.astype(np.float32), M, (IF_proc.shape[1], IF_proc.shape[0]))
    axes[2,0].imshow(IMC_transformed_display, cmap='gray')
    axes[2,0].set_title('IMC Transformed to IF Space', fontsize=10)
    axes[2,0].axis('off')
    
    axes[2,1].imshow(IF_proc, cmap='gray') # IF processed as reference
    axes[2,1].contour(IMC_transformed_display > 0.5, colors='r', linewidths=0.5, alpha=0.7) # Overlay contour
    axes[2,1].set_title('IF Processed with Transformed IMC Contour', fontsize=10)
    axes[2,1].axis('off')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to make space for suptitle
    plt.savefig(output_fig_path)
    print(f"\nVisualization saved to {output_fig_path}")
    plt.show()

# Call the plotting function
plot_registration_for_notebook(
    sdata,
    default_params, 
    user_attrs_for_callback, # This contains the string matrix
    selected_pair_info, 
    OUTPUT_FIGURE_PATH
)

print("\nNotebook 01 execution finished.")