In [None]:
# TODO
# 1. fix/examine displacement vector math
# 2. perhaps establish the colour map over the net displacement from image 0 -> image nf
# 3. look at combining multiple feature algorithms eg edge detection, region segmentation
# 4. find more data for deep learning model
# 5. establish a few more optical flow algorithms to attempt to average out/improve precision of displacement vectors


In [122]:
# requirements
import os
import cv2
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import filedialog, ttk
from PIL import Image, ImageTk
import fitz  
import io

In [123]:
'''
This section is designed to optimise image pre-processing 
Aims:
- manage sequential images and folder structure such that sequence is easily retained
- perform visual/mathematical filtering to refine image contrast and enable the subsequent algorithm
- denoising the image (speckle decorrelation?)
- contrast enhancement
- motion artifact correction?
'''
# input variable is a folder containing png images in chronological order for now.
def preprocess_image_sequence(image_sequence):

    # create a dir for sequential preprocessed images
    if not os.path.exists("denoised_sequence"):
        os.makedirs("denoised_sequence")
    if not os.path.exists("contrasted_sequence"):
        os.makedirs("contrasted_sequence")
    if not os.path.exists("preprocessed_sequence"):
        os.makedirs("preprocessed_sequence")

    # sort image paths from input folder
    image_paths = sorted(glob(os.path.join(image_sequence, "*.png")))

    # read first image as the reference image
    reference_image = cv2.imread(image_paths[0], cv2.IMREAD_GRAYSCALE)

    for idx, image_path in enumerate(image_paths):
        # read in current image
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

        # denoise image via median filtering 
        denoised_image = cv2.medianBlur(image, 3)
        output_path = os.path.join("denoised_sequence", f"{idx:03d}.png")
        cv2.imwrite(output_path, denoised_image)

        # contrast enhancement via adaptive histogram equalisation
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced_image = clahe.apply(denoised_image)
        output_path = os.path.join("contrasted_sequence", f"{idx:03d}.png")
        cv2.imwrite(output_path, enhanced_image)

        # motion artifact correction or rigid alignment of current image to the reference image
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        _, warp_matrix = cv2.findTransformECC(reference_image, enhanced_image, warp_matrix, cv2.MOTION_EUCLIDEAN)

        corrected_image = cv2.warpAffine(enhanced_image, warp_matrix, (enhanced_image.shape[1], enhanced_image.shape[0]),
                                         flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)
        output_path = os.path.join("preprocessed_sequence", f"{idx:03d}.png")
        cv2.imwrite(output_path, corrected_image)

    # Visualize the steps
    fig, axes = plt.subplots(1, 4, figsize=(16, 5))
    images = [image, denoised_image, enhanced_image, corrected_image]
    titles = ['Original', 'Denoised', 'Contrast Enhanced', 'Motion Corrected']

    for ax, img, title in zip(axes, images, titles):
        ax.imshow(img, cmap='gray')
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout()
    plt.savefig("preprocessing_pipeline")

# a simplified single image process for utilisation by the tkinter 
def preprocess_single_image(img, reference_img=None):
    steps = {}

    # Step 1: Original
    steps['Original'] = img.copy()

    # Step 2: Denoising
    denoised = cv2.medianBlur(img, 3)
    steps['Denoised'] = denoised

    # Step 3: Contrast Enhancement
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(denoised)
    steps['Contrast Enhanced'] = enhanced

    # Step 4: Motion Artifact Correction
    if reference_img is not None:
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        try:
            _, warp_matrix = cv2.findTransformECC(reference_img, enhanced, warp_matrix, cv2.MOTION_EUCLIDEAN)
            corrected = cv2.warpAffine(enhanced, warp_matrix, (enhanced.shape[1], enhanced.shape[0]),
                                       flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)
        except cv2.error:
            corrected = enhanced.copy()  # fallback if ECC fails
    else:
        corrected = enhanced.copy()

    steps['Motion Corrected'] = corrected

    return steps



In [124]:
'''
The purpose of this section is to test combinations of several tracking algorithms into a single, hybrid methodology
Aims:
- building a combined hybrid methodology which factors the strengths of each algorithm
- tweaking the influence or 'strength' of each algorithm based on overall model performance
- perhaps model efficiency optimisation
'''

# optical tracking implementation
'''
Meant to target larger homogeneous regions of the tissue (areas of similar contrast)
'''
def optical_tracking():
    return 0

# feature based alignment implementation
'''
Meant to target tissue landmarks and structures
need to test some segmentation algorithms and maybe combine to form an OCT optimal strategy
'''
def feature_alignment():
    return 0

# deep learning fine-tuning tracking implementation
'''
Uses a library of labelled imagery to make tracking adjustments
not sure yet on method of implementation, probably looking at CNN
'''
def CNN():
    return 0



In [125]:
# === Base Page Class ===
class Page(tk.Frame):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

# === Individual Pages ===
class HomePage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)
        label = ttk.Label(self, text="Welcome to the OCT Deformation Tracking Suite", font=("Helvetica", 18))
        label.pack(pady=20)

In [126]:

class PreprocessingPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # --- Controls Bar ---
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)
        ttk.Button(ctrl, text="Select Input Folder", command=self.load_images)\
            .pack(side='left', padx=5)
        ttk.Button(ctrl, text="Previous", command=self.prev_image)\
            .pack(side='left', padx=5)
        ttk.Button(ctrl, text="Next", command=self.next_image)\
            .pack(side='left', padx=5)
        ttk.Button(ctrl, text="Reorder…", command=self.open_reorder_dialog)\
            .pack(side='left', padx=5)

        # --- Canvas for display ---
        self.canvas = tk.Canvas(self, bg='#f0f0f0')
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)

        # --- Internal State ---
        self.entries = []            # list of (name, np.ndarray) in extraction order
        self.images = []             # just the image arrays
        self.reference_img = None
        self.index = 0
        self.preprocessed_steps = [] # list of dicts per image
        self.tk_imgs = {}            # cache of PhotoImage for canvas

    def preprocess_single_image(self, img, reference_img=None):
        steps = {'Original': img.copy()}
        steps['Denoised'] = cv2.medianBlur(img, 3)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        steps['Contrast Enhanced'] = clahe.apply(steps['Denoised'])
        if reference_img is not None:
            M = np.eye(2,3, dtype=np.float32)
            try:
                _, M = cv2.findTransformECC(reference_img, steps['Contrast Enhanced'], M,
                                            cv2.MOTION_EUCLIDEAN)
                steps['Motion Corrected'] = cv2.warpAffine(
                    steps['Contrast Enhanced'], M,
                    (img.shape[1], img.shape[0]),
                    flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
                )
            except cv2.error:
                steps['Motion Corrected'] = steps['Contrast Enhanced'].copy()
        else:
            steps['Motion Corrected'] = steps['Contrast Enhanced'].copy()
        return steps

    def load_images(self):
        src_folder = filedialog.askdirectory(title="Select Input Folder")
        if not src_folder:
            return

        out_folder = os.path.join(src_folder, "extracted_OCT")
        os.makedirs(out_folder, exist_ok=True)

        # clear old
        import glob
        for old in glob.glob(os.path.join(out_folder, "*.png")):
            os.remove(old)

        self.entries = []
        seq_idx = 0   # visitation counter

        for fname in os.listdir(src_folder):
            if "OCT" not in fname.upper():
                continue
            path = os.path.join(src_folder, fname)
            ext  = os.path.splitext(fname)[1].upper()

            if ext == ".PDF":
                doc = fitz.open(path)
                for pnum, page in enumerate(doc):
                    for img_idx, img_info in enumerate(page.get_images(full=True)):
                        if img_idx != 1:
                            continue
                        xref     = img_info[0]
                        img_dict = doc.extract_image(xref)
                        pil_img  = Image.open(io.BytesIO(img_dict["image"])).convert("L")
                        arr      = np.array(pil_img)

                        # name with sequence index
                        base_name = f"{seq_idx:03d}_{os.path.splitext(fname)[0]}_p{pnum+1}_i{img_idx}.png"
                        pil_img.save(os.path.join(out_folder, base_name), "PNG")
                        self.entries.append((base_name, arr))
                        seq_idx += 1
                doc.close()

            elif ext in (".PNG", ".JPG", ".JPEG", ".TIF", ".TIFF"):
                img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                if img is None:
                    continue
                pil_img  = Image.fromarray(img)
                base_name = f"{seq_idx:03d}_{os.path.splitext(fname)[0]}.png"
                pil_img.save(os.path.join(out_folder, base_name), "PNG")
                self.entries.append((base_name, img))
                seq_idx += 1

        if not self.entries:
            messagebox.showwarning("No OCT Images",
                                   "No files with 'OCT' found in the selected folder.")
            return

        # now rebuild images/preprocessing
        self._rebuild_images_and_steps()
        self.index = 0
        self.show_current_image()


    def _rebuild_images_and_steps(self):
        # Called after load or after reorder
        self.images = [arr for (name, arr) in self.entries]
        self.reference_img = self.images[0]
        self.preprocessed_steps = [
            self.preprocess_single_image(img, self.reference_img)
            for img in self.images
        ]

    def show_current_image(self):
        self.canvas.delete("all")
        if not self.preprocessed_steps:
            return
        steps = self.preprocessed_steps[self.index]
        for i, (title, img) in enumerate(steps.items()):
            thumb  = Image.fromarray(img).resize((220,220))
            tk_img = ImageTk.PhotoImage(thumb)
            self.tk_imgs[i] = tk_img
            x = 10 + i*240
            self.canvas.create_image(x,10,anchor='nw',image=tk_img)
            self.canvas.create_text(x+110,240,text=title,font=("Helvetica",12))

    def next_image(self):
        if self.index < len(self.images)-1:
            self.index += 1
            self.show_current_image()

    def prev_image(self):
        if self.index>0:
            self.index -= 1
            self.show_current_image()

    def open_reorder_dialog(self):
        """Open a small dialog to let the user drag images up/down in self.entries."""
        dlg = tk.Toplevel(self)
        dlg.title("Reorder OCT Frames")
        dlg.grab_set()

        lb = tk.Listbox(dlg, height= min(10, len(self.entries)), width=40)
        lb.pack(side='left', fill='y', padx=(10,0), pady=10)
        for name, _ in self.entries:
            lb.insert('end', name)

        btn_frame = ttk.Frame(dlg)
        btn_frame.pack(side='left', fill='y', padx=10, pady=10)

        def move(up):
            idx = lb.curselection()
            if not idx:
                return
            i = idx[0]
            j = i-1 if up else i+1
            if j<0 or j>=lb.size():
                return
            # swap in Listbox
            name_i = lb.get(i)
            name_j = lb.get(j)
            lb.delete(i); lb.insert(i, name_j)
            lb.delete(j); lb.insert(j, name_i)
            lb.selection_clear(0,'end'); lb.selection_set(j)

        ttk.Button(btn_frame, text="↑ Move Up",   command=lambda: move(True)).pack(fill='x', pady=5)
        ttk.Button(btn_frame, text="↓ Move Down", command=lambda: move(False)).pack(fill='x', pady=5)
        ttk.Separator(btn_frame, orient='horizontal').pack(fill='x', pady=5)
        def apply_and_close():
            # reorder self.entries to match Listbox
            new_order = [lb.get(i) for i in range(lb.size())]
            name_to_entry = {name: arr for name, arr in self.entries}
            self.entries = [(n, name_to_entry[n]) for n in new_order]
            self._rebuild_images_and_steps()
            self.index = 0
            self.show_current_image()
            dlg.destroy()

        ttk.Button(btn_frame, text="Apply", command=apply_and_close).pack(fill='x', pady=(20,5))
        ttk.Button(btn_frame, text="Cancel", command=dlg.destroy).pack(fill='x')


In [127]:
class OpticalFlowPage(ttk.Frame):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # Available algorithms
        self.algorithms = ['Farneback', 'Lucas-Kanade', 'Speckle Tracking', 'TVL1']
        try:
            self.tvl1 = cv2.optflow.DualTVL1OpticalFlow_create()
        except Exception:
            self.algorithms.remove('TVL1')
            self.tvl1 = None

        # Controls
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)
        ttk.Button(ctrl, text="Select Folder", command=self.load_images).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Previous Image", command=self.prev_image).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Next Image", command=self.next_image).pack(side='left', padx=5)

        ttk.Label(ctrl, text="Algorithm:").pack(side='left', padx=(20,5))
        self.selected_alg = tk.StringVar(value=self.algorithms[0])
        alg_combo = ttk.Combobox(ctrl, textvariable=self.selected_alg,
                                 values=self.algorithms, state='readonly', width=15)
        alg_combo.pack(side='left', padx=5)
        alg_combo.bind('<<ComboboxSelected>>', lambda e: self.show_flow())

        ttk.Label(ctrl, text="Grid Step:").pack(side='left', padx=(20,5))
        self.grid_step = tk.IntVar(value=20)
        step_slider = ttk.Scale(ctrl, from_=1, to=20, variable=self.grid_step,
                                command=lambda e: self.show_flow())
        step_slider.pack(side='left', padx=5)
        self.step_label = ttk.Label(ctrl, text=str(self.grid_step.get()))
        self.step_label.pack(side='left')
        self.grid_step.trace_add("write", lambda *args: self.step_label.config(text=str(self.grid_step.get())))

        self.show_net_var  = tk.BooleanVar(value=True)
        self.show_inc_var  = tk.BooleanVar(value=True)
        self.color_net_var = tk.BooleanVar(value=False)
        ttk.Checkbutton(ctrl, text="Show Net Vector",         variable=self.show_net_var,  command=self.show_flow).pack(side='left', padx=(20,5))
        ttk.Checkbutton(ctrl, text="Show Incremental Vector", variable=self.show_inc_var,  command=self.show_flow).pack(side='left', padx=5)
        ttk.Checkbutton(ctrl, text="Color Net by Mag",        variable=self.color_net_var, command=self.show_flow).pack(side='left', padx=5)

        # Canvas for visualization
        self.canvas = tk.Canvas(self, bg='#f0f0f0')
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)

        # State
        self.images = []
        self.index  = 0
        self.tk_imgs = []

    def load_images(self):
        folder = filedialog.askdirectory()
        if not folder:
            return

        # 1) List all .png files in that folder
        png_files = [f for f in os.listdir(folder) if f.lower().endswith('.png')]

        # 2) Sort them by the integer part of the filename (e.g. "0.png"→0, "10.png"→10)
        def numeric_key(fn):
            name = os.path.splitext(fn)[0]
            try:
                return int(name)
            except ValueError:
                return fn  # non‐numeric names will sort lexicographically

        png_files.sort(key=numeric_key)

        # 3) Build full paths and load in that order
        self.image_paths = [os.path.join(folder, f) for f in png_files]
        self.images = [cv2.imread(p, cv2.IMREAD_GRAYSCALE) for p in self.image_paths]

        # 4) Reset index and redraw
        self.index = 0
        self.show_flow()

    def preprocess_variants(self, img):
        variants = {
            "Raw":      img.copy(),
            "Denoised": cv2.medianBlur(img, 3)
        }
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        variants["Enhanced"] = clahe.apply(variants["Denoised"])
        return variants

    def compute_flow(self, img1, img2, method):
        if method == 'Farneback':
            return cv2.calcOpticalFlowFarneback(
                img1, img2, None,
                pyr_scale=0.5, levels=3,
                winsize=15, iterations=3,
                poly_n=5, poly_sigma=1.2, flags=0
            )
        if method == 'TVL1' and self.tvl1:
            return self.tvl1.calc(img1, img2, None)
        if method == 'Lucas-Kanade':
            p0 = cv2.goodFeaturesToTrack(img1, maxCorners=200,
                                         qualityLevel=0.01, minDistance=7, blockSize=7)
            if p0 is None:
                return None
            p1, st, _ = cv2.calcOpticalFlowPyrLK(
                img1, img2, p0, None,
                winSize=(15,15), maxLevel=2
            )
            pts0 = p0[st.flatten()==1].reshape(-1,2)
            pts1 = p1[st.flatten()==1].reshape(-1,2)
            return (pts0, pts1)
        if method == 'Speckle Tracking':
            h, w = img1.shape
            pts = np.array([[x, y]
                            for y in range(0, h, self.grid_step.get())
                            for x in range(0, w, self.grid_step.get())])
            pts1 = []
            half_t = 5
            tpl_h, tpl_w = 2*half_t+1, 2*half_t+1
            s = self.grid_step.get()
            for x0,y0 in pts:
                tpl = img1[max(y0-half_t,0):y0+half_t+1,
                           max(x0-half_t,0):x0+half_t+1]
                win = img2[max(y0-s,0):y0+s+1,
                           max(x0-s,0):x0+s+1]
                if win.shape[0]<tpl_h or win.shape[1]<tpl_w:
                    pts1.append((x0,y0))
                    continue
                res = cv2.matchTemplate(win, tpl, cv2.TM_CCOEFF_NORMED)
                _,_,_,max_loc = cv2.minMaxLoc(res)
                dx = max_loc[0] - (win.shape[1]//2 - half_t)
                dy = max_loc[1] - (win.shape[0]//2 - half_t)
                pts1.append((x0+dx, y0+dy))
            return (pts, np.array(pts1))
        return None

    def draw_arrows(self, vis, flow, method, color):
        step = self.grid_step.get()
        if method in ['Farneback','TVL1']:
            h,w = vis.shape[:2]
            y,x = np.mgrid[step//2:h:step, step//2:w:step].astype(int)
            fx,fy = flow[y,x].T
            for x0,y0,dx,dy in zip(x.flatten(),y.flatten(),fx.flatten(),fy.flatten()):
                cv2.arrowedLine(vis, (x0,y0), (int(x0+dx),int(y0+dy)),
                                color=color, thickness=1)
        else:
            pts0,pts1 = flow
            for (x0,y0),(x1,y1) in zip(pts0,pts1):
                cv2.arrowedLine(vis, (int(x0),int(y0)), (int(x1),int(y1)),
                                color=color, thickness=1)

    def _draw_arrows_from_pts(self, vis, pts0, disp, color_or_lut):
        if isinstance(color_or_lut, tuple):
            for (x0,y0),(dx,dy) in zip(pts0, disp):
                cv2.arrowedLine(vis, (int(x0),int(y0)), (int(x0+dx),int(y0+dy)),
                                color_or_lut, thickness=1)
        else:
            lut = color_or_lut[:,0]
            mags = np.hypot(disp[:,0], disp[:,1])
            maxm = mags.max() or 1.0
            for (x0,y0),(dx,dy),m in zip(pts0, disp, mags):
                idx = int(255 * min(m/maxm,1.0))
                b,g,r = lut[idx]
                cv2.arrowedLine(vis, (int(x0),int(y0)), (int(x0+dx),int(y0+dy)),
                                (int(b),int(g),int(r)), thickness=1)

    def show_flow(self):
        self.canvas.delete("all")
        self.tk_imgs.clear()
        if not self.images:
            return

        alg      = self.selected_alg.get()
        show_inc = self.show_inc_var.get()
        show_net = self.show_net_var.get()
        color_net= self.color_net_var.get()
        step     = self.grid_step.get()

        curr = self.images[self.index]
        prev = self.images[self.index-1] if self.index>0 else None
        ref  = self.images[0]           if self.index>0 else None

        curr_vars = self.preprocess_variants(curr)
        prev_vars = self.preprocess_variants(prev) if prev is not None else {}
        ref_vars  = self.preprocess_variants(ref)  if ref  is not None else {}

        self.canvas.update_idletasks()
        W   = self.canvas.winfo_width(); pad = 20
        tw  = (W - pad*4)//3

        for idx, var in enumerate(["Raw","Denoised","Enhanced"]):
            vis = cv2.cvtColor(curr_vars[var], cv2.COLOR_GRAY2BGR)

            if self.index > 0:
                # incremental
                if show_inc:
                    flow_prev = self.compute_flow(prev_vars[var], curr_vars[var], alg)
                    if flow_prev is not None:
                        self.draw_arrows(vis, flow_prev, alg, (0,255,0))

                # net
                if show_net:
                    flow_net = self.compute_flow(ref_vars[var], curr_vars[var], alg)
                    if flow_net is not None:
                        # extract pts0, disp
                        if isinstance(flow_net, tuple):
                            p0, p1 = flow_net
                            pts0, disp = p0.astype(int), p1 - p0
                        else:
                            h,w = flow_net.shape[:2]
                            y,x = np.mgrid[step//2:h:step, step//2:w:step].astype(int)
                            pts0 = np.stack([x.flatten(), y.flatten()], axis=-1)
                            fx,fy = flow_net[...,0], flow_net[...,1]
                            disp = np.stack([fx[y,x], fy[y,x]], axis=-1).reshape(-1,2)

                        if color_net:
                            jet_lut = cv2.applyColorMap(
                                np.arange(256, dtype=np.uint8),
                                cv2.COLORMAP_JET
                            )
                            self._draw_arrows_from_pts(vis, pts0, disp, jet_lut)
                        else:
                            self._draw_arrows_from_pts(vis, pts0, disp, (0,0,255))

            # thumbnail & display
            h2,w2 = vis.shape[:2]; th = int(tw*h2/w2)
            thumb = cv2.resize(vis, (tw,th), interpolation=cv2.INTER_AREA)
            img_tk = ImageTk.PhotoImage(Image.fromarray(
                        cv2.cvtColor(thumb, cv2.COLOR_BGR2RGB)))
            self.tk_imgs.append(img_tk)
            x = pad + idx*(tw+pad); y = pad
            self.canvas.create_image(x, y, anchor='nw', image=img_tk)
            self.canvas.create_text(x+tw//2, y+th+10,
                                    text=var, font=("Helvetica",8,"bold"))

        self.master.master.title(f"{alg} — Img0→Img{self.index}/{len(self.images)-1}")

    def next_image(self):
        if self.index < len(self.images)-1:
            self.index += 1
            self.show_flow()

    def prev_image(self):
        if self.index > 0:
            self.index -= 1
            self.show_flow()


In [None]:
class FeatureRegistrationPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # --- Controls bar ---
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)

        ttk.Button(ctrl, text="Select Folder", command=self.load_images).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Previous",      command=self.prev_image).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Next",          command=self.next_image).pack(side='left', padx=5)

        # Relative threshold slider (background minus pixel)
        ttk.Label(ctrl, text="Rel Thresh:").pack(side='left', padx=(20,5))
        self.rel_thresh = tk.IntVar(value=15)
        rel_sld = ttk.Scale(ctrl, from_=0, to=100, variable=self.rel_thresh,
                            command=lambda e: self.show_registration())
        rel_sld.pack(side='left')
        self.rel_lbl = ttk.Label(ctrl, text=str(self.rel_thresh.get()))
        self.rel_lbl.pack(side='left', padx=(5,10))
        self.rel_thresh.trace_add("write",
            lambda *a: self.rel_lbl.config(text=str(self.rel_thresh.get()))
        )

        # Min contour area slider
        ttk.Label(ctrl, text="Min Area:").pack(side='left', padx=(10,5))
        self.min_area = tk.IntVar(value=50)
        area_sld = ttk.Scale(ctrl, from_=10, to=2000, variable=self.min_area,
                             command=lambda e: self.show_registration())
        area_sld.pack(side='left', padx=(0,5))
        self.area_lbl = ttk.Label(ctrl, text=str(self.min_area.get()))
        self.area_lbl.pack(side='left')
        self.min_area.trace_add("write",
            lambda *a: self.area_lbl.config(text=str(self.min_area.get()))
        )

        # --- Canvas ---
        self.canvas = tk.Canvas(self, bg='#f0f0f0')
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)

        # --- State ---
        self.images = []    # grayscale frames
        self.index  = 0
        self.tk_imgs = []   # PhotoImage refs

    def load_images(self):
        folder = filedialog.askdirectory()
        if not folder:
            return

        fnames = [f for f in os.listdir(folder) if f.lower().endswith('.png')]
        fnames.sort(key=lambda fn: int(os.path.splitext(fn)[0]) if fn[:-4].isdigit() else fn)
        self.images = [
            cv2.imread(os.path.join(folder, fn), cv2.IMREAD_GRAYSCALE)
            for fn in fnames
        ]
        if not self.images:
            return

        self.index = 0
        self.show_registration()

    def _segment_local_dark_spots(self, img):
        """
        Finds spots darker than their local background.
        Returns filtered contours.
        """
        # 1) compute local background by Gaussian blur
        #    kernel size = roughly 1/10th of min dimension, odd
        k = max(5, (min(img.shape)//10)//2*2+1)
        bg = cv2.GaussianBlur(img, (k,k), 0)

        # 2) difference image: bg - img
        diff = cv2.subtract(bg, img)

        # 3) threshold difference
        thresh = self.rel_thresh.get()
        _, mask = cv2.threshold(diff, thresh, 255, cv2.THRESH_BINARY)

        # 4) clean mask
        kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kern, iterations=1)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kern, iterations=2)

        # 5) find contours
        cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # 6) filter by area
        h, w = img.shape
        min_a = self.min_area.get()
        max_a = 0.8 * h * w
        good = [c for c in cnts
                if min_a <= cv2.contourArea(c) <= max_a]
        return good

    def show_registration(self):
        self.canvas.delete("all")
        self.tk_imgs.clear()
        if not self.images:
            return

        img = self.images[self.index]

        # prepare three variants
        denoised = cv2.bilateralFilter(img, d=9, sigmaColor=75, sigmaSpace=75)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced = clahe.apply(denoised)
        variants = [('Raw', img), ('Denoised', denoised), ('Enhanced', enhanced)]

        # layout thumbnails
        self.canvas.update_idletasks()
        W = self.canvas.winfo_width()
        pad = 20
        thumb_w = (W - pad*4) // 3

        for i, (name, vimg) in enumerate(variants):
            vis = cv2.cvtColor(vimg, cv2.COLOR_GRAY2BGR)
            # detect local dark spots
            cnts = self._segment_local_dark_spots(vimg)
            # draw outlines in green
            cv2.drawContours(vis, cnts, -1, (0,255,0), 1)

            # resize to thumbnail
            h, w = vis.shape[:2]
            thumb_h = int(thumb_w * h / w)
            thumb = cv2.resize(vis, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA)

            imgtk = ImageTk.PhotoImage(Image.fromarray(thumb))
            self.tk_imgs.append(imgtk)

            x = pad + i*(thumb_w + pad)
            y = pad
            self.canvas.create_image(x, y, anchor='nw', image=imgtk)
            self.canvas.create_text(
                x + thumb_w//2, y + thumb_h + 12,
                text=name, font=("Helvetica", 9, "bold"), fill="#333"
            )

        total = len(self.images)
        self.master.master.title(
            f"Local Dark Spots — Img{self.index}/{total-1}  "
            f"(Δ>{self.rel_thresh.get()}, Area>{self.min_area.get()})"
        )

    def next_image(self):
        if self.index < len(self.images)-1:
            self.index += 1
            self.show_registration()

    def prev_image(self):
        if self.index > 0:
            self.index -= 1
            self.show_registration()


In [129]:
class DeepLearningRegistrationPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)
        ttk.Label(self, text="Deep Learning Registration", font=("Helvetica", 18)).pack(pady=20)



In [130]:
class HybridModelPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)
        ttk.Label(self, text="Hybrid Model (Feature + Flow + Deep Learning)", font=("Helvetica", 18)).pack(pady=20)



In [131]:
class BenchmarkingPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)
        ttk.Label(self, text="Benchmarking Framework", font=("Helvetica", 18)).pack(pady=20)



In [132]:
# === Main Application ===
class MainApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("OCT Deformation Tracking Suite")
        self.geometry("1200x800")
        style = ttk.Style(self)
        style.theme_use('clam')
        style.configure('TNotebook.Tab', padding=(10, 10))

        notebook = ttk.Notebook(self)
        notebook.pack(fill='both', expand=True)

        # Add pages to notebook
        for PageClass, title in [
            (HomePage, "Home"),
            (PreprocessingPage, "Preprocessing"),
            (OpticalFlowPage, "Optical Flow"),
            (FeatureRegistrationPage, "Feature Reg."),
            (DeepLearningRegistrationPage, "Deep Learning"),
            (HybridModelPage, "Hybrid Model"),
            (BenchmarkingPage, "Benchmarking"),
        ]:
            page = PageClass(notebook)
            notebook.add(page, text=title)

if __name__ == "__main__":
    MainApp().mainloop()