# Stitching Map
---
### Imports

In [None]:
# Libraries needed
import os
import subprocess
from functools import partial
import re
from pathlib import Path
import json
from pprint import pprint

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from skimage.io import imread
from skimage.transform import pyramid_reduce

import renderapi

### Initialize render

In [None]:
owner = 'rlane'  # replace with your name
project = 'iCAT_align'
stack = 'lil_EM'
match_collection = 'lil_EM_points'
project_dir = Path('/long_term_storage/rlane/SECOM/iCAT_sample_data')

# Create a renderapi.connect.Render object
render_connect_params = {
    'host': 'sonic',
    'port': 8080,
    'owner': owner,
    'project': project,
    'client_scripts': '/home/catmaid/render/render-ws-java-client/src/main/scripts',
    'memGB': '2G'
}

render = renderapi.connect(**render_connect_params)

### Initialize Grid

In [None]:
# Grid settings
r0, c0 = (0, 16)  # (row, col)
Nr, Nc = (4, 4)  # (row col)
rows = []
cols = []

# Get all tile specs within stack
stack_tile_specs = renderapi.tilespec.get_tile_specs_from_stack(stack, render=render)
# Get tile width and height
w = int(np.mean([tile_spec.width for tile_spec in stack_tile_specs]))
h = int(np.mean([tile_spec.height for tile_spec in stack_tile_specs]))

# Get tile specs within grid
grid_tile_specs = []
for tile_spec in stack_tile_specs:
    # Get row, col
    r = tile_spec.layout.imageRow; rows.append(r)
    c = tile_spec.layout.imageCol; cols.append(c)
    # Filter to within grid
    if (r0 <= r <= (r0 + Nr-1)) and (c0 <= c <= (c0 + Nc-1)):
        grid_tile_specs.append(tile_spec)

### Show Grid

In [None]:
# Make a heatmap of the stack
stack_map = np.zeros((max(rows), max(cols)))
stack_map[r0:r0+(Nr-1), c0:c0+(Nc-1)] = 1
fig, ax = plt.subplots(figsize=(7, 7))
ax = sns.heatmap(stack_map, vmin=0, vmax=1, ax=ax);

# Adjust x ticks
xticks = [t.get_text() for t in ax.get_xticklabels()]
plt.xticks(np.array(xticks).astype(int), ha='left');
ax.set_xticklabels(xticks, ha='center');
# Adjust y ticks
yticks = [t.get_text() for t in ax.get_yticklabels()]
plt.yticks(np.array(yticks).astype(int), va='top');
ax.set_yticklabels(yticks, va='center');
ax.set_aspect('equal')

### Get Tile Pair List
This is done by running `renderapi.client.tilePairClient` and widdling it down.

In [None]:
# Get positional bounds of image stack
stack_bounds = renderapi.stack.get_stack_bounds(stack, render=render)

# Generate tile pairs for input into point match generator
tile_pair_data = renderapi.client.tilePairClient(stack,
                                                 minz=stack_bounds['minZ'],
                                                 maxz=stack_bounds['maxZ'],
                                                 render=render)

# Generate list of tile pairs    
stack_tile_pairs = [(tp['p']['id'], tp['q']['id']) for tp in tile_pair_data['neighborPairs']]

# Filter tile pairs to include only those in grid
grid_tileIds = [tile_spec.tileId for tile_spec in grid_tile_specs]
grid_tile_pairs = []
for tile_pair in stack_tile_pairs:
    # Both tiles in pair must be in grid
    if (tile_pair[0] in grid_tileIds) and (tile_pair[1] in grid_tileIds):
        grid_tile_pairs.append(tile_pair)

# Show tile pairs
print(f'{len(grid_tile_pairs)} / {len(stack_tile_pairs)} Tile Pairs in Grid')
grid_tile_pairs

### Get Matches

In [None]:
p_matches = {}
q_matches = {}

for tile_pair in grid_tile_pairs:
    
    # Get group IDsm
    groupIds = renderapi.pointmatch.get_match_groupIds(match_collection, render=render)
    
    tile_pair_matches = renderapi.pointmatch.get_matches_from_tile_to_tile(
                            match_collection,
                            pgroup=groupIds[0],
                            pid=tile_pair[0],
                            qgroup=groupIds[0],
                            qid=tile_pair[1],
                            render=render)
    
    # Suss out column and row number from tileId
    pc, pr = [int(i) for i in re.findall('\d+', tile_pair[0])]
    qc, qr = [int(i) for i in re.findall('\d+', tile_pair[1])]
    
    try:
        # Stuff matches into arrays
        p_matches_arr = np.array(tile_pair_matches[0]['matches']['p']).T
        q_matches_arr = np.array(tile_pair_matches[0]['matches']['q']).T
        # Ensure p and q matches check out
        assert len(p_matches_arr) == len(q_matches_arr)
    
    except IndexError:
        # No matches for this tile pair
        print(f"No matches for tile pair:  {tile_pair[0]} <--> {tile_pair[1]}")
        p_matches_arr = []
        q_matches_arr = []
        
    p_matches[((pr, pc), (qr, qc))] = p_matches_arr
    q_matches[((pr, pc), (qr, qc))] = q_matches_arr

### Make Mosaic

In [None]:
# Down sample factor
dsf = 2
Npx = 2048

# Initialize mosaic
mosaic = np.full(fill_value=2**15, shape=(int(Nr*w//dsf),
                                          int(Nc*h//dsf)))

# Loop through tile specs
for tile_spec in grid_tile_specs:
    
    # Get row and column of image tile
    r = tile_spec.layout.imageRow
    c = tile_spec.layout.imageCol
    
    # Create thumbnail image
    imageUrl = tile_spec.to_dict()['mipmapLevels']['0']['imageUrl']
    tile_img = imread(imageUrl)
    clipped = np.clip(tile_img, a_min=tile_spec.minint, a_max=tile_spec.maxint)
    thumb = pyramid_reduce(clipped, downscale=dsf, multichannel=False)
    
    # Fill in mosaic
    ii, ji = ((r-r0) * h//dsf, (c-c0) * w//dsf)
    ij, jj = (ii + h//dsf, ji + w//dsf)
    mosaic[ii:ij, ji:jj] = thumb

### Make Stitch Figure

In [None]:
# Make plot
fig, ax = plt.subplots(figsize=(14, 11))
ax.imshow(mosaic, cmap='Greys_r');

# Reduce visible stitch lines by factor
rf = 10 #  (1 = no reduction)

# Make stitch lines
for (((pr, pc), (qr, qc)), pms), (((pr, pc), (qr, qc)), qms) in zip(p_matches.items(), q_mxatches.items()):
    
    try:
        lines = np.vstack((pms[:,0]/dsf + (pc-c0)*Npx/dsf,
                           qms[:,0]/dsf + (qc-c0)*Npx/dsf,
                           pms[:,1]/dsf + (pr-r0)*Npx/dsf,
                           qms[:,1]/dsf + (qr-r0)*Npx/dsf)).T[::rf]
        ax.plot(*lines.reshape(-1, 2), color='#00FF77');

    except TypeError:
        pass
    
ax.axis('off');