# iCAT Correlate
---

#### Packages

In [1]:
from pathlib import Path
import re
from functools import partial
import warnings
from itertools import product

from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import seaborn as sns
import altair as alt

from shapely.geometry import box
from shapely import affinity
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon

from skimage import img_as_uint
from skimage.color import rgb2grey
from skimage.measure import ransac
from skimage.transform import AffineTransform as AffineSkimage
from skimage.external.tifffile import TiffWriter

import renderapi
from renderapi.transform import AffineModel as AffineRender
from renderapi.tilespec import TileSpec
from renderapi.layout import Layout
from renderapi.client import (SiftPointMatchOptions,
                              MatchDerivationParameters,
                              FeatureExtractionParameters,
                              ArgumentParameters)

from icatapi.render_pandas import *
from icatapi.correlate import *

#### Settings

In [2]:
# pandas display settings
# -----------------------
pd.set_option('display.max_rows', 20)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', 20)

## Set up `render-ws` environment
---

In [3]:
# `render` project parameters
# ---------------------------
owner = 'rlane'
project = '20191230_RL010'

# Create a renderapi.connect.Render object
# ----------------------------------------
render_connect_params = {
    'host': 'sonic',
    'port': 8080,
    'owner': owner,
    'project': project,
    'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
    'memGB': '2G'
}
render = renderapi.connect(**render_connect_params)

# Infer stack and section info
# ----------------------------
stacks = renderapi.render.get_stacks_by_owner_project(render=render)
stacks_EM = [stack for stack in stacks if 'EM' in stack]
stacks_FM = [stack for stack in stacks if 'EM' not in stack]
# TODO: refactor this
stacks_2_correlate = {
    'anchor': 'lil_EM_montaged',  # aligned
    'sailor': 'big_EM_overlaid',  # overlaid
}

# Output
# ------
out = f"""\
all stacks............ {stacks}
EM stacks............. {stacks_EM}
FM stacks............. {stacks_FM}
stacks to correlate... {stacks_2_correlate}
...
"""
print(out)

# Create stacks DataFrame
# ------------------------
df_project = create_stacks_DataFrame(stacks=list(stacks_2_correlate.values()),
                                     render=render).dropna(axis=1)
df_project.groupby('stack')\
          .apply(lambda x: x.sample(3))

all stacks............ ['insulin_correlated', 'hoechst_correlated', 'big_EM_correlated', 'mm_EM', 'big_EM_overlaid', 'hoechst_overlaid', 'insulin_overlaid', 'lil_EM_montaged', 'lil_EM', 'insulin', 'hoechst', 'big_EM']
EM stacks............. ['big_EM_correlated', 'mm_EM', 'big_EM_overlaid', 'lil_EM_montaged', 'lil_EM', 'big_EM']
FM stacks............. ['insulin_correlated', 'hoechst_correlated', 'hoechst_overlaid', 'insulin_overlaid', 'insulin', 'hoechst']
stacks to correlate... {'anchor': 'lil_EM_montaged', 'sailor': 'big_EM_overlaid'}
...



Unnamed: 0_level_0,Unnamed: 1_level_0,tileId,z,width,height,minIntensity,maxIntensity,stack,sectionId,imageRow,imageCol,stageX,stageY,imageUrl,transforms
stack,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
big_EM_overlaid,2887,aai_big_EM-S001-...,1.0,4096.0,4096.0,31900.0,33450.0,big_EM_overlaid,S001,1,2,-573.418,6295.522,file:///long_ter...,"[M=[[1.000000,0...."
big_EM_overlaid,2954,aad_big_EM-S007-...,7.0,4096.0,4096.0,31900.0,33450.0,big_EM_overlaid,S007,2,2,3021.804,6186.913,file:///long_ter...,"[M=[[1.000000,0...."
big_EM_overlaid,2974,aal_big_EM-S008-...,8.0,4096.0,4096.0,31900.0,33450.0,big_EM_overlaid,S008,0,0,3256.122,6576.072,file:///long_ter...,"[M=[[1.000000,0...."
lil_EM_montaged,2014,adr_lil_EM-S007-...,7.0,4096.0,4096.0,30150.0,33750.0,lil_EM_montaged,S007,14,0,2769.386,6190.437,file:///long_ter...,"[M=[[0.787618,-0..."
lil_EM_montaged,2822,akd_lil_EM-S009-...,9.0,4096.0,4096.0,30150.0,33750.0,lil_EM_montaged,S009,3,7,4080.869,6353.268,file:///long_ter...,"[M=[[0.796012,-0..."
lil_EM_montaged,1462,aha_lil_EM-S005-...,5.0,4096.0,4096.0,30150.0,33750.0,lil_EM_montaged,S005,8,8,1710.57,6308.429,file:///long_ter...,"[M=[[0.782531,-0..."


## Create mini-montages
---

### Find overlapping tiles

Find the raw (unmontaged) lil EM tiles that overlap with each big EM tile.
Will store this information in a `DataFrame` resembling
```python
[1]:  df_overlapping

[1]: section  z  big EM tileId                  overlapping lil EM tileIds
     S001     1  'aaa_big_EM-S001-00001x00001'  ['aaa_lil_EM-S001-00012x00013','aab_lil_EM-S001-00013x00013',...]
     S001     1  'aab_big_EM-S001-00002x00001'  ['aaq_lil_EM-S001-00027x00012','aar_lil_EM-S001-00026x00012',...]
     S001     1  'aad_big_EM-S001-00001x00000'  ['aek_lil_EM-S001-00025x00006','ael_lil_EM-S001-00024x00006',...]
     ...
     S004     4  'aad_big_EM-S004-00003x00003'  ['aio_lil_EM-S004-00013x00000','aip_lil_EM-S004-00012x00000',...]
```

For this to be done correctly, we need to look at the tiles as they are in "real" physical space (i.e. based on their FOV and stage position (all in microns)). Since tiles are imported into `render-ws` in "pixel" space and the pixel size gets omitted from the specification, this is unfortunately complicated. Nonetheless, we can create some functionality to transform into real physical space.

In [None]:
# Hardcode pixelsize data
# -----------------------
# TODO: find a better way
#       probably involes somehow adding this data to the stack DataFrame
ps_a = 4.85973748747     # nm/px
ps_s = 33.5663173837785  # nm/px

In [None]:
# Initialize DataFrame of overlapping tiles
overlapping_cols = ['stack', 'z', 'sectionId', 'tileId',
                    'imageRow', 'imageCol', 'width', 'height',
                    'stageX', 'stageY', 'imageUrl']
df_overlapping = df_project.loc[(df_project['stack'] == stacks_2_correlate['sailor']),
                                overlapping_cols]\
                           .reset_index(drop=True)\
                           .copy()
# Create column for overlapping tiles
df_overlapping['overlappingTileIds'] = None
df_overlapping['minx'] = np.nan
df_overlapping['miny'] = np.nan
df_overlapping['maxx'] = np.nan
df_overlapping['maxy'] = np.nan

# For each big EM tile, find all overlapping lil EM tiles
for i, tile_s in tqdm(df_overlapping.iterrows(),
                      total=len(df_overlapping)):

    # Get physical bounding box (in um) of big EM tile from stage position and pixelsize
    x0 = tile_s['stageX']            # um
    y0 = tile_s['stageY']            # um
    w = tile_s['width'] * ps_s/1e3   # px --> um
    h = tile_s['height'] * ps_s/1e3  # px --> um
    bbox_s = box(x0-w/2, y0-h/2, x0+w/2, y0+h/2)

    # Collect all the overlapping EM tiles and their bounds in `render` space
    overlapping_tiles = []
    bounds = []
    # Loop through all lil EM tiles within the same section
    # to find the ones that overlap (stupid search strategy)
    for j, tile_a in df_project.loc[(df_project['stack'] == stacks_2_correlate['anchor']) &\
                                    (df_project['z'] == tile_s['z'])]\
                               .iterrows():

        # Get physical bounding box (in um) of lil EM tile from stage position and pixelsize
        x0 = tile_a['stageX']            # um
        y0 = tile_a['stageY']            # um
        w = tile_a['width'] * ps_a/1e3   # px --> um
        h = tile_a['height'] * ps_a/1e3  # px --> um
        bbox_a = box(x0-w/2, y0-h/2, x0+w/2, y0+h/2)

        # Determine if big EM and lil EM tiles overlap
        if bboxes_overlap(bbox_s.bounds, bbox_a.bounds):
            overlapping_tiles.append(tile_a.tileId)

            # Get bbox of tile specification in `render` space
            bounds.append(renderapi.tilespec.get_tile_spec(stack=stacks_2_correlate['anchor'],
                                                           tile=tile_a.tileId,
                                                           render=render).bbox)

    # Add overlapping tiles and bounds data to DataFrame
    df_overlapping.at[i, 'overlappingTileIds'] = overlapping_tiles
    df_overlapping.loc[i, 'minx'] = np.array(bounds)[:, 0].min() if bounds else np.nan
    df_overlapping.loc[i, 'miny'] = np.array(bounds)[:, 1].min() if bounds else np.nan
    df_overlapping.loc[i, 'maxx'] = np.array(bounds)[:, 2].max() if bounds else np.nan
    df_overlapping.loc[i, 'maxy'] = np.array(bounds)[:, 3].max() if bounds else np.nan

# Preview
df_overlapping.sample(10)

#### Preview set of overlapping tiles

In [None]:
# Preview sample
sample = df_overlapping.dropna()\
                       .sample(1)
out = f"""\
big EM tile......... {sample['tileId'].iloc[0]}
Overlapping tiles... {np.stack(sample['overlappingTileIds']).size}
"""
print(out)
with np.printoptions(linewidth=150):
    print(np.stack(sample['overlappingTileIds']))

### Write mini-montages to disk

* Render minimontage image
* Scale down minimontage to 4096px width
* Write downsampled minimontage to disk
* Create minimontage stack DataFrame

`TileSpec` parameter | value or origin
-------------------- | ---------------
`z`                  | big EM `z`
`sectionId`          | big EM `sectionId`
`tileId`             | big EM `tileId`
`width`              | ~4096
`height`             | `s` * (`maxy` - `miny`)
`imageRow`           | big EM `imageRow`
`imageCol`           | big EM `imageCol`
`minint`             | mean `minint` of overlapping tiles
`maxint`             | mean `maxint` of overlapping tiles
`imageUrl`           | 
`transforms`         | `M=[[1/s, 0.0],[0.0, 1/s]] B=[x0, y0]`

where `s = 4096 / width` is the scale factor by which the minimontage was downsampled and (`x0`, `y0`) is the upper left coordinate (`minx`, `miny`) of the set of overlapping tiles.

In [None]:
# TODO: parallelize this shiz

# Set name of mindf_minimontage stack
stack_minimontage = 'mm_EM'

# Initialize minimontage DataFrame
minimontage_cols = ['stack', 'z', 'sectionId', 'tileId',
                    'width', 'height', 'imageRow', 'imageCol',
                    'minint', 'maxint', 'imageUrl', 'tforms']
df_minimontage = pd.DataFrame(columns=minimontage_cols)

# Loop through big EM tiles that overlap with lil EM tiles
for i, tile_s in tqdm(df_overlapping.dropna().iterrows(),
                    total=len(df_overlapping.dropna())):

    # Render minimontage: bbox image enclosing overlapping lil EM tiles
    x = tile_s['minx']
    y = tile_s['miny']
    width = tile_s['maxx'] - tile_s['minx']
    height = tile_s['maxy'] - tile_s['miny']
    scale = np.round(4096/width, 5)
#     image = renderapi.image.get_bb_image(stack=stacks_2_correlate['anchor'],
#                                                z=tile_s['z'],
#                                                x=x,
#                                                y=y,
#                                                width=width,
#                                                height=height,
#                                                scale=scale,
#                                                img_format='tif',
#                                                render=render)

#     # Convert to grey scale 16-bit image
#     with warnings.catch_warnings():      # Suppress precision
#         warnings.simplefilter('ignore')  # loss warnings
#         image = img_as_uint(rgb2grey(image))

    # Set minimontage filepath
    fp_dir = Path(tile_s['imageUrl'].split('://')[-1]).parents[1] / stack_minimontage
    fp_dir.mkdir(parents=False, exist_ok=True)
    fp = fp_dir / f"{stack_minimontage}-"\
                  f"{tile_s['sectionId']}-"\
                  f"{tile_s['imageCol']:05d}x"\
                  f"{tile_s['imageRow']:05d}.tif"
#     # Save to disk with `TiffWriter`
#     with TiffWriter(fp.as_posix()) as tif:
#         tif.save(image)

    # Set transform
    A = AffineRender(M00=1/scale, B0=x,
                     M11=1/scale, B1=y)

    # Set min, max intensity
    minint = df_project.loc[(df_project['stack'] == stacks_2_correlate['anchor']) &\
                            (df_project['tileId'].isin(tile_s['overlappingTileIds'])),
                            'minIntensity'].mean()
    maxint = df_project.loc[(df_project['stack'] == stacks_2_correlate['anchor']) &\
                            (df_project['tileId'].isin(tile_s['overlappingTileIds'])),
                            'maxIntensity'].mean()

    # Populate minimontage DataFrame
    df_minimontage.loc[i, 'stack'] = stack_minimontage
    df_minimontage.loc[i, 'z'] = tile_s['z']
    df_minimontage.loc[i, 'sectionId'] = tile_s['sectionId']
    df_minimontage.loc[i, 'tileId'] = fp.stem
    df_minimontage.loc[i, 'width'] = np.floor(scale * width)
    df_minimontage.loc[i, 'height'] = np.floor(scale * height)
    df_minimontage.loc[i, 'imageRow'] = tile_s['imageRow']
    df_minimontage.loc[i, 'imageCol'] = tile_s['imageCol']
    df_minimontage.loc[i, 'minint'] = 0
    df_minimontage.loc[i, 'maxint'] = 65535
    df_minimontage.loc[i, 'imageUrl'] = fp.as_uri()
    df_minimontage.at[i, 'tforms'] = [A]

# Preview
df_minimontage.sample(10)

## Upload mini-montages to `render-ws`
---

### Create mini-montage stack

In [None]:
# Loop through minimontage tiles
tile_specs = []
for i, tile in df_minimontage.iterrows():

    # Create `TileSpec
    ts = TileSpec(**tile.to_dict())
    tile_specs.append(ts)

# Create stack
renderapi.stack.create_stack(stack_minimontage,
                             render=render)

# Import TileSpecs to render
renderapi.client.import_tilespecs(stack_minimontage,
                                  tile_specs,
                                  render=render)

# Set stack state to complete
renderapi.stack.set_stack_state(stack_minimontage,
                                'COMPLETE',
                                render=render)

### Inspect mini-montage stack

#### Tile map

In [None]:
# Specify stacks and sections
stacks_2_plot = ['mm_EM', 'lil_EM_montaged']
sections_2_plot = df_project['sectionId'].unique().tolist()

# Set up figure
ncols = len(sections_2_plot)
fig, axes = plt.subplots(ncols=ncols, squeeze=False,
                         figsize=(8*ncols, 8))
axmap = {k: v for k, v in zip(sections_2_plot, axes.flat)}
cmap = {k: v for k, v in zip(stacks_2_plot, sns.color_palette(n_colors=len(stacks_2_plot)))}

# Iterate through layers
df_stacks = create_stacks_DataFrame(stacks_2_plot,
                                    render=render)
for sectionId, layer in tqdm(df_stacks.groupby('sectionId')):
    # Collect all tiles in each layer to determine bounds
    boxes = []
    # Set axis
    ax = axmap[sectionId]

    # Loop through tilesets within each layer
    for stack, tileset in layer.groupby('stack'):

        # Loop through each tile
        for i, tile in tileset.iterrows():

            # Create `shapely.box` resembling raw image tile
            b = box(0, 0, tile['width'], tile['height'])
            # Apply transforms to `shapely.box`
            for tform in tile['transforms']:
                A = (tform.M[:2, :2].ravel().tolist() +
                     tform.M[:2,  2].ravel().tolist())
                b = affinity.affine_transform(b, A)
            boxes.append(b)
            # Get coordinates of `shapely.box` to plot matplotlib polygon patch
            xy = np.array(b.exterior.xy).T
            p = Polygon(xy, color=cmap[stack], alpha=0.2)
            ax.add_patch(p)

    # Axis aesthetics
    ax.set_title(sectionId)
    ax.set_xlabel('X [px]')
    ax.set_ylabel('Y [px]')
    # Determine bounds
    bounds = np.swapaxes([b.exterior.xy for b in boxes], 1, 2).reshape(-1, 2)
    ax.set_xlim(bounds[:, 0].min(), bounds[:, 0].max())
    ax.set_ylim(bounds[:, 1].min(), bounds[:, 1].max())
    ax.invert_yaxis()
    ax.set_aspect('equal')

#### Render images

In [None]:
# Stacks DataFrame
df_stacks = create_stacks_DataFrame(['big_EM_overlaid', 'mm_EM'],
                                    render=render)

# Specify stacks and sections
stacks_2_plot = df_stacks['stack'].unique().tolist()
sections_2_plot = df_stacks['sectionId'].unique().tolist()

# Set up figure
nrows = len(stacks_2_plot)
ncols = len(sections_2_plot)
fig, axes = plt.subplots(nrows, ncols, squeeze=False,
                         figsize=(8*ncols, 8*nrows))
axmap = {k: v for k, v in zip(product(stacks_2_plot, sections_2_plot), axes.flat)}

# Iterate through tilesets
for (stack, sectionId), tileset in tqdm(df_stacks.groupby(['stack', 'sectionId'])):

    # Set axis
    ax = axmap[(stack, sectionId)]
    # Fetch a (not so) random image tile
    tile = tileset.loc[(df_stacks['imageRow'] == 1) &\
                       (df_stacks['imageCol'] == 2)].iloc[0]
    tileId = tile['tileId']
    scale = 1024 / tile['width']
    image = renderapi.image.get_tile_image_data(stack=stack,
                                                tileId=tileId,
                                                normalizeForMatching=False,
                                                scale=scale,
                                                render=render)
    # Plot
    ax.imshow(image)
    # Axis aesthetics
    ax.set_title(f"{stack} | {sectionId}\n{tileId}")
    ax.set_xlabel('X [px]')
    ax.set_ylabel('Y [px]')

## Generate correlative point matches
---

### Generate (low mag EM, mini-montage) tile pairs

In [None]:
# Initialize DataFrame of tile pairs
pairs_cols = ['p.stack', 'q.stack', 'z',
              'p.groupId', 'q.groupId',
              'p.id', 'q.id',
              'imageRow', 'imageCol']
df_pairs = pd.DataFrame(columns=pairs_cols)

# Iterate through minimontage tiles to find corresponding big EM tiles
for i, q_tile in df_minimontage.iterrows():

    # Cross reference row, col with overlapping tiles DataFrame
    p_tile = df_overlapping.loc[(df_overlapping['z'] == q_tile['z']) &\
                                (df_overlapping['imageRow'] == q_tile['imageRow']) &\
                                (df_overlapping['imageCol'] == q_tile['imageCol'])].iloc[0]

    df_pairs.loc[i, 'p.stack'] = p_tile['stack']
    df_pairs.loc[i, 'q.stack'] = q_tile['stack']
    df_pairs.loc[i, 'z'] = q_tile['z']
    df_pairs.loc[i, 'p.groupId'] = p_tile['sectionId']
    df_pairs.loc[i, 'q.groupId'] = q_tile['sectionId']
    df_pairs.loc[i, 'p.id'] = p_tile['tileId']
    df_pairs.loc[i, 'q.id'] = q_tile['tileId']
    df_pairs.loc[i, 'imageRow'] = q_tile['imageRow']
    df_pairs.loc[i, 'imageCol'] = q_tile['imageCol']

# Preview
df_pairs.sample(10)

### Run `pointMatchClient` on (low mag EM, mini-montage) tile pairs
##### Set `SIFT` & `RANSAC` parameters

In [None]:
# `RANSAC` parameters
match_params = MatchDerivationParameters(matchIterations=None,
                                         matchMaxEpsilon=25,        # maximal alignment error
                                         matchMaxNumInliers=None,
                                         matchMaxTrust=None,
                                         matchMinInlierRatio=0.05,  # minimal inlier ratio
                                         matchMinNumInliers=7,      # minimal number of inliers
                                         matchModelType='AFFINE',   # expected transformation
                                         matchRod=0.92)             # closest/next closest ratio
# `SIFT` parameters
feature_params = FeatureExtractionParameters(SIFTfdSize=8,          # feature descriptor size
                                             SIFTmaxScale=0.20,     # (width/height *) maximum image size
                                             SIFTminScale=0.05,     # (width/height *) minimum image size
                                             SIFTsteps=6,           # steps per scale octave
                                             clipWidth=500,
                                             clipHeight=500)
# Combined `SIFT` & `RANSAC` parameters
sift_options = SiftPointMatchOptions(**{**match_params.__dict__,
                                        **feature_params.__dict__})

##### `pointMatchClient` wrapper for parallelized processing

In [None]:
def run_point_match_client(tile_pair_chunk, p_stack, q_stack, collection, sift_options, render):
    """Point match client wrapper supporting two different stacks for use in multiprocessing"""
    renderapi.client.pointMatchClient(stack=p_stack,
                                      stack2=q_stack,
                                      collection=collection,
                                      tile_pairs=tile_pair_chunk,
                                      sift_options=sift_options,
                                      render=render)

#### \*\****COMPUTATIONALLY EXPENSIVE*** \**

##### Run `pointMatchClient` on `N_cores`

In [None]:
# Set number of cores and batch size
N_cores = 10
batch_size = 10

# Set match collection
match_collection = f"{project}_minimontage_points"

# Loop through sections of each montage stack
for z, tile_pairs in tqdm(df_pairs.groupby('z')):

    # Group tile pairs into batches
    grouping = np.arange(len(tile_pairs)) // batch_size
    for i, batch in tqdm(tile_pairs.groupby(grouping), leave=False):

        # Create batch of tile pairs
        tp_batch = [[tuple(tp)] for tp in batch[['p.id', 'q.id']].values.tolist()]

        # Set up `pointMatchClient` partial
        point_match_client_partial = partial(run_point_match_client,
                                             p_stack=df_pairs['p.stack'].iloc[0],
                                             q_stack=df_pairs['q.stack'].iloc[0],
                                             collection=match_collection,
                                             sift_options=sift_options,
                                             render=render)

        # Run `pointMatchClient` on `N_cores`
        with renderapi.client.WithPool(N_cores) as pool:
            pool.map(point_match_client_partial, tp_batch)

### Analyze point matches

#### Collect point matches

In [None]:
# Initialize point matches DataFrame
matches_cols = ['p.stack', 'q.stack', 'z', 'p.groupId', 'q.groupId', 'p.id', 'q.id']
df_matches = df_pairs.loc[:, matches_cols].copy()
# Add columns for row, col, and number of point matches
df_matches['imageRow'] = np.nan
df_matches['imageCol'] = np.nan
df_matches['N'] = np.nan
df_matches['p.matches'] = None
df_matches['q.matches'] = None

# Iterate through tile pairs
for i, tp in tqdm(df_matches.iterrows(), total=len(df_matches)):

    # Get json point matches data for each tile pair
    matches_json = renderapi.pointmatch.get_matches_from_tile_to_tile(match_collection,
                                                                      pgroup=tp['p.groupId'],
                                                                      pid=tp['p.id'],
                                                                      qgroup=tp['q.groupId'],
                                                                      qid=tp['q.id'],
                                                                      render=render)
    # Check if point matches exist for tile pair
    if matches_json:
        # Normalize json to DataFrame
        df_json = pd.io.json.json_normalize(matches_json)
        # Count number of point matches
        df_matches.loc[i, 'N'] = np.array(df_json['matches.p'].iloc[0]).shape[1]
        # Add p, q matches
        df_matches.at[i, 'p.matches'] = df_json['matches.p'].iloc[0]
        df_matches.at[i, 'q.matches'] = df_json['matches.q'].iloc[0]
    # If no point matches exist
    else:
        df_matches.loc[i, 'N'] = 0
        df_matches.at[i, 'p.matches'] = []
        df_matches.at[i, 'q.matches'] = []

# Populate DataFrame with row, column and number of matches data
df_matches[['imageCol', 'imageRow']] = np.stack(df_matches['p.id'].apply(lambda x:\
                                           [int(i) for i in re.findall(r'\d+', x)[-2:]]))

# Preview
df_matches.sample(6)

#### Heatmap of point matches

In [None]:
# Data source
source = df_matches.drop(['p.matches', 'q.matches'], axis=1).copy()
source.loc[source['N'] == 0, 'N'] = np.nan

# Make heatmap
base = alt.Chart(source).encode(
    x='imageCol:O',
    y='imageRow:O'
)
heatmap = base.mark_rect().encode(
    color=alt.Color('N:Q'),
).properties(
    width=80,
    height=80
)
text = base.mark_text(baseline='middle').encode(
    text='N:Q',
)
# Facet heatmaps across sections and montage stacks
alt.layer(heatmap, text, data=source).facet(
    column=r'p\.groupId:N',
)

## Upload correlated stacks to `render-ws`
---

### Create DataFrame of correlated tile specifications
Basically means applying new set of affine transformations to overlaid stacks.

The exact transform and exact order of transforms here is super duper important. Also the fact that the poinmatch-based transform must be inverted. Also the fact that the initial big EM transform is tossed out.

The exact mapping from each big EM tile to each mini-montage is a combination of the
* Affine transform derived from `SIFT` point matches
* Scaling factor used to render the mini-montage.

In [None]:
# TODO: this is broken
#       Formation of A2 makes literally no sense

In [None]:
# Initialize DataFrame of correlated tile specifications
stacks_2_correlate = [stack for stack in stacks if 'overlaid' in stack]
df_stacks_correlated = create_stacks_DataFrame(stacks_2_correlate,
                                               render=render)
# Add `tforms` column
df_stacks_correlated['tforms'] = None
# Rename min, max intensity columns
df_stacks_correlated.rename(columns={'minIntensity': 'minint',
                                     'maxIntensity': 'maxint'},
                            inplace=True)

# Loop through CLEM tiles
for i, tile in tqdm(df_stacks_correlated.iterrows(),
                    total=len(df_stacks_correlated)):

    # Get correlative (low mag EM, minimontage) point matches for each tile to be correlated
    matches = df_matches.loc[(df_matches['z'] == tile['z']) &\
                             (df_matches['imageRow'] == tile['imageRow']) &\
                             (df_matches['imageCol'] == tile['imageCol'])]

    # Compute affine transform from low mag EM --> minimontage
    if len(matches):

        # Centering translation transform
        A1 = AffineRender()
        A1.M = tile['transforms'][0].M

        # Overlay transform
        if tile['stack'] in stacks_FM:
            A2 = AffineRender(M00=tile['transforms'][1].M00,
                              M01=tile['transforms'][1].M10,
                              M10=tile['transforms'][1].M01,
                              M11=tile['transforms'][1].M11,
                              B0=tile['transforms'][1].B0,
                              B1=tile['transforms'][1].B1)
#             A2.M = tile['transforms'][1].M
            print(tile['transforms'][1].M01, A2.M01)

        # Compute point-match-based affine transform using `RANSAC`
        src = np.array(matches['p.matches'].iloc[0]).T
        tgt = np.array(matches['q.matches'].iloc[0]).T
        model, inliers = ransac((src, tgt),
                                model_class=AffineSkimage,
                                min_samples=7,          # not sure how optimal
                                residual_threshold=25)  # these parameters are
        A3 = AffineRender()
        A3.M = model.params

        # Generate scaling transform based on downsampled minimontage
        A4 = df_minimontage.loc[(df_minimontage['z'] == tile['z']) &\
                                (df_minimontage['imageRow'] == tile['imageRow']) &\
                                (df_minimontage['imageCol'] == tile['imageCol']),
                                'tforms'].iloc[0][0]

        # Add transforms to DataFrame
        if tile['stack'] in stacks_FM:
            df_stacks_correlated.at[i, 'tforms'] = [A1, A2, A3, A4]

        else:  # big EM overlaid (hopefully)
            df_stacks_correlated.at[i, 'tforms'] = [A1, A3, A4]

    # No matches for this tile --> correlation not possible
    else:
        df_stacks_correlated.at[i, 'tforms'] = []

# Preview
df_stacks_correlated.groupby('stack')\
                    .apply(lambda x: x.sample(3))\
                    .drop('transforms', axis=1)

#### Preview CLEM tile correlated transforms

In [None]:
# Choose sample CLEM tile
sample_tile = df_stacks_correlated.sample(1).iloc[0]
CLEM_tile = df_stacks_correlated.loc[(df_stacks_correlated['z'] == sample_tile['z']) &\
                                     (df_stacks_correlated['imageRow'] == sample_tile['imageRow']) &\
                                     (df_stacks_correlated['imageCol'] == sample_tile['imageCol'])]

# Print out correlative transforms
for i, tile in CLEM_tile.iterrows():
    out = f"{tile['stack']}\n"
    for tform in tile['tforms']:
        out += f"... {tform}\n"
    print(out)

CLEM_tile

### Import to `render-ws`

In [None]:
# Collect overlay stacks
stacks_correlated = []

# Iterate through stacks
for stack_overlaid in tqdm(stacks_2_correlate):

    # Set overlay stack name
    stack_correlated = stack_overlaid.replace('overlaid', 'correlated')
    stacks_correlated.append(stack_correlated)

    # Create `TileSpec`s
    tile_specs = []
    df_stack = df_stacks_correlated.loc[(df_stacks_correlated['stack'] == stack_overlaid) &\
                                        (df_stacks_correlated['tforms'])]
    for i, tile in df_stack.iterrows():
        # Create `TileSpec
        ts = TileSpec(**tile.to_dict())
        tile_specs.append(ts)

    # Create stack
    renderapi.stack.create_stack(stack_correlated,
                                 render=render)

    # Import TileSpecs to render
    renderapi.client.import_tilespecs(stack_correlated,
                                      tile_specs,
                                      render=render)

    # Set stack state to complete
    renderapi.stack.set_stack_state(stack_correlated,
                                    'COMPLETE',
                                    render=render)

### Inspect correlated stacks

#### Tile map

In [None]:
# Specify stacks and sections
stacks_2_plot = stacks_correlated
sections_2_plot = df_project['sectionId'].unique().tolist()

# Set up figure
ncols = len(sections_2_plot)
fig, axes = plt.subplots(ncols=ncols, figsize=(8*ncols, 8))
axmap = {k: v for k, v in zip(sections_2_plot, axes.flat)}
cmap = {k: v for k, v in zip(stacks_2_plot, sns.color_palette(n_colors=len(stacks_2_plot)))}

# Iterate through layers
df_stacks = create_stacks_DataFrame(stacks_2_plot,
                                    render=render)
for sectionId, layer in tqdm(df_stacks.groupby('sectionId')):
    # Collect all tiles in each layer to determine bounds
    boxes = []
    # Set axis
    ax = axmap[sectionId]

    # Loop through tilesets within each layer
    for stack, tileset in layer.groupby('stack'):

        # Loop through each tile
        for i, tile in tileset.iterrows():

            # Create `shapely.box` resembling raw image tile
            b = box(0, 0, tile['width'], tile['height'])
            # Apply transforms to `shapely.box`
            for tform in tile['transforms']:
                A = (tform.M[:2, :2].ravel().tolist() +
                     tform.M[:2,  2].ravel().tolist())
                b = affinity.affine_transform(b, A)
            boxes.append(b)
            # Get coordinates of `shapely.box` to plot matplotlib polygon patch
            xy = np.array(b.exterior.xy).T
            p = Polygon(xy, color=cmap[stack], alpha=0.2)
            ax.add_patch(p)

    # Axis aesthetics
    ax.set_title(sectionId)
    ax.set_xlabel('X [px]')
    ax.set_ylabel('Y [px]')
    # Determine bounds
    bounds = np.swapaxes([b.exterior.xy for b in boxes], 1, 2).reshape(-1, 2)
    ax.set_xlim(bounds[:, 0].min(), bounds[:, 0].max())
    ax.set_ylim(bounds[:, 1].min(), bounds[:, 1].max())
    ax.invert_yaxis()
    ax.set_aspect('equal')

#### Render images

In [None]:
# Specify stacks and sections
stacks_2_plot = stacks_correlated
sections_2_plot = df_project['sectionId'].unique().tolist()

# Set up figure
nrows = len(stacks_2_plot)
ncols = len(sections_2_plot)
fig, axes = plt.subplots(nrows, ncols, figsize=(8*ncols, 8*nrows))
axmap = {k: v for k, v in zip(product(stacks_2_plot, sections_2_plot), axes.flat)}
cmap = {k: v for k, v in zip(stacks_2_plot, sns.color_palette(n_colors=len(stacks_2_plot)))}

# Iterate through layers
df_stacks = create_stacks_DataFrame(stacks_2_plot,
                                    render=render)
for (stack, sectionId), tileset in tqdm(df_stacks.groupby(['stack', 'sectionId'])):

    # Set axis
    ax = axmap[(stack, sectionId)]
    # Fetch tileset image
    z = tileset['z'].iloc[0]
    bounds = renderapi.stack.get_bounds_from_z(stack=stack,
                                               z=z,
                                               render=render)
    scale = 1024 / np.max([bounds['maxX'] - bounds['minX'],
                           bounds['maxY'] - bounds['minY']])
    image = renderapi.image.get_section_image(stack=stack,
                                              z=z,
                                              scale=scale,
                                              maxTileSpecsToRender=30,
                                              render=render)
    # Plot
    extent = [bounds['minX'],  # left
              bounds['maxX'],  # right
              bounds['minY'],  # bottom
              bounds['maxY']]  # top
    ax.imshow(image, extent=extent, origin='lower')
    # Axis aesthetics
    ax.set_title(f"{stack} | {sectionId}")
    ax.set_xlabel('X [px]')
    ax.set_ylabel('Y [px]')
    ax.set_xlim()
    ax.invert_yaxis()