# Stitch
---

#### Overview
Interactive stitching of one or several sections.

In [3]:
# indirectly enable autocomplete
%config Completer.use_jedi = False

# autoreload modules
%load_ext autoreload
%autoreload 2

In [4]:
import pathlib
import requests

from tqdm.notebook import tqdm
import renderapi

from scripted_render_pipeline import basic_auth
from interactive_render import plotting

#### `render-ws` environment variables

In [5]:
# create an authorized session
auth = basic_auth.load_auth()
sesh = requests.Session()
sesh.auth = auth

project = "20231107_MCF7_UAC_test"
# render-ws environment variables
params_render = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "akievits",
    "project": project,
    "session": sesh
}

params_align = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "akievits",
    "project": project,
    "memGB": '2G', # Allocate enough memory
    "session": sesh
}

# set project directory
dir_project = pathlib.Path("/long_term_storage/akievits/FAST-EM/tests/20231107_MCF7_UAC_test/")

# set max_workers (for multithreading)
max_workers = 10

## Stitching (part I)
---
Get tile pairs. Get point matches.

### Get tile pairs

Use render-ws `tilePairClient` to get the set of tile pairs.

In [6]:
# choose stack from which to get tile pairs
stack = "postcorrection"
z_values = [int(z) for z in renderapi.stack.get_z_values_for_stack(
    stack,
    **params_render
)]

# Get tile pairs from the rough aligned stack
tilepairs = renderapi.client.tilePairClient(
    stack=stack,
    minz=min(z_values),
    maxz=max(z_values),
    zNeighborDistance=0,  # half-height of search cylinder
    excludeSameLayerNeighbors=False,
    subprocess_mode="check_output",  # suppresses output
    **params_render
)["neighborPairs"]

# Reformat for PointMatchClient
tile_pairs_reformat = [(tp["p"]["id"], tp["q"]["id"]) for tp in tilepairs]

# Get relative positions for first tile in every tilepair
relativePositions = [tp['p']['relativePosition'] for tp in tilepairs]

# Show tile pairs
out = f"Number of tile pairs... {len(tile_pairs_reformat)}"
print(out, "\n" + "-"*len(out))
tile_pairs_reformat[:5]

Number of tile pairs... 120 
---------------------------


[('t00_z0_y4_x4', 't05_z0_y3_x4'),
 ('t00_z0_y4_x4', 't01_z0_y4_x3'),
 ('t01_z0_y4_x3', 't06_z0_y3_x3'),
 ('t01_z0_y4_x3', 't02_z0_y4_x2'),
 ('t02_z0_y4_x2', 't07_z0_y3_x2')]

### Get point matches

Use `render-ws` `PointMatchClient` script to find matching features between the neighboring megafields.

In [21]:
from renderapi.client import (
    SiftPointMatchOptions,
    MatchDerivationParameters,
    FeatureExtractionParameters
)

In [34]:
# Name for pointmatch collection
match_collection = f"{params_render['project']}_{stack}_stitch_matches"
match_collection

'20231107_MCF7_UAC_test_postcorrection_stitch_matches'

In [35]:
# renderapi.pointmatch.delete_collection(match_collection,
#                                        **params_render)

#### Set SIFT + RANSAC parameters

In [36]:
# `RANSAC` parameters
params_RANSAC = MatchDerivationParameters(
    matchIterations=None,
    matchMaxEpsilon=25,        # maximal alignment error
    matchMaxNumInliers=None,
    matchMaxTrust=None,
    matchMinInlierRatio=0.05,  # minimal inlier ratio
    matchMinNumInliers=7,      # minimal number of inliers
    matchModelType='AFFINE',   # expected transformation
    matchRod=0.92              # closest/next closest ratio
)

# `SIFT` parameters
params_SIFT = FeatureExtractionParameters(
    SIFTfdSize=8,              # feature descriptor size
    SIFTmaxScale=0.20,         # (width/height *) maximum image size
    SIFTminScale=0.05,         # (width/height *) minimum image size
    SIFTsteps=5                # steps per scale octave
)

# Combined `SIFT` & `RANSAC` parameters
params_SIFT = SiftPointMatchOptions(
    fillWithNoise=True,
    **{**params_RANSAC.__dict__,
       **params_SIFT.__dict__}
)

# Extra parameters
params_SIFT.numberOfThreads = 1  # multithreading
params_SIFT.clipWidth = 640      # N pixels included in rendered clips of LEFT/RIGHT oriented montage tiles
params_SIFT.clipHeight = 640     # N pixels included in rendered clips of TOP/BOTTOM oriented montage tiles
params_SIFT.__dict__

{'SIFTfdSize': 8,
 'SIFTmaxScale': 0.2,
 'SIFTminScale': 0.05,
 'SIFTsteps': 5,
 'matchIterations': None,
 'matchMaxEpsilon': 25,
 'matchMaxNumInliers': None,
 'matchMaxTrust': None,
 'matchMinInlierRatio': 0.05,
 'matchMinNumInliers': 7,
 'matchModelType': 'AFFINE',
 'matchRod': 0.92,
 'renderScale': None,
 'fillWithNoise': True,
 'numberOfThreads': 1,
 'clipWidth': 640,
 'clipHeight': 640}

#### Execute SIFT + RANSAC on `N` cores
Runs `N` parallel threads with one thread per tile pair

In [25]:
# import concurrent.futures

In [29]:
## BROKEN
## Breaks pointMatchClient in such a way that not all matches are saved in the match collection. FIX

# futures = set()
# all_sections = {}
# executor = concurrent.futures.ThreadPoolExecutor(
# max_workers = max_workers
# )

# try:
#     for tile_pair, pos in tqdm(zip(tile_pairs_reformat, relativePositions),
#                                desc="submitting tile pairs",
#                                total=len(tile_pairs_reformat),
#                                unit="tilepairs",
#                                smoothing=0.3):
#         params_SIFT.firstCanvasPosition = pos
#         future = executor.submit(
#             renderapi.client.pointMatchClient,
#             stack=stack,
#             collection=match_collection,
#             tile_pairs=[tile_pair],
#             sift_options=params_SIFT,
#             excludeAllTransforms=True,
#             subprocess_mode='check_output',  # suppresses output
#             **params_align)
#         futures.add(future)

#     for future in tqdm(
#         concurrent.futures.as_completed(futures),
#         desc="extracting point matches",
#         total=len(futures),
#         unit="tilepairs",
#         smoothing=min(100 / len(futures), 0.3),
#     ):
#         futures.remove(future)
        
# finally:
#     for future in futures:
#         future.cancel()
#     executor.shutdown()

submitting tile pairs:   0%|          | 0/120 [00:00<?, ?tilepairs/s]

extracting point matches:   0%|          | 0/120 [00:00<?, ?tilepairs/s]

In [8]:
# Loop through tile pairs
for tp, pos in tqdm(zip(tile_pairs_reformat, relativePositions),
                    desc="extracting point matches",
                    total=len(tile_pairs_reformat),
                    unit="tilepairs",
                    smoothing=0.3
                    ):

    # Provide relative position of first tile so that the
    # client script knows how to clip the images properly
    params_SIFT.firstCanvasPosition = pos

    # Run SIFT + RANSAC via render-ws PointMatchClient
    renderapi.client.pointMatchClient(
        stack=stack,
        collection=match_collection,
        tile_pairs=[tp],
        sift_options=params_SIFT,
        excludeAllTransforms=True,
        subprocess_mode='check_output',  # suppresses output
        **params_render
    )

extracting point matches:   0%|          | 0/120 [00:00<?, ?tilepairs/s]

NameError: name 'params_SIFT' is not defined

### Inspect

In [38]:
plotting.plot_stack_with_stitching_matches(
    stack,
    match_collection,
    params_render,
    width=256
)

interactive(children=(IntSlider(value=0, description='z', max=2), Output()), _dom_classes=('widget-interact',)…

## Stitching (part II)
---

### Create alignment files

In [39]:
from pathlib import Path
import os
import subprocess
import json
from pprint import pprint

In [40]:
z_values = renderapi.stack.get_z_values_for_stack(stack,
                                                  **params_render)
# Load align.json template
template_stitch_json = Path('../templates/montage.json')
with template_stitch_json.open('r') as json_data:
    params_stitch = json.load(json_data)

# Edit BigFeta solver schema
params_stitch['first_section'] = min(z_values)
params_stitch['last_section'] = max(z_values)
params_stitch['transformation'] = 'TranslationModel'

# Edit input stack data
params_stitch['input_stack']['host'] = params_render['host']
params_stitch['input_stack']['owner'] = params_render['owner']
params_stitch['input_stack']['project'] = params_render['project']
params_stitch['input_stack']['name'] = stack

# Edit point match stack data
params_stitch['pointmatch']['host'] = params_render['host']
params_stitch['pointmatch']['owner'] = params_render['owner']
params_stitch['pointmatch']['name'] = match_collection

# Edit output stack data
params_stitch['output_stack']['host'] = params_render['host']
params_stitch['output_stack']['owner'] = params_render['owner']
params_stitch['output_stack']['project'] = params_render['project']
params_stitch['output_stack']['name'] = f"{stack}_stitched"

# Edit regularization parameters
params_stitch['regularization']['default_lambda'] = 0.005      # default: 0.005
params_stitch['regularization']['translation_factor'] = 0.005  # default: 0.005

# Export montage settings to
stitch_json = dir_project / '_jsons_montage' / stack / 'montage.json'
stitch_json.parent.mkdir(parents=True, exist_ok=True)
with stitch_json.open('w') as json_data:
    json.dump(params_stitch, json_data, indent=2)

# Check alignment parameters
print(stitch_json)
print('-'*len(stitch_json.as_posix()))
pprint(params_stitch)

/long_term_storage/akievits/FAST-EM/tests/20231107_MCF7_UAC_test/_jsons_montage/postcorrection/montage.json
-----------------------------------------------------------------------------------------------------------
{'close_stack': 'True',
 'first_section': 0.0,
 'hdf5_options': {'chunks_per_file': -1, 'output_dir': ''},
 'input_stack': {'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
                 'collection_type': 'stack',
                 'db_interface': 'render',
                 'host': 'http://localhost',
                 'name': 'postcorrection',
                 'owner': 'akievits',
                 'port': 8081,
                 'project': '20231107_MCF7_UAC_test'},
 'last_section': 2.0,
 'log_level': 'INFO',
 'matrix_assembly': {'cross_pt_weight': 1.0,
                     'depth': 2,
                     'inverse_dz': 'True',
                     'montage_pt_weight': 1.0,
                     'npts_max': 500,
                     'npts_mi

### Compute optimal transformations with `BigFeta`

In [41]:
# Path to `BigFeta`
cwd = Path.cwd().as_posix()
dir_BigFeta = Path('/home/catmaid/BigFeta/')

# Select json for stitching
stitch_json = dir_project / '_jsons_montage' / stack / 'montage.json'

# Call `BigFeta.BigFeta` process -- have to switch to BigFeta directory
os.chdir(dir_BigFeta.as_posix())
subprocess.run(['python', '-m', 'bigfeta.bigfeta', '--input_json', stitch_json.as_posix()])
os.chdir(cwd)

INFO:bigfeta.utils:
 loaded 25 tile specs from 1 zvalues in 0.1 sec using interface: render
INFO:__main__: A created in 0.2 seconds
INFO:__main__:
 solved in 0.0 sec
 precision [norm(Kx-Lm)/norm(Lm)] = 6.3e-16, 7.6e-16
 error     [norm(Ax-b)] = 367.359, 277.496
 [mean(Ax) +/- std(Ax)] : 0.0 +/- 5.4, 0.0 +/- 4.1
 [mean(error mag) +/- std(error mag)] : 6.0 +/- 3.1
INFO:__main__:
 scales: 1.00 +/- 0.00, 1.00 +/- 0.00
INFO:bigfeta.utils:
ingesting results to http://localhost:8081 akievits__20231107_MCF7_UAC_test__postcorrection_stitched
INFO:bigfeta.utils:render output is going to /dev/null
INFO:bigfeta.utils:
 loaded 25 tile specs from 1 zvalues in 0.0 sec using interface: render
INFO:__main__: A created in 0.1 seconds
INFO:__main__:
 solved in 0.0 sec
 precision [norm(Kx-Lm)/norm(Lm)] = 7.2e-16, 7.2e-16
 error     [norm(Ax-b)] = 340.132, 255.254
 [mean(Ax) +/- std(Ax)] : 0.0 +/- 5.5, 0.0 +/- 4.1
 [mean(error mag) +/- std(error mag)] : 6.1 +/- 3.2
INFO:__main__:
 scales: 1.00 +/- 0.00, 1.

### Inspect

In [42]:
# Plot stacks
stacks = ["postcorrection", "postcorrection_stitched"]
plotting.plot_stacks(
    stacks,
    **params_render
)

  0%|          | 0/2 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=28502, description='vmin', m…

### Filter tiles with no point matches (substrate/resin tiles)
---

In [None]:
z_values = renderapi.stack.get_z_values_for_stack(stack=f"{stack}_stitched",
                                                  **params_render)
tileIds = renderapi.stack.get_stack_tileIds(stack=f"{stack}_stitched",
                                            **params_render)
match_tileIds = []
# Loop over z-values
for z in tqdm(z_values,
              desc="Finding tiles with matches",
              total=len(z_values),
              unit=" z-levels"):
    # SectionId to find tilepairs and matches for
    sectionId = renderapi.stack.get_sectionId_for_z(stack=f"{stack}_stitched",
                                                    z=z,
                                                    **params_render)
    # Get match groupIds (all Ids for which there exist matches)
    matches = renderapi.pointmatch.get_matches_within_group(match_collection,
                                                            groupId=sectionId,
                                                            **params_render)
    # tileIds with matches
    for match in matches:
        match_tileIds.append(match['pId'])
# Get all tileIds
all_tile_Ids = renderapi.stack.get_stack_tileIds(stack=f"{stack}_stitched",
                                                 **params_render)
# If no matches, island tile
island_tileIds = [tileId for tileId in all_tile_Ids if tileId not in match_tileIds]

In [None]:
# Clone stack to create filtered stack
renderapi.stack.clone_stack(f"{stack}_stitched",
                            f"{stack}_stitched_filtered",
                            **params_render)

# Filter stack
renderapi.stack.set_stack_state(stack=f"{stack}_stitched_filtered",
                                state='LOADING',
                                **params_render)
for tileId in island_tileIds:
    renderapi.stack.delete_tile(stack=f"{stack}_stitched_filtered",
                                tileId=tileId,
                                **params_render)                                
renderapi.stack.set_stack_state(stack=f"{stack}_stitched_filtered",
                                state='COMPLETE',
                                **params_render)

### Inspect filtered stack

In [None]:
# Plot stacks
stacks = [f"{stack}_stitched", f"{stack}_stitched_filtered"]
plotting.plot_stacks(
    stacks,
    **params_render
)