# Align
---

#### Overview
Interactive 3D alignment of serial sections.

In [1]:
import pathlib
import requests

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import renderapi
import os

from scripted_render_pipeline import basic_auth
from scripted_render_pipeline.importer import uploader
from interactive_render import plotting
from interactive_render.utils import clear_image_cache

#### `render-ws` environment variables

In [2]:
# create an authorized session
auth = basic_auth.load_auth()
sesh = requests.Session()
sesh.auth = auth

# render-ws environment variables
params_render = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "akievits",
    "project": "20230914_RP_exocrine_partial_test",
    "session": sesh
}

params_uploader = {
    "host": "https://sonic.tnw.tudelft.nl",
    "owner": "akievits",
    "project": "20230914_RP_exocrine_partial_test",
    "auth": auth
}

# set project directory
dir_project = pathlib.Path("/long_term_storage/akievits/FAST-EM/20230914_RP_exocrine_partial_test/")

## 2) Rough alignment (I)
---
Perform rough alignment of downsamples section images

### Create downsampled montage stack

In [3]:
from interactive_render.utils import create_downsampled_stack

In [4]:
# Set stacks for downsampling
stack_2_downsample = {
    'in': 'postcorrection_stitched',
    'out': 'postcorrection_dsmontages'
}

In [5]:
# Create downsampled stack
ds_stack = create_downsampled_stack(dir_project, 
                                    stack_2_downsample, 
                                    **params_render)
# Upload
# initialize uploader
uppity = uploader.Uploader(
        **params_uploader,
        clobber=False
)

# import stack to render-ws
uppity.upload_to_render(
    stacks=[ds_stack],
    z_resolution=100
)

  0%|          | 0/3 [00:00<?, ?it/s]

uploading: 100%|██████████| 1/1 [00:00<00:00,  4.24stacks/s]


### Inspect downsampled montage stack

In [6]:
# plot stack
plotting.plot_stacks(
    [ds_stack.name],
    width=1000,
    vmin=0,
    vmax=65535,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=26, description='vmin', max=…

## 3) Rough alignment (II)
---
Get point matches for `dsmontage` stack and roughly align
### Align `dsmontage` stack

In [7]:
ds_stack_2_align = {
    'in': 'postcorrection_dsmontages',
    'out': 'postcorrection_dsmontages_aligned'
}

### Get point matches

Use `render-ws` `PointMatchClient` script to find matching features between the neighboring z-levels
#### Collect tile pairs

In [8]:
# choose stack from which to get tile pairs
z_values = [int(z) for z in renderapi.stack.get_z_values_for_stack(
    ds_stack_2_align['in'],
    **params_render
)]

# Get tile pairs from the rough aligned stack
tilepairs = renderapi.client.tilePairClient(
    stack=ds_stack_2_align['in'],
    minz=min(z_values),
    maxz=max(z_values),
    zNeighborDistance=1,  # half-height of search cylinder
    excludeSameLayerNeighbors=False,
    subprocess_mode="check_output",  # suppresses output
    **params_render
)["neighborPairs"]

# Show tile pairs
out = f"Number of tile pairs... {len(tilepairs)}"
print(out, "\n" + "-"*len(out))
tilepairs[:5]



Number of tile pairs... 2 
-------------------------


[{'p': {'groupId': 'S003', 'id': 't0_z0.0_y0_x0'},
  'q': {'groupId': 'S004', 'id': 't0_z1.0_y0_x0'}},
 {'p': {'groupId': 'S004', 'id': 't0_z1.0_y0_x0'},
  'q': {'groupId': 'S005', 'id': 't0_z2.0_y0_x0'}}]

In [None]:
# renderapi.pointmatch.delete_collection(match_collection,
#                                        **params_render)

In [10]:
# Name for pointmatch collection
match_collection = f"{params_render['project']}_{ds_stack_2_align['in']}_matches"
match_collection

'20230914_RP_exocrine_partial_test_postcorrection_dsmontages_matches'

In [None]:
from skimage.measure import ransac
from skimage.morphology import binary_dilation
from skimage.filters import threshold_li
from skimage.transform import AffineTransform

from interactive_render.prematching import (
    get_bbox_from_relative_position,
    get_image_pair_for_matching,
)
from interactive_render.features import (
    find_feature_correspondences,
    find_robust_feature_correspondences
)

#### Set SIFT + RANSAC parameters

In [11]:
# Try with same default params as PointMatchClient
params_imScale = {
    "maxScale": 0.8, # Defined in PointMatchClient, used here to fetch bbox image at maxScale for SIFT extraction
    "minScale": 0.05 # Determined by n_octaves
}

params_SIFT = {
    "upsampling": 1,  # no upsampling
    "n_octaves": 3, # 3, Octaves are defined as (maxScale, maxScale/2, maxScale/4, ..., minScale) in PointMatchClient
    "n_scales": 5, # 5, SIFTsteps per scale octave
    "sigma_min": 1.6, # 3.2 * params_imScale["maxScale"], # Not defined in PointMatchClient, but we assume it is 3.2 standard for full scale image
    "n_bins": 8, # 8, No parameter in PointMatchClient 
    "n_hist": 8 # 8, SIFTfdSize in PointMatchClient
}

params_MATCH = {
    "metric": None, # Not defined in PointMatchClient 
    "cross_check": True, # Not defined in PointMatchClient 
    "max_ratio": 0.92 # matchRod in PointMatchClient
}

params_RANSAC = {
    "model_class": AffineTransform, # matchModelType in PointMatchClient
    "min_samples": 12, # Minimal amount of data points to fit model to, related to matchMinNumInliers? 
    "residual_threshold": 25, # matchMaxEpsilon in PointMatchClient
    "max_trials": 10000, # matchIterations in PointMatchClient
    "stop_sample_num" : None # matchMaxNumInliers? 
}

NameError: name 'AffineTransform' is not defined

### Run SIFT + RANSAC

In [None]:
# initialize collection of point matches
matches = []

# loop through tile pairs
for i, tp in tqdm(enumerate(tilepairs)):
    
    # get z values from groupIds (aka sectionIds)
    z_p = renderapi.stack.get_section_z_value(
        ds_stack_2_align['in'],
        tp["p"]["groupId"],
        **params_render)
    z_q = renderapi.stack.get_section_z_value(
        ds_stack_2_align['in'],
        tp["q"]["groupId"],
        **params_render)

    # render image pair
    # get stack bounds
    bounds = renderapi.stack.get_stack_bounds(ds_stack_2_align['in'], 
                                              **params_render)
    if i == 0:
        bbox_p = renderapi.image.get_bb_image(stack=ds_stack_2_align['in'],
                                                z=z_p,
                                                x=bounds['minX'],
                                                y=bounds['minY'],
                                                width=(bounds['maxX'] - bounds['minX']),
                                                height=(bounds['maxY'] - bounds['minY']),
                                                scale=params_imScale["maxScale"],
                                                img_format='tiff16',
                                                **params_render)
    else:
        bbox_p = bbox_q # new p = last q
    bbox_q = renderapi.image.get_bb_image(stack=ds_stack_2_align['in'],
                                            z=z_q,
                                            x=bounds['minX'],
                                            y=bounds['minY'],
                                            width=(bounds['maxX'] - bounds['minX']),
                                            height=(bounds['maxY'] - bounds['minY']),
                                            scale=params_imScale["maxScale"],
                                            img_format='tiff16',
                                            **params_render)
    # get point match candidates
    matches_p, matches_q = find_feature_correspondences(
        bbox_p,
        bbox_q,
        feature_detector="SIFT",
        params_features=params_SIFT,
        params_match=params_MATCH
    )
    # filter matches based on location
    # create mask 
    # TODO make more efficient
    thres_p, thres_q = threshold_li(bbox_p), threshold_li(bbox_q)
    mask_p, mask_q = binary_dilation(bbox_p < thres_p), binary_dilation(bbox_q < thres_q) # find regions with low intensity pixels, to exclude from matching
    masked_p, masked_q = np.argwhere(mask_p), np.argwhere(mask_q) # convert to coordinates
    # Filter point match candidates
    matches_p_filtered, matches_q_filtered = np.in1d(matches_p.astype(np.uint8), masked_p), np.in1d(matches_q.astype(np.uint8), masked_q)

    # robustify the point match candidates
    model, inliers = ransac(
        (matches_p, matches_q),
        **params_RANSAC
    )
    # Scale point matches to full image
    inliers_p, inliers_q = matches_p[inliers] * (1 / params_imScale["maxScale"]), matches_q[inliers] * (1 / params_imScale["maxScale"])
    # format matches for uploading to render-ws point match database
    d = {
        "pGroupId": tp["p"]["groupId"],  # sectionId for image P
        "qGroupId": tp["q"]["groupId"],  # sectionId for image Q
        "pId": tp["p"]["id"],  # tileId for image P
        "qId": tp["q"]["id"],  # tileId for image Q
        "matches": {
            "p": inliers_p.T.tolist(),
            "q": inliers_q.T.tolist(),
            "w": np.ones(len(inliers_p)).tolist()
        }
    }

    matches.append(d)

### Upload point matches

In [None]:
# import pointmatches
renderapi.pointmatch.import_matches(
    match_collection,
    matches,
    **params_render
)

#### Inspect

In [None]:
plotting.plot_dsstack_with_alignment_matches(
    ds_stack_2_align['in'],
    match_collection,
    width=1000,
    **params_render
)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(5, 10))
ax[0].imshow(bbox_p, cmap="Greys_r")
ax[1].imshow(bbox_q, cmap="Greys_r")