## Install environment
pip install --upgrade "git+https://github.com/JEFworks-Lab/STalign.git"

In [2]:
# Refer to example code within STalign
# https://github.com/JEFworks-Lab/STalign/blob/main/docs/notebooks/xenium-heimage-alignment.ipynb
# Tested on Python=3.11.11
# Other dependencies are described in STalign original repo.
# 
# Please Note: This method requires to provide pairs of landmarks: pointsI and pointsJ

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import tifffile as tf
from STalign import STalign
import math
import seaborn as sns


In [3]:
def get_cell_loc(HE_quant_fn, col_name_x='Centroid X µm', col_name_y='Centroid Y µm'):
    HE_quant_df = pd.read_csv(HE_quant_fn, sep='\t')
    he_x = HE_quant_df[col_name_x]
    he_y = HE_quant_df[col_name_y]
    source = np.array([he_x, he_y]).T
    return source


In [4]:

def get_rotation_translation(matrix):
    # Calculate rotation angle in degrees
    rotation_angle = math.degrees(math.atan2(matrix[1, 0], matrix[0, 0]))
    # Calculate translation distance
    translation_distance = math.sqrt(matrix[0, 2]**2 + matrix[1, 2]**2)

    return rotation_angle, translation_distance

def draw_points(source_points, target_points, title):
    plt.figure(1)
    plt.scatter(source_points[:, 1], source_points[:, 0], c='r', s=10)
    plt.scatter(target_points[:, 1], target_points[:, 0], c='b', s=10)
    plt.legend(["HE", "MxIF"])
    plt.title(title)
    plt.axis('equal')
    plt.show()

In [None]:
def STalign_registration(HE_quant_fn, MxIF_quant_fn, HE_img_fn):
    HE_centroids = get_cell_loc(HE_quant_fn)   # you may need to specify column names to get the cell coordinates
    MxIF_centroids = get_cell_loc(MxIF_quant_fn)

    xM = HE_centroids[:, 0]
    yM = HE_centroids[:, 1]
    xN = MxIF_centroids[:, 0]
    yN = MxIF_centroids[:, 1]

    he_img = tf.TiffFile(HE_img_fn).pages[0].asarray().astype(np.float)  # read the images

    Inorm = STalign.normalize(he_img)
    I = Inorm.transpose(2,0,1)
    YI = np.array(range(I.shape[1]))*1. # needs to be longs not doubles for STalign.transform later so multiply by 1.
    XI = np.array(range(I.shape[2]))*1. # needs to be longs not doubles for STalign.transform later so multiply by 1.

    XJ,YJ,M = STalign.rasterize(xM, yM, dx=30, draw=0)
    J = np.vstack((M, M, M)) # make into 3xNxM
    # normalize
    J = STalign.normalize(J)

    # Randomly select 4 points from HE_centroids and the corresponding points from MxIF_centroids
    indices = np.random.choice(len(xM), 4, replace=False)
    selected_HE_points = HE_centroids[indices]
    selected_MxIF_points = MxIF_centroids[indices]

    temp = selected_HE_points[:,0].copy()
    selected_HE_points[:,0]  = selected_HE_points[:,1]
    selected_HE_points[:,1]  =  temp

    temp = selected_MxIF_points[:,0].copy()
    selected_MxIF_points[:,0]  = selected_MxIF_points[:,1]
    selected_MxIF_points[:,1]  =  temp

    pointsI = selected_HE_points
    pointsJ = selected_MxIF_points

    L,T = STalign.L_T_from_points(pointsI,pointsJ)

    if torch.cuda.is_available():
        torch.set_default_device('cuda:0')
    else:
        torch.set_default_device('cpu')

    #  run LDDMM
    if torch.cuda.is_available():
        device = 'cuda:0'
    else:
        device = 'cpu'

    # keep all other parameters default
    params = {'L':L,'T':T,
            'niter':2000,
            'pointsI':pointsI,
            'pointsJ':pointsJ,
            'device':device,
            'sigmaM':0.15, 
            'sigmaB':0.10,
            'sigmaA':0.11,
            'epV': 10,
            'muB': torch.tensor([0,0,0]), # black is background in target
            'muA': torch.tensor([1,1,1]) # use white as artifact 
            }

    out = STalign.LDDMM([YI,XI],I,[YJ,XJ],J,**params)
    A = out['A']
    v = out['v']
    xv = out['xv']

    return A, v, xv

In [None]:
#  Quantitatively evaluate the alignment results
# 1. calculate delta theta and delta translation based on the homography matrix.
# 2. Apply the homography matrix to the manually labeled landmarks of image 1, and compare with the landmarks of image 2, calculate delta distance.

# add setting.json to the .vscode folder
# '''
# {
#     "python.analysis.extraPaths": [
#         "./MultimodalityHistoComb/release/eval"
#     ]
# }
# '''
from eval_utils import *

DEBUG = False

Sec = 2

HE_pixel_size = 0.2201  # unit micron
MxIF_pixel_size = 0.325

core_list_fn = "MultimodalityHistoComb/release/core_list.txt"
core_list = open(core_list_fn, 'r').readlines()
roi_id_list = []
for core in core_list:
    roi_id = core.split(".")[0]
    roi_id_list.append(roi_id)


img_data_dir = "/Users/jjiang10/Data/OV_TMA"  #TODO: change to your own data path
output_dir = "/Users/jjiang10/Data/OV_TMA/align_eval"  #TODO: change to your own data path

data_root_dir = "/temp/Ovarian_TMA"  #TODO: change to your own data path
eval_data_root = os.path.join(data_root_dir, "AlignmentEval", "GroundTruthEvaluation")
anno_data_root = os.path.join(data_root_dir, "AlignmentEval", "GroundTruthAnnotation")
# /temp/Ovarian_TMA/AlignmentEval/QuPathAnnoProj_MxIF/export/
HE_cell_quant_fn = os.path.join(data_root_dir, "AlignmentEval", "QuPathAnnoProj_HE", "export")
MxIF_cell_quant_fn = os.path.join(data_root_dir, "AlignmentEval", "QuPathAnnoProj_MxIF", "export")
if Sec == 1:
    # Sec1
    ground_truth_output_dir = os.path.join(eval_data_root, "GT_HE_Sec1_MxIF")
    HE_export_dir = os.path.join(anno_data_root, "HE_Sec1")
elif Sec == 2:
    # Sec2
    ground_truth_output_dir = os.path.join(eval_data_root, "GT_HE_Sec2_MxIF")
    HE_export_dir = os.path.join(anno_data_root, "HE_Sec2")
else:
    raise Exception("Undefined Section")

MxIF_export_dir = os.path.join(anno_data_root, "MxIF_Sec1")


if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_csv_fn = os.path.join(output_dir, "STalign_alignment_eval_Sec%d.csv" % Sec)
df = pd.DataFrame(columns=["ROI", "GT_dist", "Test_dist", "GT_angle", "Test_angle", "GT_landmark", "Test_landmark"])

for roi in roi_id_list:
    print("Processing %s" % roi)
    HE_landmarks_fn = os.path.join(HE_export_dir, roi+"_align_anno_8.csv")
    MxIF_landmarks_fn = os.path.join(MxIF_export_dir, roi+"_align_anno_8.csv")
    point_sort_fn = os.path.join(data_root_dir, "GroundTruthAnnotation", "Landmark_correspondence", "Sec%d" %Sec, roi+".csv")

    gt_trans_fn = os.path.join(eval_data_root, "GT_HE_Sec%d_MxIF" % Sec, roi+"_gt_trans.npy")

    HE_img_path = os.path.join(img_data_dir, f'HE_{roi}.tif')
    MxIF_img_path = os.path.join(img_data_dir, f'MxIF_{roi}.ome.tif')

    # /temp/Ovarian_TMA/AlignmentEval/QuPathAnnoProj_MxIF/export/A-8_StarDist_QUANT.tsv
    HE_quant_fn = os.path.join(HE_cell_quant_fn, roi+"_StarDist_QUANT.tsv")
    MxIF_quant_fn = os.path.join(MxIF_cell_quant_fn, roi+"_StarDist_QUANT.tsv")


    aff_matrix, _, _ = STalign_registration(HE_quant_fn, MxIF_quant_fn, HE_img_path)

    # HE_img, MxIF_img = load_images(HE_img_path, MxIF_img_path)
    # aligned_img, good_matches, keypoints1, keypoints2, aff_matrix = sift_alignment(HE_img, MxIF_img)

    sorted_HE_landmarks, sorted_MxIF_landmarks = get_sorted_annotation_landmark_pairs(HE_landmarks_fn, MxIF_landmarks_fn, point_sort_fn)
    sorted_HE_landmarks_um = sorted_HE_landmarks * HE_pixel_size
    sorted_MxIF_landmarks_um = sorted_MxIF_landmarks * MxIF_pixel_size

    gt_M = np.load(gt_trans_fn)

    # check distance between transformed source landmark points and target landmark points
    gt_trans_HE_landmarks_um = apply_aff_trans2points(sorted_HE_landmarks_um, gt_M)
    gt_landmark_dist = calculate_transformed_landmark_dist(gt_trans_HE_landmarks_um, sorted_MxIF_landmarks_um)

    trans_HE_landmarks_um = apply_aff_trans2points(sorted_HE_landmarks, aff_matrix)*MxIF_pixel_size
    test_landmark_dist = calculate_transformed_landmark_dist(trans_HE_landmarks_um, sorted_MxIF_landmarks_um)

    gt_rotation_angle, gt_translation_distance = get_rotation_translation(gt_M)
    test_rotation_angle, test_translation_distance = get_rotation_translation(aff_matrix)
    if DEBUG:
        draw_points(gt_trans_HE_landmarks_um, sorted_MxIF_landmarks_um, "Landmarks with ground truth transformation")
        draw_points(trans_HE_landmarks_um, sorted_MxIF_landmarks_um, "Landmarks with test transformation")

        print("\t Distance differences: %f\t %f" % (gt_landmark_dist, test_landmark_dist))
    data_list = [roi, gt_translation_distance, test_translation_distance, gt_rotation_angle, test_rotation_angle, gt_landmark_dist, test_landmark_dist] 

    df = pd.concat([pd.DataFrame([data_list], columns=df.columns), df], ignore_index=True)

df.to_csv(output_csv_fn)

# draw plots
vals = list(abs(df['Test_landmark']-df['GT_landmark']))
avg = sum(vals)/len(vals)
plt.figure(figsize=(4,3), dpi=300)
plt.xlabel(r'$\mu$' + "m")
plt.ylabel("percentage")
sns.histplot(vals, stat='percent', bins=5, shrink=0.5, color='red', edgecolor="black")
plt.axvline(avg, 0, 0.8, color='brown', linestyle='--')
plt.title(r'$\Delta$' +" Distance")
plt.tight_layout()
plt.show()


vals = list(abs(df['Test_angle']-df['GT_angle']))
avg = sum(vals)/len(vals)
plt.figure(figsize=(4,3), dpi=300)
plt.xlabel("Degree" + '$^\circ$')
plt.ylabel("percentage")
sns.histplot(vals, stat='percent', bins=5, shrink=0.5, color='green', edgecolor="black")
plt.axvline(avg, 0, 0.8, color='brown', linestyle='--')
plt.title(r'$\Delta$' + " Rotation")
plt.tight_layout()
plt.show()

vals = list(abs(df['Test_dist']-df['GT_dist']))
avg = sum(vals)/len(vals)
plt.figure(figsize=(4,3), dpi=300)
plt.xlabel(r'$\mu$' + "m")
plt.ylabel("percentage")
sns.histplot(vals, stat='percent', bins=5, shrink=0.5, color='blue', edgecolor="black")
plt.axvline(avg, 0, 0.8, color='brown', linestyle='--')
plt.title(r'$\Delta$' + " Translation")
plt.tight_layout()
plt.show()

print("Done")

ModuleNotFoundError: No module named 'eval_utils'