# Align
---

#### Overview
Interactive 3D alignment of serial sections.

In [4]:
import pathlib
import requests

from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import renderapi
import os

from scripted_render_pipeline import basic_auth
from scripted_render_pipeline.importer import uploader
from interactive_render import plotting
from interactive_render.utils import clear_image_cache

In [5]:
clear_image_cache()

<Response [200]>

#### `render-ws` environment variables

In [6]:
# create an authorized session
auth = basic_auth.load_auth()
sesh = requests.Session()
sesh.auth = auth

# render-ws environment variables
params_render = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "akievits",
    "project": "20230914_RP_exocrine_partial_test",
    "session": sesh
}

params_uploader = {
    "host": "https://sonic.tnw.tudelft.nl",
    "owner": "akievits",
    "project": "20230914_RP_exocrine_partial_test",
    "auth": auth
}

# set project directory
dir_project = pathlib.Path("/long_term_storage/akievits/FAST-EM/20230914_RP_exocrine_partial_test/")

## 2) Rough alignment (I)
---
Perform rough alignment of downsamples section images

### Create downsampled montage stack

In [70]:
from interactive_render.utils import create_downsampled_stack

In [71]:
# Set stacks for downsampling
stack_2_downsample = {
    'in': 'postcorrection_stitched',
    'out': 'postcorrection_dsmontages'
}

In [73]:
SCALE = 0.05

In [72]:
# Create downsampled stack
ds_stack = create_downsampled_stack(dir_project, 
                                    stack_2_downsample, 
                                    **params_render)
# Upload
# initialize uploader
uppity = uploader.Uploader(
        **params_uploader,
        clobber=False
)

# import stack to render-ws
uppity.upload_to_render(
    stacks=[ds_stack],
    z_resolution=100
)

  0%|          | 0/3 [00:00<?, ?it/s]

uploading: 100%|██████████| 1/1 [00:00<00:00,  4.17stacks/s]


### Inspect downsampled montage stack

In [73]:
# plot stack
plotting.plot_stacks(
    [ds_stack.name],
    width=1000,
    vmin=0,
    vmax=65535,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=26, description='vmin', max=…

## 3) Rough alignment (II)
---
Get point matches for `dsmontage` stack and roughly align
### Align `dsmontage` stack


In [8]:
ds_stack_2_align = {
    'in': 'postcorrection_dsmontages',
    'out': 'postcorrection_dsmontages_aligned'
}

### Get point matches

Use `render-ws` `PointMatchClient` script to find matching features between the neighboring z-levels
#### Collect tile pairs

In [75]:
# choose stack from which to get tile pairs
z_values = [int(z) for z in renderapi.stack.get_z_values_for_stack(
    ds_stack_2_align['in'],
    **params_render
)]

# Get tile pairs from the rough aligned stack
tilepairs = renderapi.client.tilePairClient(
    stack=ds_stack_2_align['in'],
    minz=min(z_values),
    maxz=max(z_values),
    zNeighborDistance=1,  # half-height of search cylinder
    excludeSameLayerNeighbors=False,
    subprocess_mode="check_output",  # suppresses output
    **params_render
)["neighborPairs"]

# Show tile pairs
out = f"Number of tile pairs... {len(tilepairs)}"
print(out, "\n" + "-"*len(out))
tilepairs[:5]



Number of tile pairs... 2 
-------------------------


[{'p': {'groupId': 'S003', 'id': 't0_z0.0_y0_x0'},
  'q': {'groupId': 'S004', 'id': 't0_z1.0_y0_x0'}},
 {'p': {'groupId': 'S004', 'id': 't0_z1.0_y0_x0'},
  'q': {'groupId': 'S005', 'id': 't0_z2.0_y0_x0'}}]

In [76]:
from renderapi.client import (
    SiftPointMatchOptions,
    MatchDerivationParameters,
    FeatureExtractionParameters
)

In [9]:
# Name for pointmatch collection
match_collection = f"{params_render['project']}_{ds_stack_2_align['in']}_matches"
match_collection

'20230914_RP_exocrine_partial_test_postcorrection_dsmontages_matches'

In [78]:
# renderapi.pointmatch.delete_collection(match_collection,
#                                        **params_render)

#### Set SIFT + RANSAC parameters

In [79]:
# `RANSAC` parameters
params_RANSAC = MatchDerivationParameters(
    matchIterations=None,
    matchMaxEpsilon=25,        # maximal alignment error
    matchMaxNumInliers=None,
    matchMaxTrust=None,
    matchMinInlierRatio=0.01,  # minimal inlier ratio
    matchMinNumInliers=3,      # minimal number of inliers
    matchModelType='AFFINE',   # expected transformation
    matchRod=0.92              # closest/next closest ratio
)

# `SIFT` parameters
params_SIFT = FeatureExtractionParameters(
    SIFTfdSize=8,              # feature descriptor size
    SIFTmaxScale=0.8,         # (width/height *) maximum image size
    SIFTminScale=0.2,         # (width/height *) minimum image size
    SIFTsteps=7               # steps per scale octave
)

# Combined `SIFT` & `RANSAC` parameters
params_SIFT = SiftPointMatchOptions(
    fillWithNoise=True,
    **{**params_RANSAC.__dict__,
       **params_SIFT.__dict__}
)

# Extra parameters
params_SIFT.numberOfThreads = 40  # multithreading
params_SIFT.__dict__

{'SIFTfdSize': 8,
 'SIFTmaxScale': 0.8,
 'SIFTminScale': 0.2,
 'SIFTsteps': 5,
 'matchIterations': None,
 'matchMaxEpsilon': 25,
 'matchMaxNumInliers': None,
 'matchMaxTrust': None,
 'matchMinInlierRatio': 0.01,
 'matchMinNumInliers': 3,
 'matchModelType': 'AFFINE',
 'matchRod': 0.92,
 'renderScale': None,
 'fillWithNoise': True,
 'numberOfThreads': 40}

In [80]:
# Loop through tile pairs
for tp in tqdm(tilepairs):

    # Format tile pair
    tp_ids = (tp["p"]["id"], tp["q"]["id"])

    # Run SIFT + RANSAC via render-ws PointMatchClient
    renderapi.client.pointMatchClient(
        stack=ds_stack_2_align['in'],
        collection=match_collection,
        tile_pairs=[tp_ids],
        sift_options=params_SIFT,
        excludeAllTransforms=True,
        subprocess_mode='check_output',  # suppresses output
        **params_render
    )

  0%|          | 0/2 [00:00<?, ?it/s]



### Inspect

In [10]:
plotting.plot_dsstack_with_alignment_matches(
    ds_stack_2_align['in'],
    match_collection,
    width=1000,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=1), Output()), _dom_classes=('widget-interact',)…

### Filter point matches
Filter false point matches that are located on the edges of the stack

In [83]:
# Define clipping based on stack bounds
bounds = renderapi.stack.get_stack_bounds(ds_stack_2_align['in'],
                                          **params_render)
clip = 0.05 # fraction of image size
min_X, max_X = clip * bounds['maxX'], (1-clip) * bounds['maxX'] 
min_Y, max_Y = clip * bounds['maxY'], (1-clip) * bounds['maxX'] 

# Filtered matches
matches_filtered = []

# loop through tile pairs
for tp in tqdm(tilepairs):
    # get z value from groupId (aka sectionId)
    p_sectionId = tp["p"]["groupId"]
    q_sectionId = tp["q"]["groupId"]
    matches_per_group = renderapi.pointmatch.get_matches_from_group_to_group(match_collection,
                                                                             p_sectionId,
                                                                             q_sectionId,
                                                                             **params_render)
    # Filter match coordinates
    pmatches = matches_per_group[0]['matches']['p']
    qmatches = matches_per_group[0]['matches']['q']
    p_matches = np.array(
        [[px, py] for px, py in zip(pmatches[0], pmatches[1]) if ((px >= min_X and px <= max_X) and (py >= min_Y and py <= max_Y))]
    )
    q_matches = np.array(
        [[qx, qy] for qx, qy in zip(qmatches[0], qmatches[1]) if ((qx >= min_X and qx <= max_X) and (qy >= min_Y and qy <= max_Y))]
    )

    # format matches for uploading to render-ws point match database
    d = {
        "pGroupId": p_sectionId,  # sectionId for image P
        "qGroupId": q_sectionId,  # sectionId for image Q
        "pId": tp["p"]["id"],  # tileId for image P
        "qId": tp["q"]["id"],  # tileId for image Q
        "matches": {
            "p": p_matches.T.tolist(),
            "q": q_matches.T.tolist(),
            "w": np.ones(len(p_matches)).tolist()
        }
    }
    matches_filtered.append(d)
    


  0%|          | 0/2 [00:00<?, ?it/s]

### Delete original point matches and upload filtered matches

In [84]:
# Delete existing collection
renderapi.pointmatch.delete_collection(match_collection,
                                       **params_render)

# Import filtered matches
renderapi.pointmatch.import_matches(
    match_collection,
    matches_filtered,
    **params_render
)

<Response [201]>

### Inspect filtered matches

In [85]:
plotting.plot_dsstack_with_alignment_matches(
    ds_stack_2_align['in'],
    match_collection,
    width=1000,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=1), Output()), _dom_classes=('widget-interact',)…

### Align `dsmontage` stack
---

### Create alignment files

In [11]:
from pathlib import Path
import os
import subprocess
import json
from pprint import pprint

In [21]:
z_values = renderapi.stack.get_z_values_for_stack(ds_stack_2_align['in'],
                                                  **params_render)
# Load align.json template
template_align_json = Path('../templates/align.json')
with template_align_json.open('r') as json_data:
    params_align_rough = json.load(json_data)

# Edit BigFeta solver schema
params_align_rough['first_section'] = min(z_values)
params_align_rough['last_section'] = max(z_values)
params_align_rough['solve_type'] = '3D'
params_align_rough['transformation'] = 'rigid'
params_align_rough['log_level'] = 'INFO'

# Edit input stack data
params_align_rough['input_stack']['host'] = params_render['host']
params_align_rough['input_stack']['owner'] = params_render['owner']
params_align_rough['input_stack']['project'] = params_render['project']
params_align_rough['input_stack']['name'] = ds_stack_2_align['in']

# Edit point match stack data
params_align_rough['pointmatch']['host'] = params_render['host']
params_align_rough['pointmatch']['owner'] = params_render['owner']
params_align_rough['pointmatch']['name'] = match_collection

# Edit output stack data
params_align_rough['output_stack']['host'] = params_render['host']
params_align_rough['output_stack']['owner'] = params_render['owner']
params_align_rough['output_stack']['project'] = params_render['project']
params_align_rough['output_stack']['name'] = ds_stack_2_align['out']

# Edit regularization parameters
params_align_rough['regularization']['default_lambda'] = 0.1      # default: 0.005
params_align_rough['regularization']['translation_factor'] = 0.1  # default: 0.005
params_align_rough['regularization']['thinplate_factor'] = 1e-5     # default: 1e-5

# Export montage settings to
align_json = dir_project / '_jsons_align_rough' / ds_stack_2_align['in'] / 'align_rough.json'
align_json.parent.mkdir(parents=True, exist_ok=True)
with align_json.open('w') as json_data:
    json.dump(params_align_rough, json_data, indent=2)

# Check alignment parameters
print(align_json)
print('-'*len(align_json.as_posix()))
pprint(params_align_rough)

/long_term_storage/akievits/FAST-EM/20230914_RP_exocrine_partial_test/_jsons_align_rough/postcorrection_dsmontages/align_rough.json
-----------------------------------------------------------------------------------------------------------------------------------
{'close_stack': 'True',
 'first_section': 0.0,
 'hdf5_options': {'chunks_per_file': -1, 'output_dir': ''},
 'input_stack': {'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
                 'collection_type': 'stack',
                 'db_interface': 'render',
                 'host': 'http://localhost',
                 'mongo_host': 'sonic.tnw.tudelft.nl',
                 'mongo_port': 27017,
                 'name': 'postcorrection_dsmontages',
                 'owner': 'akievits',
                 'port': 8081,
                 'project': '20230914_RP_exocrine_partial_test'},
 'last_section': 2.0,
 'log_level': 'INFO',
 'matrix_assembly': {'cross_pt_weight': 1.0,
                     'depth

### Run `BigFeta`

In [22]:
# Path to `BigFeta`
cwd = Path.cwd().as_posix()
BigFeta_dir = Path('/home/catmaid/BigFeta/')

# Select json for rough alignment
align_json = dir_project / '_jsons_align_rough' / ds_stack_2_align['in'] / 'align_rough.json'

# Call `BigFeta.BigFeta` process -- have to switch to BigFeta directory
os.chdir(BigFeta_dir.as_posix())
subprocess.run(['python', '-m', 'bigfeta.bigfeta', '--input_json', align_json.as_posix()])
os.chdir(cwd)

INFO:bigfeta.utils:
 loaded 3 tile specs from 3 zvalues in 0.0 sec using interface: render
INFO:__main__: A created in 0.1 seconds
INFO:__main__:
 solved in 0.0 sec
 precision [norm(Kx-Lm)/norm(Lm)] = 1.8e-12
 error     [norm(Ax-b)] = 0.008
 [mean(Ax) +/- std(Ax)] : -0.0 +/- 0.0
 [mean(error mag) +/- std(error mag)] : 0.0 +/- 0.0
  self.M[0, 0] = vec[0] * s
  self.M[0, 1] = vec[1] * s
  self.M[0, 2] = vec[2] * s
  self.M[1, 0] = -vec[1] * s
  self.M[1, 1] = vec[0] * s
  self.M[1, 2] = vec[3] * s
INFO:__main__:
 scales: 1.00 +/- 0.00, 1.00 +/- 0.00
INFO:bigfeta.utils:
ingesting results to http://localhost:8081 akievits__20230914_RP_exocrine_partial_test__postcorrection_dsmontages_aligned
INFO:bigfeta.utils:render output is going to /dev/null
INFO:__main__: total time: 1.3


### Inspect alignment

In [23]:
# plot stack
plotting.plot_stacks(
    stacks=[ds_stack_2_align['out']],
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=26, description='vmin', max=…

## 4) Rough alignment (III)
---
### Propagate transform to stitched together stack
The output stack is only used to generate tile pairs from. Doesn't really matter what the resulting stack looks like so much as long it's roughly aligned in `z` and not too crazy.

In [36]:
stack_2_align_rough = {
    'in': 'postcorrection_stitched',
    'out': 'postcorrection_aligned_rough'
}

In [75]:
ts_rough_aligned = []
z_values = renderapi.stack.get_z_values_for_stack(stack_2_align_rough['in'],
                                                  **params_render)
for z in tqdm(z_values, total=len(z_values)):
    # Get the transformation from the aligned downsampled stack
    tform_ds = renderapi.tilespec.get_tile_specs_from_z(ds_stack_2_align['out'],
                                                        z=z,
                                                        **params_render)[0].tforms[0]
    # Get TileSpecs from high res stack to transform
    tilespec = renderapi.tilespec.get_tile_specs_from_z(stack_2_align_rough['in'],
                                                        z=z,
                                                        **params_render)
    # Add transform
    for ts in tilespec:
        new_tforms = renderapi.transform.AffineModel(M00=tform_ds.M00, 
                                                     M01=tform_ds.M01, 
                                                     M10=tform_ds.M10, 
                                                     M11=tform_ds.M11, 
                                                     B0=tform_ds.B0 * (1/SCALE) + ts.tforms[0].B0,
                                                     B1=tform_ds.B1 * (1/SCALE) + ts.tforms[0].B1)
        ts.tforms = [new_tforms]
        ts_rough_aligned.append(ts.to_dict())
        
# Dump new TileSpecs to json
rough_align_json = dir_project / '_jsons_align_rough' / stack_2_align_rough['out'] / 'tilespecs_rough_align.json'
rough_align_json.parent.mkdir(parents=True, exist_ok=True)
json.dump(ts_rough_aligned, 
          open(rough_align_json,'w'), 
          indent=4)
# Create new stack
renderapi.stack.create_stack(stack_2_align_rough['out'],
                             **params_render)
# Import TileSpecs from json
renderapi.client.import_jsonfiles(stack_2_align_rough['out'],
                                  [rough_align_json],
                                  **params_render) 
    

  0%|          | 0/3 [00:00<?, ?it/s]

org.janelia.render.client.ImportJsonClient

  Running: /home/catmaid/render/deploy/jdk1.8.0_131/bin/java -cp /home/catmaid/render/render-ws-java-client/target/render-ws-java-client-2.0.1-SNAPSHOT-standalone.jar -Xms1G -Xmx1G -Djava.awt.headless=true -XX:+UseSerialGC org.janelia.render.client.ImportJsonClient --baseDataUrl http://localhost:8081/render-ws/v1 --owner akievits --project 20230914_RP_exocrine_partial_test --stack postcorrection_aligned_rough /long_term_storage/akievits/FAST-EM/20230914_RP_exocrine_partial_test/_jsons_align_rough/postcorrection_aligned_rough/tilespecs_rough_align.json


11:47:51.758 [main] INFO  [org.janelia.render.client.ClientRunner] run: entry
11:47:51.917 [main] INFO  [org.janelia.render.client.ImportJsonClient] runClient: entry, parameters={
  "renderWeb" : {
    "baseDataUrl" : "http://localhost:8081/render-ws/v1",
    "owner" : "akievits",
    "project" : "20230914_RP_exocrine_partial_test"
  },
  "tileSpecValidator" : { },
  "stack" : "postcorrection_

### Inspect

In [76]:
# plot stack
plotting.plot_stacks(
    stacks=[stack_2_align_rough['out']],
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=25226, description='vmin', m…