In [2]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import cv2
import numpy as np
from functools import partial, reduce
from scipy import interpolate
import copy
import json


In [3]:
# ======================================================
# Aligner Class (handles warping based on morph points)
# ======================================================
class Aligner:
    def __init__(self):
        self.matrices = []      # Transformation matrices (for few points)
        self.n = 0              # Number of fixed (morphological) point pairs
        self.interp_order = (1, 1)
        self.min_interp_points = 5  # Use spline interpolation only when at least 5 points
        self._pairs = []        # List of (source, target) pairs (both np.array([x,y]))
        self._inv_maps = None   # Cached inverse mapping functions

    @property
    def pairs(self):
        if self.n == 0:
            return np.empty((0, 4), dtype=np.float32)
        return np.array([np.hstack((s, t)) for s, t in self._pairs], dtype=np.float32)

    @property
    def inv_maps(self):
        if self._inv_maps is not None:
            return self._inv_maps
        elif self.n < self.min_interp_points:
            # With too few points, return identity maps that output numpy arrays.
            def id_map_x(x, y):
                return np.array(x, dtype=np.float32)
            def id_map_y(x, y):
                return np.array(y, dtype=np.float32)
            return [id_map_x, id_map_y]
        else:
            return self.get_inv_interp(self.pairs[:, 0:2], self.pairs[:, 2:], order=self.interp_order)

    def __call__(self, im, dsize=None):
        if dsize is None:
            dsize = (im.shape[1], im.shape[0])
        im_map = self.get_im_map(dsize=dsize)
        return im_map(im)

    def add_point_pair(self, source, target):
        self._pairs.append((np.array(source, dtype=np.float32), np.array(target, dtype=np.float32)))
        self.n += 1
        if self.n < self.min_interp_points:
            self.matrices = self.compute_simple_matrix()
        else:
            self._inv_maps = None

    def remove_point(self, idx):
        if 0 <= idx < self.n:
            self._pairs.pop(idx)
            self.n -= 1
            if self.n < self.min_interp_points:
                self.matrices = self.compute_simple_matrix()
            else:
                self._inv_maps = None

    def compute_simple_matrix(self):
        if self.n == 0:
            return []
        elif self.n == 1:
            src, tgt = self._pairs[0]
            return [self.get_translation_matrix(src, tgt)]
        elif self.n == 2:
            src1, tgt1 = self._pairs[0]
            src2, tgt2 = self._pairs[1]
            return [self.get_rotation_matrix(tgt1, src2, tgt2)]
        elif self.n == 3:
            src = np.array([p[0] for p in self._pairs], dtype=np.float32)
            tgt = np.array([p[1] for p in self._pairs], dtype=np.float32)
            M = cv2.getAffineTransform(src, tgt)
            return [M]
        else:
            return []

    def get_im_map(self, dsize, n=None):
        n = n if n is not None else self.n
        if n < self.min_interp_points:
            if self.matrices:
                M = self.chain_matrices(self.matrices)
                return lambda im: cv2.warpAffine(im, M, dsize, flags=cv2.INTER_NEAREST)
            else:
                return lambda im: im
        else:
            x = np.arange(dsize[0])
            y = np.arange(dsize[1])
            fx = self.inv_maps[0](x, y).astype(np.float32).T
            fy = self.inv_maps[1](x, y).astype(np.float32).T
            return lambda im: cv2.remap(im, fx, fy, interpolation=cv2.INTER_NEAREST)

    @staticmethod
    def get_translation_matrix(source, target):
        dx, dy = target - source
        return np.array([[1, 0, dx], [0, 1, dy]], dtype=np.float32)

    @staticmethod
    def get_rotation_matrix(center, source, target):
        w1 = source - center
        w2 = target - center
        dot = np.dot(w1, w2)
        det = np.linalg.det(np.vstack([w1, w2]))
        norm1 = np.linalg.norm(w1)
        norm2 = np.linalg.norm(w2)
        scale = norm2 / norm1 if norm1 != 0 else 1.0
        angle = -np.arctan2(det, dot) * 180 / np.pi
        return cv2.getRotationMatrix2D(tuple(center), angle, scale)

    @staticmethod
    def chain_matrices(matrices):
        if not matrices:
            return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32)
        def to_3x3(mat):
            M = np.eye(3, dtype=np.float32)
            M[:2, :] = mat
            return M
        M = reduce(lambda A, B: np.matmul(A, to_3x3(B)), matrices, np.eye(3, dtype=np.float32))
        return M[:2, :]

    def get_inv_interp(self, source, target, order=(1, 1)):
        if target.size == 0:
            def id_map(x, y):
                return np.array(x, dtype=np.float32)
            return [id_map, id_map]
        n = target.shape[0]
        # Use linear interpolation when too few points exist.
        ord_val = 1 if n < 6 else min(3, int(np.sqrt(n)) - 1)
        x_max, y_max = target.max(axis=0)
        lims = {"xb": -5. * x_max, "xe": 5. * x_max, "yb": -5. * y_max, "ye": 5. * y_max}
        smooth_val = 10  # Increased smoothing parameter to reduce warnings.
        maps = []
        for j in range(2):
            tck = interpolate.bisplrep(
                target[:, 0], target[:, 1], source[:, j],
                kx=ord_val, ky=ord_val, s=smooth_val, **lims
            )
            maps.append(partial(interpolate.bisplev, tck=tck))
        return maps

# ======================================================
# Tkinter Application with Pre‑Morph Modes, Zoom, and Metadata
# ======================================================
class AlignApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("EBSD/Microstructure Image Aligner")
        self.geometry("1200x900")
        # Metadata dictionary to record operations.
        self.metadata = {'resize': [], 'translate': [], 'morph': []}
        self.meta_loaded = False

        # Top frame for buttons and instructions.
        self.btn_frame = tk.Frame(self)
        self.btn_frame.pack(side=tk.TOP, fill=tk.X)
        tk.Button(self.btn_frame, text="Load Microstructure", command=self.load_microstructure).pack(side=tk.LEFT, padx=5)
        tk.Button(self.btn_frame, text="Load Grain Boundaries", command=self.load_grain_boundaries).pack(side=tk.LEFT, padx=5)
        tk.Button(self.btn_frame, text="Save Result", command=self.save_result).pack(side=tk.LEFT, padx=5)
        tk.Button(self.btn_frame, text="Load Meta", command=self.load_meta).pack(side=tk.LEFT, padx=5)
        tk.Button(self.btn_frame, text="Reset Points", command=self.reset_points).pack(side=tk.LEFT, padx=5)
        tk.Label(self.btn_frame, text="r+left mouse: resize   t+left mouse: translate").pack(side=tk.LEFT, padx=10)
        
        # Canvas with scrollbars.
        self.canvas = tk.Canvas(self, bg="black")
        self.canvas.pack(fill=tk.BOTH, expand=True)
        self.hbar = tk.Scrollbar(self, orient=tk.HORIZONTAL, command=self.canvas.xview)
        self.hbar.pack(side=tk.BOTTOM, fill=tk.X)
        self.vbar = tk.Scrollbar(self, orient=tk.VERTICAL, command=self.canvas.yview)
        self.vbar.pack(side=tk.RIGHT, fill=tk.Y)
        self.canvas.configure(xscrollcommand=self.hbar.set, yscrollcommand=self.vbar.set)
        
        # Images and state.
        self.micro_img = None     # Microstructure image (BGR, real size)
        self.gb_img = None        # Grain boundary image (BGR)
        self.overlay = None       # Composite overlay (BGR)
        self.photo = None         # PhotoImage for display
        self.aligner = Aligner()  # For morph points (only active after pre‑morph adjustments)
        self.point_radius = 5
        self.zoom_factor = 1.0
        self.dragging = False
        self.source_pt = None     # Starting point of current drag (in image coordinates)
        self.current_mode = "morph"  # Modes: "morph" (default), "resize", "translate"
        
        # Bind mouse and key events.
        self.canvas.bind("<ButtonPress-1>", self.on_left_press)
        self.canvas.bind("<B1-Motion>", self.on_left_drag)
        self.canvas.bind("<ButtonRelease-1>", self.on_left_release)
        self.canvas.bind("<Button-3>", self.on_right_click)
        self.canvas.bind("<MouseWheel>", self.on_mousewheel)  # Windows
        self.canvas.bind("<Button-4>", self.on_mousewheel)    # Linux scroll up
        self.canvas.bind("<Button-5>", self.on_mousewheel)    # Linux scroll down
        self.bind("<KeyPress>", self.on_key_press)
        self.bind("<KeyRelease>", self.on_key_release)
        
    # -------------- Image Loading --------------
    def load_microstructure(self):
        path = filedialog.askopenfilename(title="Select Microstructure Image")
        if not path: 
            return
        self.micro_img = cv2.imread(path, cv2.IMREAD_COLOR)
        if self.micro_img is None:
            messagebox.showerror("Error", "Could not load microstructure image.")
            return
        self.update_overlay()
        
    def load_grain_boundaries(self):
        path = filedialog.askopenfilename(title="Select Grain Boundary Image")
        if not path: 
            return
        gb = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if gb is None:
            messagebox.showerror("Error", "Could not load grain boundary image.")
            return
        # Determine which pixel (white or black) is the minority.
        white = np.sum(gb > 127)
        black = np.sum(gb <= 127)
        if white < black:
            binary = (gb > 127).astype(np.uint8) * 255
        else:
            binary = (gb <= 127).astype(np.uint8) * 255
        self.gb_img = np.zeros((binary.shape[0], binary.shape[1], 3), dtype=np.uint8)
        self.gb_img[binary > 0] = (0, 0, 255)
        self.update_overlay()
    
    def compute_union_size(self):
        if self.micro_img is None:
            return (800, 600)
        h1, w1 = self.micro_img.shape[:2]
        if self.gb_img is not None:
            h2, w2 = self.gb_img.shape[:2]
            return (max(w1, w2), max(h1, h2))
        return (w1, h1)
    
    # -------------- Overlay and Display --------------
    def update_overlay(self):
        if self.micro_img is None:
            return
        union_w, union_h = self.compute_union_size()
        base = np.zeros((union_h, union_w, 3), dtype=np.uint8)
        # Place microstructure at top-left.
        mh, mw = self.micro_img.shape[:2]
        base[0:mh, 0:mw] = self.micro_img.copy()
        # If GB image is loaded, apply current morph transform.
        if self.gb_img is not None:
            warped_gb = self.aligner(self.gb_img, dsize=(union_w, union_h))
            mask = cv2.cvtColor(warped_gb, cv2.COLOR_BGR2GRAY)
            base[mask > 0] = (0, 0, 255)
        # Draw fixed morph points.
        for pair in self.aligner._pairs:
            pt = tuple(map(int, pair[1]))
            cv2.circle(base, pt, self.point_radius+2, (255, 255, 255), 2)
        self.overlay = base
        self.show_image()
        
    def show_image(self):
        if self.overlay is None:
            return
        disp = cv2.resize(self.overlay, (0, 0), fx=self.zoom_factor, fy=self.zoom_factor, interpolation=cv2.INTER_NEAREST)
        im_rgb = cv2.cvtColor(disp, cv2.COLOR_BGR2RGB)
        im_pil = Image.fromarray(im_rgb)
        self.photo = ImageTk.PhotoImage(image=im_pil)
        self.canvas.delete("all")
        self.canvas.create_image(0, 0, anchor=tk.NW, image=self.photo)
        self.canvas.config(scrollregion=(0, 0, disp.shape[1], disp.shape[0]))
    
    # -------------- Key and Mouse Event Handlers --------------
    def on_key_press(self, event):
        # Pre-morph adjustments are allowed only if no morph points exist.
        if self.aligner.n > 0:
            return
        if event.char.lower() == 'r':
            self.current_mode = "resize"
        elif event.char.lower() == 't':
            self.current_mode = "translate"
            
    def on_key_release(self, event):
        # Nothing to do on key release; we'll finalize on mouse release.
        pass
            
    def on_left_press(self, event):
        self.dragging = True
        self.source_pt = (event.x / self.zoom_factor, event.y / self.zoom_factor)
        
    def on_left_drag(self, event):
        if not self.dragging:
            return
        current_pt = (event.x / self.zoom_factor, event.y / self.zoom_factor)
        union_size = self.compute_union_size()
        temp = np.zeros((union_size[1], union_size[0], 3), dtype=np.uint8)
        if self.micro_img is not None:
            mh, mw = self.micro_img.shape[:2]
            temp[0:mh, 0:mw] = self.micro_img.copy()
        if self.gb_img is not None:
            if self.current_mode == "resize":
                dx = current_pt[0] - self.source_pt[0]
                scale = 1 + dx / 100.0  # Adjust sensitivity as needed.
                M = np.array([[scale, 0, 0],
                              [0, scale, 0]], dtype=np.float32)
                temp_gb = cv2.warpAffine(self.gb_img, M, union_size, flags=cv2.INTER_NEAREST)
            elif self.current_mode == "translate":
                dx = current_pt[0] - self.source_pt[0]
                dy = current_pt[1] - self.source_pt[1]
                M = np.array([[1, 0, dx],
                              [0, 1, dy]], dtype=np.float32)
                temp_gb = cv2.warpAffine(self.gb_img, M, union_size, flags=cv2.INTER_NEAREST)
            else:  # Morph mode.
                temp_aligner = copy.deepcopy(self.aligner)
                temp_aligner.add_point_pair(self.source_pt, current_pt)
                temp_gb = temp_aligner(self.gb_img, dsize=union_size)
            mask = cv2.cvtColor(temp_gb, cv2.COLOR_BGR2GRAY)
            temp[mask > 0] = (0, 0, 255)
        for pair in self.aligner._pairs:
            pt = tuple(map(int, pair[1]))
            cv2.circle(temp, pt, self.point_radius+2, (255, 255, 255), 2)
        if self.current_mode == "morph":
            cv2.line(temp, tuple(map(int, self.source_pt)), tuple(map(int, current_pt)), (0, 255, 0), 2)
        self.overlay = temp
        self.show_image()
        
    def on_left_release(self, event):
        if not self.dragging:
            return
        release_pt = (event.x / self.zoom_factor, event.y / self.zoom_factor)
        union_size = self.compute_union_size()
        # Finalize based on current mode.
        if self.current_mode == "resize":
            dx = release_pt[0] - self.source_pt[0]
            scale = 1 + dx / 100.0
            M = np.array([[scale, 0, 0],
                          [0, scale, 0]], dtype=np.float32)
            # Permanently update gb_img.
            self.gb_img = cv2.warpAffine(self.gb_img, M, union_size, flags=cv2.INTER_NEAREST)
            self.metadata['resize'].append({'start': self.source_pt, 'end': release_pt, 'scale': scale})
            # Reset aligner so that morph coordinates are in the new system.
            self.aligner = Aligner()
        elif self.current_mode == "translate":
            dx = release_pt[0] - self.source_pt[0]
            dy = release_pt[1] - self.source_pt[1]
            M = np.array([[1, 0, dx],
                          [0, 1, dy]], dtype=np.float32)
            self.gb_img = cv2.warpAffine(self.gb_img, M, union_size, flags=cv2.INTER_NEAREST)
            self.metadata['translate'].append({'start': self.source_pt, 'end': release_pt, 'dx': dx, 'dy': dy})
            self.aligner = Aligner()
        else:  # Morph mode.
            self.aligner.add_point_pair(self.source_pt, release_pt)
            self.metadata['morph'] = self.aligner.pairs.tolist()
        self.current_mode = "morph"
        self.dragging = False
        self.source_pt = None
        self.update_overlay()
        
    def on_right_click(self, event):
        if self.aligner.n == 0:
            return
        pts = np.array([pair[1] for pair in self.aligner._pairs])
        click_pt = np.array([event.x / self.zoom_factor, event.y / self.zoom_factor])
        dists = np.linalg.norm(pts - click_pt, axis=1)
        idx = int(np.argmin(dists))
        if dists[idx] < 15:
            self.aligner.remove_point(idx)
            self.metadata['morph'] = self.aligner.pairs.tolist()
            self.update_overlay()
            
    def on_mousewheel(self, event):
        if hasattr(event, 'delta'):
            if event.delta > 0:
                self.zoom_factor *= 1.1
            else:
                self.zoom_factor /= 1.1
        else:
            if event.num == 4:
                self.zoom_factor *= 1.1
            elif event.num == 5:
                self.zoom_factor /= 1.1
        self.show_image()
    
    # -------------- Saving and Meta File --------------
    def save_result(self):
        if self.micro_img is None or self.gb_img is None:
            messagebox.showwarning("Warning", "Both images must be loaded.")
            return
        union_w, union_h = self.compute_union_size()
        warped_gb = self.aligner(self.gb_img, dsize=(union_w, union_h))
        base = np.zeros((union_h, union_w, 3), dtype=np.uint8)
        mh, mw = self.micro_img.shape[:2]
        base[0:mh, 0:mw] = self.micro_img.copy()
        mask = cv2.cvtColor(warped_gb, cv2.COLOR_BGR2GRAY)
        base[mask > 0] = (0, 0, 255)
        overlap_x = min(mw, union_w)
        overlap_y = min(mh, union_h)
        if overlap_x <= 0 or overlap_y <= 0:
            messagebox.showwarning("Warning", "No overlapping region to save.")
            return
        cropped_overlay = base[0:overlap_y, 0:overlap_x]
        cropped_gb = warped_gb[0:overlap_y, 0:overlap_x]
        filename = filedialog.asksaveasfilename(defaultextension=".tif",
                                                filetypes=[("TIFF files", "*.tif"), ("All files", "*.*")])
        if not filename:
            return
        cv2.imwrite(filename, cropped_overlay)
        gb_filename = filename.replace(".tif", "_gb.tif")
        cv2.imwrite(gb_filename, cropped_gb)
        pts_filename = filename.replace(".tif", "_pts.txt")
        with open(pts_filename, "wt") as f:
            for pair in self.aligner._pairs:
                s, t = pair
                f.write(f"{s[0]}, {s[1]} -> {t[0]}, {t[1]}\n")
        meta_filename = filename.replace(".tif", ".meta")
        with open(meta_filename, "wt") as f:
            json.dump(self.metadata, f, indent=4)
        messagebox.showinfo("Saved", "Result and metadata saved.")
        
    def load_meta(self):
        path = filedialog.askopenfilename(title="Select Metadata File", filetypes=[("Meta files", "*.meta")])
        if not path:
            return
        try:
            with open(path, "rt") as f:
                meta = json.load(f)
            # Reapply pre-morph operations.
            if 'resize' in meta:
                for op in meta['resize']:
                    scale = op['scale']
                    union_size = self.compute_union_size()
                    M = np.array([[scale, 0, 0],
                                  [0, scale, 0]], dtype=np.float32)
                    self.gb_img = cv2.warpAffine(self.gb_img, M, union_size, flags=cv2.INTER_NEAREST)
            if 'translate' in meta:
                for op in meta['translate']:
                    dx = op['dx']
                    dy = op['dy']
                    union_size = self.compute_union_size()
                    M = np.array([[1, 0, dx],
                                  [0, 1, dy]], dtype=np.float32)
                    self.gb_img = cv2.warpAffine(self.gb_img, M, union_size, flags=cv2.INTER_NEAREST)
            # Load morph points.
            if 'morph' in meta:
                self.aligner = Aligner()
                for row in meta['morph']:
                    s = (row[0], row[1])
                    t = (row[2], row[3])
                    self.aligner.add_point_pair(s, t)
            self.metadata = meta
            self.meta_loaded = True
            messagebox.showinfo("Meta Loaded", "Metadata loaded successfully. You may now continue morphing.")
            self.update_overlay()
        except Exception as e:
            messagebox.showerror("Error", f"Could not load metadata file:\n{e}")
        
    def reset_points(self):
        self.aligner = Aligner()
        self.metadata['morph'] = []
        self.update_overlay()

# ======================================================
# Main Execution
# ======================================================
if __name__ == "__main__":
    app = AlignApp()
    app.mainloop()


Exception in Tkinter callback
Traceback (most recent call last):
  File "c:\Users\lorenzo.francesia\AppData\Local\Programs\Python\Python312\Lib\tkinter\__init__.py", line 1968, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "C:\Users\lorenzo.francesia\AppData\Local\Temp\ipykernel_17872\1268127277.py", line 200, in load_microstructure
    self.update_overlay()
  File "C:\Users\lorenzo.francesia\AppData\Local\Temp\ipykernel_17872\1268127277.py", line 243, in update_overlay
    base[mask > 0] = (0, 0, 255)
    ~~~~^^^^^^^^^^
IndexError: boolean index did not match indexed array along axis 1; size of axis is 8192 but size of corresponding boolean axis is 2048
Exception in Tkinter callback
Traceback (most recent call last):
  File "c:\Users\lorenzo.francesia\AppData\Local\Programs\Python\Python312\Lib\tkinter\__init__.py", line 1968, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "C:\Users\lorenzo.francesia\AppData\Local\Temp\ipykernel