# Stitch
---

#### Overview
Interactive stitching of one or several sections.

In [1]:
# indirectly enable autocomplete
%config Completer.use_jedi = False

# autoreload modules
%load_ext autoreload
%autoreload 2

import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [2]:
import pathlib
import requests
import logging

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import renderapi

from scripted_render_pipeline import basic_auth
from interactive_render import plotting

#### `render-ws` environment variables

In [3]:
# # render-ws environment variables
# params_render = {
#     "host": "https://sonic.tnw.tudelft.nl",
#     "owner": "fastem",
#     "project": "20230712_RF_zstack",
#     "auth": basic_auth.load_auth()
# }

# create an authorized session
auth = basic_auth.load_auth()
sesh = requests.Session()
sesh.auth = auth

# render-ws environment variables
params_render = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "fastem",
    "project": "20230712_RF_zstack",
    "session": sesh
}

# set project directory
dir_project = pathlib.Path("/long_term_storage/asm_storage/asm_service/2023-07-12/20230712_RF_zstack/")

## Stitching (part I)
---
Get tile pairs. Get point matches.

### Get tile pairs

Use render-ws `tilePairClient` to get the set of tile pairs.

In [4]:
# choose stack from which to get tile pairs
stack = "corrected"
z_values = [int(z) for z in renderapi.stack.get_z_values_for_stack(
    stack,
    **params_render
)]

# Get tile pairs from the rough aligned stack
tilepairs = renderapi.client.tilePairClient(
    stack=stack,
    minz=min(z_values),
    maxz=max(z_values),
    zNeighborDistance=0,  # half-height of search cylinder
    excludeSameLayerNeighbors=False,
    subprocess_mode="check_output",  # suppresses output
    **params_render
)["neighborPairs"]

# Show tile pairs
out = f"Number of tile pairs... {len(tilepairs)}"
print(out, "\n" + "-"*len(out))
tilepairs[:5]

Number of tile pairs... 204 
---------------------------


[{'p': {'groupId': 'RF_brain_section_10_10us',
   'id': 't0_z8_y2_x2',
   'relativePosition': 'BOTTOM'},
  'q': {'groupId': 'RF_brain_section_10_10us',
   'id': 't3_z8_y1_x2',
   'relativePosition': 'TOP'}},
 {'p': {'groupId': 'RF_brain_section_10_10us',
   'id': 't0_z8_y2_x2',
   'relativePosition': 'RIGHT'},
  'q': {'groupId': 'RF_brain_section_10_10us',
   'id': 't1_z8_y2_x1',
   'relativePosition': 'LEFT'}},
 {'p': {'groupId': 'RF_brain_section_10_10us',
   'id': 't1_z8_y2_x1',
   'relativePosition': 'BOTTOM'},
  'q': {'groupId': 'RF_brain_section_10_10us',
   'id': 't4_z8_y1_x1',
   'relativePosition': 'TOP'}},
 {'p': {'groupId': 'RF_brain_section_10_10us',
   'id': 't1_z8_y2_x1',
   'relativePosition': 'RIGHT'},
  'q': {'groupId': 'RF_brain_section_10_10us',
   'id': 't2_z8_y2_x0',
   'relativePosition': 'LEFT'}},
 {'p': {'groupId': 'RF_brain_section_10_10us',
   'id': 't2_z8_y2_x0',
   'relativePosition': 'BOTTOM'},
  'q': {'groupId': 'RF_brain_section_10_10us',
   'id': 't5_z8_

### Get point matches

In [5]:
from skimage.measure import ransac
from skimage.transform import EuclideanTransform

from interactive_render.prematching import (
    get_bbox_from_relative_position,
    get_image_pair_for_matching,
)
from interactive_render.features import (
    find_feature_correspondences,
    find_robust_feature_correspondences
)

In [6]:
# Name for pointmatch collection
match_collection = f"{params_render['project']}_{stack}_watches"
match_collection

'20230712_RF_zstack_corrected_watches'

#### Set SIFT + RANSAC parameters

In [7]:
params_SIFT = {
    "upsampling": 1,  # no upsampling
    "n_octaves": 3,
    "sigma_min": 3
}

params_MATCH = {
    "metric": None,
    "cross_check": True,
    "max_ratio": 0.8
}

params_RANSAC = {
    "model_class": EuclideanTransform,
    "min_samples": 12,
    "residual_threshold": 4,
    "max_trials": 10000
}

#### Execute SIFT + RANSAC

In [9]:
# initialize collection of point matches
matches = []

# loop through tile pairs
for tp in tqdm(tilepairs):

    # get z value from groupId (aka sectionId)
    sectionId = tp["p"]["groupId"]
    z = renderapi.stack.get_section_z_value(
        stack,
        sectionId,
        **params_render
    )

    # render image pair
    bbox_p = get_bbox_from_relative_position(
        renderapi.tilespec.get_tile_spec(stack, tp["p"]["id"], **params_render),
        tp["p"]["relativePosition"]
    )
    bbox_q = get_bbox_from_relative_position(
        renderapi.tilespec.get_tile_spec(stack, tp["q"]["id"], **params_render),
        tp["q"]["relativePosition"]
    )
    image_p, image_q = get_image_pair_for_matching(
        "corrected",
        tp,
        **params_render
    )

    # get point matches
    inliers_p, inliers_q = find_robust_feature_correspondences(
        image_p,
        image_q,
        feature_detector="SIFT",
        params_features=params_SIFT,
        params_match=params_MATCH,
        params_RANSAC=params_RANSAC
    )
    logging.info(f"Found {len(inliers_p)} matched features between "
                 f"tiles `{tp['p']['id']}` and `{tp['q']['id']}` in "
                 f"section `{sectionId}` (z={z:.0f}).")

    # TODO: make better
    if tp["p"]["relativePosition"].lower() == "left":
        inliers_p[:, 0] += int(6400 - (bbox_p[2] - bbox_p[0]))
    if tp["p"]["relativePosition"].lower() == "top":
        inliers_p[:, 1] += int(6400 - (bbox_p[3] - bbox_p[1]))
    if tp["q"]["relativePosition"].lower() == "left":
        inliers_q[:, 0] += int(6400 - (bbox_q[2] - bbox_q[0]))
    if tp["q"]["relativePosition"].lower() == "top":
        inliers_q[:, 1] += int(6400 - (bbox_q[3] - bbox_q[1]))

    # format matches for uploading to render-ws point match database
    d = {
        "pGroupId": sectionId,  # sectionId for image P
        "qGroupId": sectionId,  # sectionId for image Q
        "pId": tp["p"]["id"],  # tileId for image P
        "qId": tp["q"]["id"],  # tileId for image Q
        "matches": {
            "p": inliers_p.T.tolist(),
            "q": inliers_q.T.tolist(),
            "w": np.ones(len(inliers_p)).tolist()
        }
    }
    matches.append(d)

### Upload point matches

In [11]:
# import pointmatches
renderapi.pointmatch.import_matches(
    match_collection,
    matches,
    **params_render
)

#### Inspect

In [12]:
plotting.plot_stack_with_stitching_matches(
    stack,
    match_collection,
    params_render,
    width=256
)

interactive(children=(IntSlider(value=0, description='z', max=16), Output()), _dom_classes=('widget-interact',…

## Stitching (part II)
---

### Create alignment files

In [14]:
from pathlib import Path
import os
import subprocess
import json
from pprint import pprint

In [21]:
# Load align.json template
template_align_json = Path('../templates/montage.json')
with template_align_json.open('r') as json_data:
    params_align = json.load(json_data)

# Loop through sections
for z in tqdm(z_values):

    # Edit BigFeta solver schema
    params_align['first_section'] = z
    params_align['last_section'] = z
    params_align['transformation'] = 'TranslationModel'

    # Edit input stack data
    params_align['input_stack']['host'] = params_render['host']
    params_align['input_stack']['owner'] = params_render['owner']
    params_align['input_stack']['project'] = params_render['project']
    params_align['input_stack']['name'] = stack

    # Edit point match stack data
    params_align['pointmatch']['host'] = params_render['host']
    params_align['pointmatch']['owner'] = params_render['owner']
    params_align['pointmatch']['name'] = match_collection

    # Edit output stack data
    params_align['output_stack']['host'] = params_render['host']
    params_align['output_stack']['owner'] = params_render['owner']
    params_align['output_stack']['project'] = params_render['project']
    params_align['output_stack']['name'] = f"{stack}_stitched"

    # Edit regularization parameters
    params_align['regularization']['default_lambda'] = 0.005      # default: 0.005
    params_align['regularization']['translation_factor'] = 0.005  # default: 0.005

    # Export montage settings to
    sectionId = renderapi.stack.get_sectionId_for_z(stack, z, **params_render)
    align_json = dir_project / f"{sectionId}/{stack}/montage.json"
    with align_json.open('w') as json_data:
        json.dump(params_align, json_data, indent=2)

# Check alignment parameters
print(align_json)
print('-'*len(align_json.as_posix()))
pprint(params_align)

  0%|          | 0/17 [00:00<?, ?it/s]

/long_term_storage/asm_storage/asm_service/2023-07-12/20230712_RF_zstack/RF_brain_section_18_10us/corrected/montage.json
------------------------------------------------------------------------------------------------------------------------
{'close_stack': 'True',
 'first_section': 16,
 'hdf5_options': {'chunks_per_file': -1, 'output_dir': ''},
 'input_stack': {'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
                 'collection_type': 'stack',
                 'db_interface': 'render',
                 'host': 'http://localhost',
                 'name': 'corrected',
                 'owner': 'fastem',
                 'port': 8081,
                 'project': '20230712_RF_zstack'},
 'last_section': 16,
 'log_level': 'INFO',
 'matrix_assembly': {'cross_pt_weight': 1.0,
                     'depth': 2,
                     'inverse_dz': 'True',
                     'montage_pt_weight': 1.0,
                     'npts_max': 500,
                

### Compute optimal transformations with `BigFeta`

In [22]:
# Path to `BigFeta`
cwd = Path.cwd().as_posix()
dir_BigFeta = Path('/home/rlane/BigFeta/')

# Loop through z values
for z in z_values:

    # Select montage json
    sectionId = renderapi.stack.get_sectionId_for_z(stack, z, **params_render)
    align_json = dir_project / f"{sectionId}/{stack}/montage.json"

    # Call `BigFeta.BigFeta` process -- have to switch to BigFeta directory
    os.chdir(dir_BigFeta.as_posix())
    subprocess.run(['python', '-m', 'bigfeta.bigfeta', '--input_json', align_json.as_posix()])
    os.chdir(cwd)

INFO:bigfeta.utils:
 loaded 9 tile specs from 1 zvalues in 0.0 sec using interface: render
INFO:__main__: A created in 0.1 seconds
INFO:__main__:
 solved in 0.0 sec
 precision [norm(Kx-Lm)/norm(Lm)] = 1.8e-16, 2.9e-16
 error     [norm(Ax-b)] = 213.728, 96.686
 [mean(Ax) +/- std(Ax)] : 0.0 +/- 4.4, 0.0 +/- 2.0
 [mean(error mag) +/- std(error mag)] : 3.9 +/- 2.9
INFO:__main__:
 scales: 1.00 +/- 0.00, 1.00 +/- 0.00
INFO:bigfeta.utils:
ingesting results to http://localhost:8081 fastem__20230712_RF_zstack__corrected_stitched2
INFO:bigfeta.utils:render output is going to /dev/null
INFO:__main__: total time: 1.4
INFO:bigfeta.utils:
 loaded 9 tile specs from 1 zvalues in 0.0 sec using interface: render
INFO:__main__: A created in 0.1 seconds
INFO:__main__:
 solved in 0.0 sec
 precision [norm(Kx-Lm)/norm(Lm)] = 1.4e-16, 2.9e-16
 error     [norm(Ax-b)] = 180.793, 83.035
 [mean(Ax) +/- std(Ax)] : 0.0 +/- 4.2, 0.0 +/- 2.0
 [mean(error mag) +/- std(error mag)] : 3.7 +/- 2.9
INFO:__main__:
 scales: 

### Inspect

In [25]:
# Plot stacks
stacks = ["corrected", "corrected_stitched"]
plotting.plot_stacks(
    stacks,
    params_render
)

  0%|          | 0/2 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=16), IntSlider(value=30201, description='vmin', …