# Rough align
---

#### Overview
Interactive rough 3D alignment of serial sections. Point matches are found in montages (2d stitched sections) that are rendered to disk at 5\% scale. A filtering step is then performed to remove false point matches that are found on the border of the ROA. The alignment between downsampled montages is then solved, which produces a roughly aligned stack in render-ws. The transformations from this stack are then applied to the full-scale data, creating a montaged, roughly-aligned stack.

In [4]:
import pathlib
import requests

from tqdm.notebook import tqdm
import numpy as np
import renderapi

from scripted_render_pipeline import basic_auth
from scripted_render_pipeline.importer import uploader
from interactive_render import plotting
from interactive_render.utils import clear_image_cache

clear_image_cache()

#### parameters and `render-ws` environment variables
* `host` : Localhost or Web address which hosts render-ws (for uploader). It's usually the preamble of the link to render-ws html page, i.e. `{host_name}/render-ws/view/index.html`
* `owner` : render-ws ID of dataset
* `project`: render-ws project name
* `auth`: authorization for render-ws. See https://github.com/hoogenboom-group/scripted-render-pipeline/tree/master for instructions
* `dir_project`: full path to project directory on disk
* `memGB`: Gigabytes of RAM to give to java PointMatchClient. Increase proportionally if using more threads to find matches
* `SCALE`: Scale to downsample montage images to.

In [6]:
# Set script parameters  
dir_project = pathlib.Path("/long_term_storage/akievits/FAST-EM/tests/20231107_MCF7_UAC_test/") # project directory
SCALE = 0.05 # downsampling scale, default is 0.05
num_threads = 20

# Set stacks for downsampling
stack_2_downsample = {
    'in': 'postcorrection_stitched',
    'out': 'postcorrection_dsmontages'
}

# create an authorized session
auth = basic_auth.load_auth()
sesh = requests.Session()
sesh.auth = auth

# render-ws environment variables
params_render = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "akievits",
    "project": "20231107_MCF7_UAC_test",
    "session": sesh
}

params_align = {
    "host": "http://localhost",
    "port": 8081,
    "client_scripts": "/home/catmaid/render/render-ws-java-client/src/main/scripts",
    "client_script": "/home/catmaid/render/render-ws-java-client/src/main/scripts/run_ws_client.sh",
    "owner": "akievits",
    "project": "20231107_MCF7_UAC_test",
    "memGB": f"{num_threads * 2}G", # Allocate enough memory
    "session": sesh
}

params_uploader = {
    "host": "https://sonic.tnw.tudelft.nl",
    "owner": "akievits",
    "project": "20231107_MCF7_UAC_test",
    "auth": auth
}

## 1) Rough alignment (part I)
---
Perform rough alignment of downsamples section images

### Create downsampled montage stack

In [7]:
from interactive_render.utils import create_downsampled_stack

In [8]:
# Set stacks for downsampling
stack_2_downsample = {
    'in': 'postcorrection_stitched',
    'out': 'postcorrection_dsmontages'
}

In [10]:
# Create downsampled stack
ds_stack = create_downsampled_stack(dir_project, 
                                    stack_2_downsample, 
                                    **params_render)
# Upload
# initialize uploader
uppity = uploader.Uploader(
        **params_uploader,
        clobber=False
)

# import stack to render-ws
uppity.upload_to_render(
    stacks=[ds_stack],
    z_resolution=100
)

  0%|          | 0/3 [00:00<?, ?it/s]

uploading: 100%|██████████| 1/1 [00:00<00:00,  2.67stacks/s]


### Inspect downsampled montage stack

In [13]:
# plot stack
plotting.plot_stacks(
    [ds_stack.name],
    width=1000,
    vmin=0,
    vmax=65535,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=5, description='vmin', max=6…

## 2) Rough alignment (part II)
---
Get point matches for `dsmontage` stack and roughly align
### Align `dsmontage` stack


In [14]:
# Define stack to rough align and output stack name
ds_stack_2_align = {
    'in': 'postcorrection_dsmontages',
    'out': 'postcorrection_dsmontages_aligned'
}

### Get point matches

Use `render-ws` `PointMatchClient` script to find matching features between the neighboring z-levels
#### Collect tile pairs

In [15]:
# choose stack from which to get tile pairs
z_values = [int(z) for z in renderapi.stack.get_z_values_for_stack(
    ds_stack_2_align['in'],
    **params_render
)]

# Get tile pairs from the rough aligned stack
tilepairs = renderapi.client.tilePairClient(
    stack=ds_stack_2_align['in'],
    minz=min(z_values),
    maxz=max(z_values),
    zNeighborDistance=1,  # half-height of search cylinder
    excludeSameLayerNeighbors=False,
    subprocess_mode="check_output",  # suppresses output
    **params_render
)["neighborPairs"]

# Reformat tilepairs for PointMatchClient
tile_pairs_reformat = [[tp['p']['id'], tp['q']['id']] for tp in tilepairs]

# Show tile pairs
out = f"Number of tile pairs... {len(tile_pairs_reformat)}"
print(out, "\n" + "-"*len(out))
tile_pairs_reformat[:5]



Number of tile pairs... 2 
-------------------------


[['t0_z0.0_y0_x0', 't0_z1.0_y0_x0'], ['t0_z1.0_y0_x0', 't0_z2.0_y0_x0']]

In [16]:
from renderapi.client import (
    SiftPointMatchOptions,
    MatchDerivationParameters,
    FeatureExtractionParameters
)

In [17]:
# Name for pointmatch collection
match_collection = f"{params_render['project']}_{ds_stack_2_align['in']}_matches"
match_collection

'20231107_MCF7_UAC_test_postcorrection_dsmontages_matches'

In [18]:
# renderapi.pointmatch.delete_collection(match_collection,
#                                        **params_render)

#### Set SIFT + RANSAC parameters

In [19]:
# `RANSAC` parameters
params_RANSAC = MatchDerivationParameters(
    matchIterations=None,
    matchMaxEpsilon=50,        # maximal alignment error
    matchMaxNumInliers=None,
    matchMaxTrust=None,
    matchMinInlierRatio=0.01,  # minimal inlier ratio
    matchMinNumInliers=3,      # minimal number of inliers
    matchModelType='AFFINE',   # expected transformation
    matchRod=0.92              # closest/next closest ratio
)

# `SIFT` parameters
params_SIFT = FeatureExtractionParameters(
    SIFTfdSize=8,              # feature descriptor size
    SIFTmaxScale=0.8,         # (width/height *) maximum image size
    SIFTminScale=0.2,         # (width/height *) minimum image size
    SIFTsteps=5               # steps per scale octave
)

# Combined `SIFT` & `RANSAC` parameters
params_SIFT = SiftPointMatchOptions(
    fillWithNoise=True,
    **{**params_RANSAC.__dict__,
       **params_SIFT.__dict__}
)

# Extra parameters
params_SIFT.numberOfThreads = min(len(tile_pairs_reformat), num_threads)  # multithreading
params_SIFT.__dict__

{'SIFTfdSize': 8,
 'SIFTmaxScale': 0.8,
 'SIFTminScale': 0.2,
 'SIFTsteps': 5,
 'matchIterations': None,
 'matchMaxEpsilon': 50,
 'matchMaxNumInliers': None,
 'matchMaxTrust': None,
 'matchMinInlierRatio': 0.01,
 'matchMinNumInliers': 3,
 'matchModelType': 'AFFINE',
 'matchRod': 0.92,
 'renderScale': None,
 'fillWithNoise': True,
 'numberOfThreads': 2}

In [20]:
# Divide point matching finding in batches equal in size to # of threads
batch_size = params_SIFT.numberOfThreads
batches = [i for i in range(0, len(tile_pairs_reformat), batch_size)] 

In [21]:
for i in tqdm((batches),
              total=len(batches)):
    tile_pair_batch = tile_pairs_reformat[i:(i+batch_size)]

    # Run SIFT + RANSAC via render-ws PointMatchClient
    renderapi.client.pointMatchClient(
        stack=ds_stack_2_align['in'],
        collection=match_collection,
        tile_pairs=tile_pair_batch,
        sift_options=params_SIFT,
        excludeAllTransforms=True,
        subprocess_mode='check_output',  # suppresses output
        **params_align
    )

  0%|          | 0/1 [00:00<?, ?it/s]

### Inspect

In [22]:
plotting.plot_dsstack_with_alignment_matches(
    ds_stack_2_align['in'],
    match_collection,
    width=1000,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=1), Output()), _dom_classes=('widget-interact',)…

### Filter point matches
Filter false point matches that are located on the edges of the stack

In [23]:
clip = 0.05 # fraction of image size to clip, can be varied

# Define clipping range based on stack bounds
bounds = renderapi.stack.get_stack_bounds(ds_stack_2_align['in'],
                                          **params_render)
min_X, max_X = clip * bounds['maxX'], (1-clip) * bounds['maxX'] 
min_Y, max_Y = clip * bounds['maxY'], (1-clip) * bounds['maxX'] 

# Filtered matches
matches_filtered = []

# Get tile pairs from the rough aligned stack
tile_pairs = renderapi.client.tilePairClient(
    stack=ds_stack_2_align['in'],
    minz=min(z_values),
    maxz=max(z_values),
    zNeighborDistance=1,  # half-height of search cylinder
    excludeSameLayerNeighbors=False,
    subprocess_mode="check_output",  # suppresses output
    **params_render
)["neighborPairs"]

# loop through tile pairs
for tp in tqdm(tile_pairs):
    # get z value from groupId (aka sectionId)
    p_sectionId = tp["p"]["groupId"]
    q_sectionId = tp["q"]["groupId"]
    matches_per_group = renderapi.pointmatch.get_matches_from_group_to_group(match_collection,
                                                                             p_sectionId,
                                                                             q_sectionId,
                                                                             **params_render)
    # Filter match coordinates
    pmatches = matches_per_group[0]['matches']['p']
    qmatches = matches_per_group[0]['matches']['q']
    p_matches = np.array(
        [[px, py] for px, py, qx, qy in zip(pmatches[0], pmatches[1], qmatches[0], qmatches[1]) if ((px >= min_X and px <= max_X) and (py >= min_Y and py <= max_Y) and
                                                                                                    (qx >= min_X and qx <= max_X) and (qy >= min_Y and qy <= max_Y))]
    )
    q_matches = np.array(
        [[qx, qy] for qx, qy, px, py in zip(qmatches[0], qmatches[1], pmatches[0], pmatches[1]) if ((qx >= min_X and qx <= max_X) and (qy >= min_Y and qy <= max_Y) and
                                                                                                    (px >= min_X and px <= max_X) and (py >= min_Y and py <= max_Y))]
    )

    # format matches for uploading to render-ws point match database
    d = {
        "pGroupId": p_sectionId,  # sectionId for image P
        "qGroupId": q_sectionId,  # sectionId for image Q
        "pId": tp["p"]["id"],  # tileId for image P
        "qId": tp["q"]["id"],  # tileId for image Q
        "matches": {
            "p": p_matches.T.tolist(),
            "q": q_matches.T.tolist(),
            "w": np.ones(len(p_matches)).tolist()
        }
    }
    matches_filtered.append(d)



  0%|          | 0/2 [00:00<?, ?it/s]

### Delete original point matches and upload filtered matches

In [25]:
# Delete existing collection
renderapi.pointmatch.delete_collection(match_collection,
                                       **params_render)

# Import filtered matches
renderapi.pointmatch.import_matches(
    match_collection,
    matches_filtered,
    **params_render
)

<Response [201]>

### Inspect filtered matches

In [26]:
plotting.plot_dsstack_with_alignment_matches(
    ds_stack_2_align['in'],
    match_collection,
    width=1000,
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=1), Output()), _dom_classes=('widget-interact',)…

## Align `dsmontage` stack
---

### Create alignment files

In [24]:
from pathlib import Path
import os
import subprocess
import json
from pprint import pprint

In [25]:
z_values = renderapi.stack.get_z_values_for_stack(ds_stack_2_align['in'],
                                                  **params_render)
# Load align.json template
template_align_json = Path('../templates/align.json')
with template_align_json.open('r') as json_data:
    params_align_rough = json.load(json_data)

# Edit BigFeta solver schema
params_align_rough['first_section'] = min(z_values)
params_align_rough['last_section'] = max(z_values)
params_align_rough['solve_type'] = '3D'
params_align_rough['transformation'] = 'SimilarityModel'
params_align_rough['log_level'] = 'INFO'

# Edit input stack data
params_align_rough['input_stack']['host'] = params_render['host']
params_align_rough['input_stack']['owner'] = params_render['owner']
params_align_rough['input_stack']['project'] = params_render['project']
params_align_rough['input_stack']['name'] = ds_stack_2_align['in']

# Edit point match stack data
params_align_rough['pointmatch']['host'] = params_render['host']
params_align_rough['pointmatch']['owner'] = params_render['owner']
params_align_rough['pointmatch']['name'] = match_collection

# Edit output stack data
params_align_rough['output_stack']['host'] = params_render['host']
params_align_rough['output_stack']['owner'] = params_render['owner']
params_align_rough['output_stack']['project'] = params_render['project']
params_align_rough['output_stack']['name'] = ds_stack_2_align['out']

# Edit regularization parameters
params_align_rough['regularization']['default_lambda'] = 1e3        # default: 0.005
params_align_rough['regularization']['translation_factor'] = 0.00001  # default: 0.005
params_align_rough['regularization']['thinplate_factor'] = 1e-5   # default: 1e-5

# Export montage settings to
align_json = dir_project / '_jsons_align_rough' / ds_stack_2_align['in'] / 'align_rough.json'
align_json.parent.mkdir(parents=True, exist_ok=True)
with align_json.open('w') as json_data:
    json.dump(params_align_rough, json_data, indent=2)

# Check alignment parameters
print(align_json)
print('-'*len(align_json.as_posix()))
pprint(params_align_rough)

/long_term_storage/akievits/FAST-EM/tests/20231107_MCF7_UAC_test/_jsons_align_rough/postcorrection_dsmontages/align_rough.json
------------------------------------------------------------------------------------------------------------------------------
{'close_stack': 'True',
 'first_section': 0.0,
 'hdf5_options': {'chunks_per_file': -1, 'output_dir': ''},
 'input_stack': {'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
                 'collection_type': 'stack',
                 'db_interface': 'render',
                 'host': 'http://localhost',
                 'mongo_host': 'sonic.tnw.tudelft.nl',
                 'mongo_port': 27017,
                 'name': 'postcorrection_dsmontages',
                 'owner': 'akievits',
                 'port': 8081,
                 'project': '20231107_MCF7_UAC_test'},
 'last_section': 2.0,
 'log_level': 'INFO',
 'matrix_assembly': {'choose_random': 'False',
                     'cross_pt_weight': 0.5,
 

### Run `BigFeta`

In [26]:
# Path to `BigFeta`
cwd = Path.cwd().as_posix()
BigFeta_dir = Path('/home/catmaid/BigFeta/')

# Select json for rough alignment
align_json = dir_project / '_jsons_align_rough' / ds_stack_2_align['in'] / 'align_rough.json'

# Call `BigFeta.BigFeta` process -- have to switch to BigFeta directory
os.chdir(BigFeta_dir.as_posix())
subprocess.run(['python', '-m', 'bigfeta.bigfeta', '--input_json', align_json.as_posix()])
os.chdir(cwd)

INFO:bigfeta.utils:
 loaded 3 tile specs from 3 zvalues in 0.1 sec using interface: render
INFO:__main__: A created in 0.2 seconds
INFO:__main__:
 solved in 0.0 sec
 precision [norm(Kx-Lm)/norm(Lm)] = 7.8e-12
 error     [norm(Ax-b)] = 52.647
 [mean(Ax) +/- std(Ax)] : 0.0 +/- 1.0
 [mean(error mag) +/- std(error mag)] : 0.8 +/- 0.6
  self.M[0, 0] = vec[0]
  self.M[0, 1] = vec[1]
  self.M[0, 2] = vec[2]
  self.M[1, 0] = -vec[1]
  self.M[1, 1] = vec[0]
  self.M[1, 2] = vec[3]
INFO:__main__:
 scales: 0.40 +/- 0.00, 0.40 +/- 0.00
INFO:bigfeta.utils:
ingesting results to http://localhost:8081 akievits__20231107_MCF7_UAC_test__postcorrection_dsmontages_aligned
INFO:bigfeta.utils:render output is going to /dev/null
INFO:__main__: total time: 1.8


### Inspect alignment

In [27]:
clear_image_cache()

<Response [200]>

In [28]:
# plot stack
plotting.plot_stacks(
    stacks=[ds_stack_2_align['out']],
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=5, description='vmin', max=6…

## 3) Rough alignment (part III)
---
### Propagate transform to stitched together stack
The output stack is only used to generate tile pairs from. Doesn't really matter what the resulting stack looks like so much as long it's roughly aligned in `z` and not too crazy.

In [29]:
stack_2_align_rough = {
    'in': 'postcorrection_stitched',
    'out': 'postcorrection_aligned_rough'
}

In [31]:
ts_rough_aligned = []
z_values = renderapi.stack.get_z_values_for_stack(stack_2_align_rough['in'],
                                                  **params_render)
for z in tqdm(z_values, total=len(z_values)):
    # Get the transformation from the aligned downsampled stack
    tform_ds = renderapi.tilespec.get_tile_specs_from_z(ds_stack_2_align['out'],
                                                        z=z,
                                                        **params_render)[0].tforms
    # Get TileSpecs from high res stack to transform
    tilespec = renderapi.tilespec.get_tile_specs_from_z(stack_2_align_rough['in'],
                                                        z=z,
                                                        **params_render)
    # Scale downsampled stack transform to original stack size
    for tf in tform_ds:
        if isinstance(tf, renderapi.transform.leaf.AffineModel):
            tf.M[0:2, 0:2] *= SCALE

    # Append transform to high res stack
    for ts in tilespec:
        ts.tforms += tform_ds
        ts_rough_aligned.append(ts)
        
# Create new stack
renderapi.stack.create_stack(stack_2_align_rough['out'],
                             **params_render)
# Import tilespecs to stack
renderapi.client.import_tilespecs(stack_2_align_rough['out'],
                                  tilespecs=ts_rough_aligned,
                                  **params_render)
# Close stack
renderapi.stack.set_stack_state(stack_2_align_rough['out'],
                                state='COMPLETE',
                                **params_render)

  0%|          | 0/3 [00:00<?, ?it/s]



org.janelia.render.client.ImportJsonClient

  Running: /home/catmaid/render/deploy/jdk1.8.0_131/bin/java -cp /home/catmaid/render/render-ws-java-client/target/render-ws-java-client-2.0.1-SNAPSHOT-standalone.jar -Xms1G -Xmx1G -Djava.awt.headless=true -XX:+UseSerialGC org.janelia.render.client.ImportJsonClient --baseDataUrl http://localhost:8081/render-ws/v1 --owner akievits --project 20231107_MCF7_UAC_test --stack postcorrection_aligned_rough /tmp/tmpvbn4ewk9.json


16:03:53.210 [main] INFO  [org.janelia.render.client.ClientRunner] run: entry
16:03:53.376 [main] INFO  [org.janelia.render.client.ImportJsonClient] runClient: entry, parameters={
  "renderWeb" : {
    "baseDataUrl" : "http://localhost:8081/render-ws/v1",
    "owner" : "akievits",
    "project" : "20231107_MCF7_UAC_test"
  },
  "tileSpecValidator" : { },
  "stack" : "postcorrection_aligned_rough",
  "tileFiles" : [
    "/tmp/tmpvbn4ewk9.json"
  ]
}
16:03:53.652 [main] INFO  [org.janelia.render.client.RenderDataClient] getSta

<Response [201]>

### Inspect

In [32]:
# plot stack
plotting.plot_stacks(
    stacks=[stack_2_align_rough['out']],
    **params_render
)

  0%|          | 0/1 [00:00<?, ?it/s]

interactive(children=(IntSlider(value=0, description='z', max=2), IntSlider(value=28550, description='vmin', m…

In [33]:
bounds = renderapi.stack.get_stack_bounds(stack_2_align_rough['out'],
                                          **params_render)
bounds

{'minX': -58.0,
 'minY': -26.0,
 'minZ': 0.0,
 'maxX': 659.0,
 'maxY': 646.0,
 'maxZ': 2.0}