# iCAT Alignment
---
Overview and interactive walkthrough of how to align (part of) an individual section using `render-python`. Assumes image data has already been imported following the [iCAT-import workflow](https://github.com/lanery/iCAT-workflow/blob/master/notebooks/iCAT-import.ipynb).

## 1 Generate Point Matches
---

### 1.0 Imports

In [None]:
# Libraries needed
import os
from shutil import rmtree
import subprocess
from functools import partial
import re
from pathlib import Path
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import renderapi
from renderapi.client import (SiftPointMatchOptions,
                              MatchDerivationParameters,
                              FeatureExtractionParameters,
                              ArgumentParameters)

### 1.1 Setup render Environment

In [None]:
owner = '<user>'  # replace with your name
project = 'iCAT_align'
project_dir = Path(f'/long_term_storage/{owner}/SECOM/projects/{project}')
export_dir = Path(f'/long_term_storage/{owner}/CATMAID/projects/{project}')

# Create a renderapi.connect.Render object
render_connect_params = {
    'host': 'sonic',
    'port': 8080,
    'owner': owner,
    'project': project,
    'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
    'memGB': '2G'
}

render = renderapi.connect(**render_connect_params)
stacks = renderapi.render.get_stacks_by_owner_project(render=render)
match_collections = {}
for stack in stacks:
    match_collections[stack] = f'{project}_{stack}_points'
    
out = f"""\
stacks.............. {stacks}
match collections... {match_collections}
"""
print(out)

### 1.2 Get Tile Pairs
Tile pairs are any two tiles that overlap with each other (including diagonally). There are useful functions in the `renderapi` for generating lists of tile pairs.

In [None]:
# Initialize dicts
stack_bounds = {}
tile_pairs = {}
tile_pair_lists = {}

for stack in stacks:
    
    # Get positional bounds of image stack
    stack_bound = renderapi.stack.get_stack_bounds(stack, render=render)
    stack_bounds[stack] = stack_bound
    
    # Only attempt to get tile pairs if stack contains >1 tiles
    N_tiles = len(renderapi.stack.get_stack_tileIds(stack, render=render))
    if N_tiles > 1:
        
        # Generate tile pairs for input into point match generator
        tile_pairs[stack] = renderapi.client.tilePairClient(stack,
                                                            minz=stack_bounds[stack]['minZ'],
                                                            maxz=stack_bounds[stack]['maxZ'],
                                                            render=render)
        
        # Generate list of tile pairs    
        tile_pair_list = [(tp['p']['id'], tp['q']['id']) for tp in tile_pairs[stack]['neighborPairs']]
        tile_pair_lists[stack] = tile_pair_list
        
# Preview first several tile pairs
tile_pair_lists['lil_EM'][:8]

### 1.3 Run Point Match Client
Set SIFT and RANSAC parameters.

In [None]:
match_params = MatchDerivationParameters(matchIterations=None,
                                         matchMaxEpsilon=25,        # maximal alignment error
                                         matchMaxNumInliers=None,
                                         matchMaxTrust=None,
                                         matchMinInlierRatio=0.05,  # minimal inlier ratio
                                         matchMinNumInliers=7,      # minimal number of inliers
                                         matchModelType='RIGID',    # expected transformation
                                         matchRod=0.92)             # closest/next closest ratio
feature_params = FeatureExtractionParameters(SIFTfdSize=4,          # feature descriptor size
                                             SIFTmaxScale=0.25,     # (width/height *) maximum image size
                                             SIFTminScale=0.05,     # (width/height *) minimum image size
                                             SIFTsteps=6,           # steps per scale octave
                                             clipWidth=500,
                                             clipHeight=500)

sift_options = SiftPointMatchOptions(**{**match_params.__dict__,
                                        **feature_params.__dict__})

First create a wrapper for `renderapi.client.pointMatchClient` so that we can run the point match client on multiple threads with `renderapi.client.WithPool`.

In [None]:
def run_point_match_client(tile_pair_chunk, stack, collection, render):
    """
    Point match client wrapper for use in multiprocessing
    """
    renderapi.client.pointMatchClient(stack=stack,
                                      collection=collection,
                                      tile_pairs=tile_pair_chunk,
                                      render=render)

Now break the tile pairs into chunks to prevent a java heap space out of memory error. Adjust `chunk_size` as necessary. For large (4k x 4k) images, `chunk_size` may have to be reduced to 1.

In [None]:
chunk_size = 6
tile_pair_chunks = {}

for stack, tile_pair_list in tile_pair_lists.items():
    tile_pair_chunks[stack] = []
    
    for i in range(0, len(tile_pair_list), chunk_size):
        
        tile_pair_chunk = tile_pair_list[i : i+chunk_size]
        tile_pair_chunks[stack].append(tile_pair_chunk)
        
# Preview chunks
tile_pair_chunks['lil_EM'][:3]

Now run the point match client (the function that actually finds features and whatnot) over the chunks of tile pairs. Have to use `partial` with `WithPool` because otherwise it would be too easy. Adjust `N_cores` as necessary. For large (4k x 4k) images, `N_cores` may have to be reduced to ~10.
#### \*\****COMPUTATIONALLY EXPENSIVE*** \**

In [None]:
N_cores = 30

for stack, tile_pair_chunk in tile_pair_chunks.items():
    
    point_match_client_partial = partial(run_point_match_client,
                                         stack=stack,
                                         collection=match_collections[stack],
                                         render=render)
    
    with renderapi.client.WithPool(N_cores) as pool:
        pool.map(point_match_client_partial, tile_pair_chunks[stack])

## 2 Basic Point Match Analysis
---
### 2.1 Aggregate Point Matches

In [None]:
# Create empty dataframe in which to store results
df = pd.DataFrame(columns=['pid', 'qid', 'N_matches', 'pc', 'pr', 'qc', 'qr'])

for stack, tile_pair_list in tile_pair_lists.items():
    
    for i, tile_pair in enumerate(tile_pair_list):
    
        # Get group IDs
        groupIds = renderapi.pointmatch.get_match_groupIds(match_collections[stack], render=render)

        tile_pair_matches = renderapi.pointmatch.get_matches_from_tile_to_tile(
                                match_collections[stack],
                                pgroup=groupIds[0],
                                pid=tile_pair[0],
                                qgroup=groupIds[0],
                                qid=tile_pair[1],
                                render=render)
        
        # Suss out column and row number from tileId
        pc, pr = [int(i) for i in re.findall('\d+', tile_pair[0])]
        qc, qr = [int(i) for i in re.findall('\d+', tile_pair[1])]
                                
        try:
            # Stuff matches into arrays
            p_matches = np.array(tile_pair_matches[0]['matches']['p']).T
            q_matches = np.array(tile_pair_matches[0]['matches']['q']).T
            # Ensure p and q matches check out
            assert len(p_matches) == len(q_matches)
                        
        except IndexError:
            # No matches for this tile pair
            p_matches = []
            q_matches = []
        
        df.loc[i,['pid','qid']] = tile_pair
        df.loc[i,['pc','pr','qc','qr']] = (pc, pr, qc, qr)
        df.loc[i,'N_matches'] = len(p_matches)

# Infer dtype for each column
df = df.infer_objects()
# Preview dataframe
df.head()

### 2.2 Plot East-West and North-South point matches

In [None]:
Nr = int(df[['pr', 'qr']].max().max() + 1)
Nc = int(df[['pc', 'qc']].max().max() + 1)
EW_matches = np.zeros((Nr, Nc))
NS_matches = np.zeros((Nr, Nc))
# Loop over dataframe rows
for i, row in df.iterrows():
    if row['pr'] == row['qr']:
        EW_matches[row['pr'], row[['pc', 'qc']].min()] = row['N_matches']
    else:  # pc == qc
        NS_matches[row[['pr', 'qr']].min(), row['pc']] = row['N_matches']

# Plot point matches
fig, (ax1, ax2) = plt. subplots(ncols=2, figsize=(14, 7))
sns.heatmap(EW_matches, ax=ax1, annot=False, mask=EW_matches<7, cbar_kws={"shrink": 0.75},
            annot_kws={'fontsize': 7}, fmt='.0f', robust=True, square=True);
sns.heatmap(NS_matches, ax=ax2, annot=False, mask=NS_matches<7, cbar_kws={"shrink": 0.75},
            annot_kws={'fontsize': 7}, fmt='.0f', robust=True, square=True);
ax1.set_title('EAST-WEST Matches');
ax2.set_title('NORTH-SOUTH Matches');

## 3 Refine Problem Tiles
---
### 3.1 Examine distribution of point matches across all tile pairs

In [None]:
n_bins = 64
hist, bins = np.histogram(df['N_matches'], bins=n_bins)
cumsum = np.cumsum(hist)
cumsum = np.insert(cumsum, 0, 0)
cumsum = hist.max() * (cumsum - cumsum.min()) / (cumsum.max() - cumsum.min())
cumsum_x = np.linspace(0, df['N_matches'].max(), n_bins+1)

bars = np.percentile(df['N_matches'], [10, 25, 50, 75, 100])
bar_heights = np.interp(bars, cumsum_x, cumsum)

fig, ax = plt.subplots(figsize=(12, 5))
sns.distplot(df['N_matches'], bins=n_bins, kde=False, rug=True, hist_kws={'rwidth': 0.9}, ax=ax);
ax.plot(cumsum_x, cumsum);
ax.vlines(bars, np.zeros(bars.size), bar_heights, colors='C1', linestyle='--');

#### Bottom 10% of point matches

In [None]:
_10pct_threshold = np.percentile(df['N_matches'], 10)
reruns = df[df['N_matches'] < _10pct_threshold]
reruns.sort_values('N_matches')

### 3.2 Set more stringent SIFT parameters

In [None]:
match_params = MatchDerivationParameters(matchIterations=None,
                                         matchMaxEpsilon=25,        # maximal alignment error
                                         matchMaxNumInliers=None,
                                         matchMaxTrust=None,
                                         matchMinInlierRatio=0.05,  # minimal inlier ratio
                                         matchMinNumInliers=7,      # minimal number of inliers
                                         matchModelType='RIGID',    # expected transformation
                                         matchRod=0.92)             # closest/next closest ratio
feature_params = FeatureExtractionParameters(SIFTfdSize=4,          # feature descriptor size
                                             SIFTmaxScale=0.50,     # (width/height *) maximum image size
                                             SIFTminScale=0.05,     # (width/height *) minimum image size
                                             SIFTsteps=6,           # steps per scale octave
                                             clipWidth=500,
                                             clipHeight=500)

sift_options = SiftPointMatchOptions(**{**match_params.__dict__,
                                        **feature_params.__dict__})

### 3.3 (Re)Run PointMatchClient on Problematic Tile Pairs

In [None]:
chunk_size = 6
tile_pair_reruns = []

for i in range(0, len(reruns), chunk_size):
    tile_pair_chunk = reruns[['pid', 'qid']].values.tolist()[i : i+chunk_size]
    tile_pair_reruns.append(tile_pair_chunk)
    
# Preview chunks
tile_pair_reruns

#### \*\****COMPUTATIONALLY EXPENSIVE*** \**

In [None]:
point_match_client_partial = partial(run_point_match_client,
                                     stack=stack,
                                     collection=match_collections[stack],
                                     sift_options=feature_params,
                                     render=render)

N_cores = 10
with renderapi.client.WithPool(N_cores) as pool:
    pool.map(point_match_client_partial, tile_pair_reruns)

## 3 Align Tiles
---
So now all the point matches have been generated, but the tiles have not been globally aligned. To do so, we use an additional package, also developed at the Allen Institute, [EM_aligner_python](https://github.com/AllenInstitute/EM_aligner_python).

### 3.1 Edit Montage json

In [None]:
with open('/home/rlane/iCAT-workflow/templates/montage.json') as json_data:
    montage_settings = json.load(json_data)
    
for stack in tile_pair_lists.keys():
    
    # Edit input stack data
    montage_settings['input_stack']['owner'] = owner
    montage_settings['input_stack']['project'] = project
    montage_settings['input_stack']['name'] = stack
    
    # Edit pointmatch stack data
    montage_settings['pointmatch']['owner'] = owner
    montage_settings['pointmatch']['name'] = match_collections[stack]
    
    # Edit output stack data
    montage_settings['output_stack']['owner'] = owner
    montage_settings['output_stack']['project'] = f'{project}_montaged'
    montage_settings['output_stack']['name'] = f'{stack}_montaged'
    
    with project_dir.joinpath(f'{stack}/_montage.json').open('w') as json_data:
        json.dump(montage_settings, json_data, indent=2)

### 3.2 Change mongod Configuration
Needs to be refined, but this is working for now. Involves opening mongodb to remote IPs which might be dangerous.

Normal `/etc/mongod.conf` configuration
```
# network interfaces
net:
  port: 27017
#  bindIp: 0.0.0.0
  bindIp: 127.0.0.1 # Listen to local interface only, comment to listen on all interfaces.
#  bindIp: 127.0.0.1 169.254.100.2 131.180.83.183 172.17.0.1
```

Make changes and restart mongod
```
# network interfaces
net:
  port: 27017
  bindIp: 0.0.0.0
#  bindIp: 127.0.0.1 # Listen to local interface only, comment to listen on all interfaces.
#  bindIp: 127.0.0.1 169.254.100.2 131.180.83.183 172.17.0.1
```
### 3.3 Run EM Aligner

In [None]:
EM_aligner_dir = Path('/home/catmaid/EM_aligner_python/').as_posix()
input_json = project_dir.joinpath(f'{stack}/_montage.json').as_posix()

cwd = os.getcwd()
os.chdir(EM_aligner_dir)
subprocess.run(['python', '-m', 'EMaligner.EMaligner', '--input_json', f'{input_json}'])
os.chdir(cwd)

## 4 Export to CATMAID
---
### 4.1 Set up CATMAID Export Parameters

In [None]:
from renderapi.client import ArgumentParameters

class CatmaidBoxesParameters(ArgumentParameters):
    """
    """
    def __init__(self,
                 stack,
                 root_directory,
                 height=1024,
                 width=1024,
                 format='png',
                 max_level=9,
                 host=None, port=None,
                 baseurl=None,
                 owner=None, project=None,
                 render=None, **kwargs):
        
        super(CatmaidBoxesParameters, self).__init__(**kwargs)
        
        self.stack = stack
        self.rootDirectory = root_directory
        self.height = height
        self.width = width
        self.format = format
        self.maxLevel = max_level
        
        render_kwargs = render.make_kwargs()
        host = render_kwargs.get('host')
        port = render_kwargs.get('port')
        self.baseDataUrl = renderapi.render.format_baseurl(host, port)
        self.owner = render_kwargs.get('owner') if owner is None else owner
        self.project = render_kwargs.get('project') if project is None else project

In [None]:
catmaid_params = CatmaidBoxesParameters(stack=f'{stack}_montaged',
                                        root_directory=export_dir.as_posix(),
                                        height=1024,
                                        width=1024,
                                        format='png',
                                        max_level=9,
                                        project=f'{project}_montaged',
                                        render=render)
# Check java arguments
list(catmaid_params.to_java_args())

### 4.2 Call render script
`render_catmaid_boxes.sh`

In [None]:
sh = Path(render_connect_params['client_scripts']) / 'render_catmaid_boxes.sh'
p = subprocess.run([sh.as_posix(), b'0'] + list(catmaid_params.to_java_args()))

### 4.3 Resort CATMAID tiles

In [None]:
for stack in stacks:
    # Collect every tile in each stack
    stack_dir = export_dir / stack
    tiles = export_dir.glob('**/1024x1024/**/[0-9]*.png')
    
    # Relocate tiles in accordance with tile source convention 1
    for tile_format_7 in tiles:
        tile_structure = tile_format_7.as_posix().split('.')[0].split(tile_format_7.anchor)
        zoom, z, row, col = tile_structure[-4:]
        
        # Move tile to new directory
        tile_format_1 = export_dir / (f'{stack}/{z}/{row}_{col}_{zoom}.png')
        # Make directory (and parent directory) if necessary
        tile_format_1.parent.mkdir(parents=True, exist_ok=True)
        tile_format_7.rename(tile_format_1)
        
    # Clean up directory tree by removing the now empty 1024x1024 parent folder
    rmtree((export_dir / f'{project}_montaged').as_posix())