# Stitch backscatter electron (BSE) images

Håkon Wiik Ånes

2019-12-18

Based on:
* https://scikit-image.org/docs/stable/auto_examples/transform/plot_register_translation.html#sphx-glr-auto-examples-transform-plot-register-translation-py
* https://github.com/scikit-image/skimage-tutorials/blob/master/lectures/adv3_panorama-stitching.ipynb

In [12]:
%matplotlib qt5

import os

import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as scn
import skimage.exposure as ske
import skimage.color as skc
import skimage.feature as skf
import skimage.measure as skm
import skimage.transform as skt

datadir = '/home/hakon/phd/data/p/prover/0s/191217_ebsd/se_bse/scan'

## Parameters

In [93]:
height_big, width_big = (1536, 2048)
height_small, width_small = (768, 1024)

scale = (1/93.8) * 1e3  # nm/px, for 5000x magnification
banner_height = {str(height_small): 77, str(height_big): 155}

## Load images

In [95]:
imgs = []

img_size = (height_big - banner_height[str(height_big)], width_big)

for i in range(26):
    if i == 1:
        continue

    # Read pattern
    fname = '5000x' + str(i) + '.tif'
    img = plt.imread(os.path.join(datadir, fname))
    dtype_in = img.dtype.type
    
    # Turn into grey image
    img_rgb = skc.rgba2rgb(img)
    img_grey = skc.rgb2grey(img_rgb)
    
    # Crop away bottom banner
    img_height = img_grey.shape[0]
    img_banner_height = banner_height[str(img_height)]
    img_cropped = img_grey[:-img_banner_height, :]
    
    # Upscale images acquired with lower scan resolution
    if img_height != max_img_height:
        img_rescaled = skt.resize(
            img_cropped, output_shape=img_size, anti_aliasing=True)
    else:
        img_rescaled = img_cropped

    # Rescale intensity back to uint8
#    img_rescaled = ske.rescale_intensity(img_rescaled, out_range=dtype_in)
#    img_rescaled = img_rescaled.astype(dtype_in)
    
    # Append to list
    imgs.append(img_rescaled)

## Find image shifts

In [368]:
#img1, img2 = imgs[:2]

img1 = imgs[0]
img2 = imgs[1]

#img1 = imgs[-4]
#img2 = imgs[0]

# Get overlapping subimages
relationship = 'left-right'
img_part_size = 600  # Either height or width, depending on relationship
if relationship == 'left-right':
    img1_part = img1[:, -img_part_size:]
    img2_part = img2[:, :img_part_size]

shift, error, diffphase = skf.register_translation(
    src_image=img2_part, target_image=img1_part, upsample_factor=100)

img1_part_shifted = scn.fourier_shift(np.fft.fftn(img1_part), shift)
img1_part_shifted = np.fft.ifftn(img1_part_shifted)

fig = plt.figure(figsize=(8, 3))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2, sharex=ax1, sharey=ax1)
ax3 = plt.subplot(1, 4, 3, sharex=ax1, sharey=ax1)
ax4 = plt.subplot(1, 4, 4)

ax1.imshow(img1_part, cmap='gray')
ax1.set_axis_off()

ax2.imshow(img2_part, cmap='gray')
ax2.set_axis_off()

ax3.imshow(img1_part_shifted.real, cmap='gray')
ax3.set_axis_off()

image_product = np.fft.fft2(img2_part) * np.fft.fft2(img1_part_shifted).conj()
cc_image = np.fft.fftshift(np.fft.ifft2(image_product))
ax4.imshow(cc_image.real)
ax4.set_axis_off()

print(shift)
print(error)

[  47.29 -217.16]
0.09626753743024664


In [394]:
img0 = imgs[-4]
img1 = imgs[0]
img2 = imgs[1]

overlap = 600

img0_part = img0[:, -overlap:]
img1_part1 = img1[:, :overlap]
img1_part2 = img1[:, -overlap:]
img2_part = img2[:, :overlap]

shift01, error01, phasediff01 = skf.register_translation(
    src_image=img1_part1, target_image=img0_part, upsample_factor=100)
shift01[1] = -width + overlap + shift01[1]
offset01 = skt.SimilarityTransform(translation=shift01[::-1])

shift21, error21, phasediff21 = skf.register_translation(
    src_image=img1_part2, target_image=img2_part, upsample_factor=100)
shift21[1] = width - overlap + shift21[1]
offset21 = skt.SimilarityTransform(translation=shift21[::-1])

#shift12, error12, phasediff12 = skf.register_translation(
#    src_image=img2_part, target_image=img1_part2, upsample_factor=100)
#shift12[1] += overlap - width
#offset12 = skt.SimilarityTransform(translation=shift12[::-1])

fig, ax = plt.subplots(ncols=4)
ax[0].imshow(img0_part)
ax[1].imshow(img1_part1)
ax[2].imshow(img1_part2)
ax[3].imshow(img2_part)

print(shift01, shift21)

[  105.39 -1449.03] [ -47.29 1665.16]


## Find image coordinates in master image

In [400]:
row, col = img1.shape[:2]

corners = np.array([[0, 0], [0, row], [col, 0], [col, row]])

corners01 = offset01(corners)
#corners12 = offset12(corners)
corners21 = offset21(corners)

#all_corners = np.vstack((corners01, corners12, corners))
all_corners = np.vstack((corners01, corners21, corners))

corner_min = np.min(all_corners, axis=0)
corner_max = np.max(all_corners, axis=0)
output_shape = (corner_max - corner_min)
output_shape = np.ceil(output_shape[::-1]).astype(int)

print(output_shape)

[1534 5163]


## Place images in master image

In [404]:
offset1 = skt.SimilarityTransform(translation=-corner_min)

img1_warped = skt.warp(img1, offset1.inverse, output_shape=output_shape)
img1_mask = (img1_warped != -1)
img1_warped[~img1_mask] = 0

transform01 = (offset01 + offset1).inverse
img0_warped = skt.warp(img0, transform01, output_shape=output_shape)
img0_mask = (img0_warped != -1)
img0_warped[~img0_mask] = 0

transform21 = (offset21 + offset1).inverse
img2_warped = skt.warp(img2, transform21, output_shape=output_shape)
img2_mask = (img2_warped != -1)
img2_warped[~img2_mask] = 0

compare(img0_warped, img1_warped, img2_warped, figsize=(12, 10));

In [405]:
merged = img0_warped + img1_warped + img2_warped

overlap = img0_mask + img1_mask + img2_mask

normalized = merged / overlap

## Inspect result

In [406]:
plt.figure()
plt.imshow(normalized, cmap='gray')

<matplotlib.image.AxesImage at 0x7f816439ffd0>