In [42]:
# TODO
# 1. fix/examine displacement vector math
# 2. perhaps establish the colour map over the net displacement from image 0 -> image nf
# 3. look at combining multiple feature algorithms eg edge detection, region segmentation
# 4. find more data for deep learning model
# 5. establish a few more optical flow algorithms to attempt to average out/improve precision of displacement vectors


In [43]:
# requirements
import os
import cv2
import numpy as np
from glob import glob
import matplotlib.pyplot as plt
import tkinter as tk
from tkinter import filedialog, ttk
from PIL import Image, ImageTk
import fitz  
import io

In [44]:
'''
This section is designed to optimise image pre-processing 
Aims:
- manage sequential images and folder structure such that sequence is easily retained
- perform visual/mathematical filtering to refine image contrast and enable the subsequent algorithm
- denoising the image (speckle decorrelation?)
- contrast enhancement
- motion artifact correction?
'''
# input variable is a folder containing png images in chronological order for now.
def preprocess_image_sequence(image_sequence):

    # create a dir for sequential preprocessed images
    if not os.path.exists("denoised_sequence"):
        os.makedirs("denoised_sequence")
    if not os.path.exists("contrasted_sequence"):
        os.makedirs("contrasted_sequence")
    if not os.path.exists("preprocessed_sequence"):
        os.makedirs("preprocessed_sequence")

    # sort image paths from input folder
    image_paths = sorted(glob(os.path.join(image_sequence, "*.png")))

    # read first image as the reference image
    reference_image = cv2.imread(image_paths[0], cv2.IMREAD_GRAYSCALE)

    for idx, image_path in enumerate(image_paths):
        # read in current image
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

        # denoise image via median filtering 
        denoised_image = cv2.medianBlur(image, 3)
        output_path = os.path.join("denoised_sequence", f"{idx:03d}.png")
        cv2.imwrite(output_path, denoised_image)

        # contrast enhancement via adaptive histogram equalisation
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced_image = clahe.apply(denoised_image)
        output_path = os.path.join("contrasted_sequence", f"{idx:03d}.png")
        cv2.imwrite(output_path, enhanced_image)

        # motion artifact correction or rigid alignment of current image to the reference image
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        _, warp_matrix = cv2.findTransformECC(reference_image, enhanced_image, warp_matrix, cv2.MOTION_EUCLIDEAN)

        corrected_image = cv2.warpAffine(enhanced_image, warp_matrix, (enhanced_image.shape[1], enhanced_image.shape[0]),
                                         flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)
        output_path = os.path.join("preprocessed_sequence", f"{idx:03d}.png")
        cv2.imwrite(output_path, corrected_image)

    # Visualize the steps
    fig, axes = plt.subplots(1, 4, figsize=(16, 5))
    images = [image, denoised_image, enhanced_image, corrected_image]
    titles = ['Original', 'Denoised', 'Contrast Enhanced', 'Motion Corrected']

    for ax, img, title in zip(axes, images, titles):
        ax.imshow(img, cmap='gray')
        ax.set_title(title)
        ax.axis('off')

    plt.tight_layout()
    plt.savefig("preprocessing_pipeline")

# a simplified single image process for utilisation by the tkinter 
def preprocess_single_image(img, reference_img=None):
    steps = {}

    # Step 1: Original
    steps['Original'] = img.copy()

    # Step 2: Denoising
    denoised = cv2.medianBlur(img, 3)
    steps['Denoised'] = denoised

    # Step 3: Contrast Enhancement
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(denoised)
    steps['Contrast Enhanced'] = enhanced

    # Step 4: Motion Artifact Correction
    if reference_img is not None:
        warp_matrix = np.eye(2, 3, dtype=np.float32)
        try:
            _, warp_matrix = cv2.findTransformECC(reference_img, enhanced, warp_matrix, cv2.MOTION_EUCLIDEAN)
            corrected = cv2.warpAffine(enhanced, warp_matrix, (enhanced.shape[1], enhanced.shape[0]),
                                       flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)
        except cv2.error:
            corrected = enhanced.copy()  # fallback if ECC fails
    else:
        corrected = enhanced.copy()

    steps['Motion Corrected'] = corrected

    return steps



In [45]:
# === Base Page Class ===
class Page(tk.Frame):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

# === Individual Pages ===
class HomePage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)
        label = ttk.Label(self, text="Welcome to the OCT Deformation Tracking Suite", font=("Helvetica", 18))
        label.pack(pady=20)

In [46]:

class PreprocessingPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # --- Controls Bar ---
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)
        ttk.Button(ctrl, text="Select Input Folder", command=self.load_images)\
            .pack(side='left', padx=5)
        ttk.Button(ctrl, text="Previous", command=self.prev_image)\
            .pack(side='left', padx=5)
        ttk.Button(ctrl, text="Next", command=self.next_image)\
            .pack(side='left', padx=5)
        ttk.Button(ctrl, text="Reorder…", command=self.open_reorder_dialog)\
            .pack(side='left', padx=5)

        # --- Canvas for display ---
        self.canvas = tk.Canvas(self, bg='#f0f0f0')
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)

        # --- Internal State ---
        self.entries = []            # list of (name, np.ndarray) in extraction order
        self.images = []             # just the image arrays
        self.reference_img = None
        self.index = 0
        self.preprocessed_steps = [] # list of dicts per image
        self.tk_imgs = {}            # cache of PhotoImage for canvas

    def preprocess_single_image(self, img, reference_img=None):
        steps = {'Original': img.copy()}
        steps['Denoised'] = cv2.medianBlur(img, 3)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        steps['Contrast Enhanced'] = clahe.apply(steps['Denoised'])
        if reference_img is not None:
            M = np.eye(2,3, dtype=np.float32)
            try:
                _, M = cv2.findTransformECC(reference_img, steps['Contrast Enhanced'], M,
                                            cv2.MOTION_EUCLIDEAN)
                steps['Motion Corrected'] = cv2.warpAffine(
                    steps['Contrast Enhanced'], M,
                    (img.shape[1], img.shape[0]),
                    flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP
                )
            except cv2.error:
                steps['Motion Corrected'] = steps['Contrast Enhanced'].copy()
        else:
            steps['Motion Corrected'] = steps['Contrast Enhanced'].copy()
        return steps

    def load_images(self):
        src_folder = filedialog.askdirectory(title="Select Input Folder")
        if not src_folder:
            return

        out_folder = os.path.join(src_folder, "extracted_OCT")
        os.makedirs(out_folder, exist_ok=True)

        # clear old
        import glob
        for old in glob.glob(os.path.join(out_folder, "*.png")):
            os.remove(old)

        self.entries = []
        seq_idx = 0   # visitation counter

        for fname in os.listdir(src_folder):
            if "OCT" not in fname.upper():
                continue
            path = os.path.join(src_folder, fname)
            ext  = os.path.splitext(fname)[1].upper()

            if ext == ".PDF":
                doc = fitz.open(path)
                for pnum, page in enumerate(doc):
                    for img_idx, img_info in enumerate(page.get_images(full=True)):
                        if img_idx != 1:
                            continue
                        xref     = img_info[0]
                        img_dict = doc.extract_image(xref)
                        pil_img  = Image.open(io.BytesIO(img_dict["image"])).convert("L")
                        arr      = np.array(pil_img)

                        # name with sequence index
                        base_name = f"{seq_idx:03d}_{os.path.splitext(fname)[0]}_p{pnum+1}_i{img_idx}.png"
                        pil_img.save(os.path.join(out_folder, base_name), "PNG")
                        self.entries.append((base_name, arr))
                        seq_idx += 1
                doc.close()

            elif ext in (".PNG", ".JPG", ".JPEG", ".TIF", ".TIFF"):
                img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
                if img is None:
                    continue
                pil_img  = Image.fromarray(img)
                base_name = f"{seq_idx:03d}_{os.path.splitext(fname)[0]}.png"
                pil_img.save(os.path.join(out_folder, base_name), "PNG")
                self.entries.append((base_name, img))
                seq_idx += 1

        if not self.entries:
            messagebox.showwarning("No OCT Images",
                                   "No files with 'OCT' found in the selected folder.")
            return

        # now rebuild images/preprocessing
        self._rebuild_images_and_steps()
        self.index = 0
        self.show_current_image()


    def _rebuild_images_and_steps(self):
        # Called after load or after reorder
        self.images = [arr for (name, arr) in self.entries]
        self.reference_img = self.images[0]
        self.preprocessed_steps = [
            self.preprocess_single_image(img, self.reference_img)
            for img in self.images
        ]

    def show_current_image(self):
        self.canvas.delete("all")
        if not self.preprocessed_steps:
            return
        steps = self.preprocessed_steps[self.index]
        for i, (title, img) in enumerate(steps.items()):
            thumb  = Image.fromarray(img).resize((220,220))
            tk_img = ImageTk.PhotoImage(thumb)
            self.tk_imgs[i] = tk_img
            x = 10 + i*240
            self.canvas.create_image(x,10,anchor='nw',image=tk_img)
            self.canvas.create_text(x+110,240,text=title,font=("Helvetica",12))

    def next_image(self):
        if self.index < len(self.images)-1:
            self.index += 1
            self.show_current_image()

    def prev_image(self):
        if self.index>0:
            self.index -= 1
            self.show_current_image()

    def open_reorder_dialog(self):
        """Open a small dialog to let the user drag images up/down in self.entries."""
        dlg = tk.Toplevel(self)
        dlg.title("Reorder OCT Frames")
        dlg.grab_set()

        lb = tk.Listbox(dlg, height= min(10, len(self.entries)), width=40)
        lb.pack(side='left', fill='y', padx=(10,0), pady=10)
        for name, _ in self.entries:
            lb.insert('end', name)

        btn_frame = ttk.Frame(dlg)
        btn_frame.pack(side='left', fill='y', padx=10, pady=10)

        def move(up):
            idx = lb.curselection()
            if not idx:
                return
            i = idx[0]
            j = i-1 if up else i+1
            if j<0 or j>=lb.size():
                return
            # swap in Listbox
            name_i = lb.get(i)
            name_j = lb.get(j)
            lb.delete(i); lb.insert(i, name_j)
            lb.delete(j); lb.insert(j, name_i)
            lb.selection_clear(0,'end'); lb.selection_set(j)

        ttk.Button(btn_frame, text="↑ Move Up",   command=lambda: move(True)).pack(fill='x', pady=5)
        ttk.Button(btn_frame, text="↓ Move Down", command=lambda: move(False)).pack(fill='x', pady=5)
        ttk.Separator(btn_frame, orient='horizontal').pack(fill='x', pady=5)
        def apply_and_close():
            # reorder self.entries to match Listbox
            new_order = [lb.get(i) for i in range(lb.size())]
            name_to_entry = {name: arr for name, arr in self.entries}
            self.entries = [(n, name_to_entry[n]) for n in new_order]
            self._rebuild_images_and_steps()
            self.index = 0
            self.show_current_image()
            dlg.destroy()

        ttk.Button(btn_frame, text="Apply", command=apply_and_close).pack(fill='x', pady=(20,5))
        ttk.Button(btn_frame, text="Cancel", command=dlg.destroy).pack(fill='x')


In [None]:
import os
import re
import cv2
import numpy as np
import tkinter as tk
from tkinter import ttk, filedialog
from PIL import Image, ImageTk

class OpticalFlowPage(ttk.Frame):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # Available algorithms
        self.algorithms = ['Farneback', 'Lucas-Kanade', 'Speckle Tracking', 'TVL1']
        try:
            self.tvl1 = cv2.optflow.DualTVL1OpticalFlow_create()
        except Exception:
            self.algorithms.remove('TVL1')
            self.tvl1 = None

        # Controls frame
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)
        ttk.Button(ctrl, text="Select Folder", command=self.load_images).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Previous Image", command=self.prev_image).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Next Image", command=self.next_image).pack(side='left', padx=5)

        ttk.Label(ctrl, text="Algorithm:").pack(side='left', padx=(20,5))
        self.selected_alg = tk.StringVar(value=self.algorithms[0])
        alg_combo = ttk.Combobox(ctrl, textvariable=self.selected_alg,
                                 values=self.algorithms, state='readonly', width=15)
        alg_combo.pack(side='left', padx=5)
        alg_combo.bind('<<ComboboxSelected>>', lambda e: self._compute_all_and_show())

        ttk.Label(ctrl, text="Grid Step:").pack(side='left', padx=(20,5))
        self.grid_step = tk.IntVar(value=20)
        ttk.Scale(ctrl, from_=1, to=50, variable=self.grid_step,
                  command=lambda e: self.show_flow()).pack(side='left', padx=5)
        self.step_label = ttk.Label(ctrl, text=str(self.grid_step.get()))
        self.step_label.pack(side='left')
        self.grid_step.trace_add("write", lambda *args: self.step_label.config(text=str(self.grid_step.get())))

        # Display options: computed trail/net and NPY-based
        self.show_trail_var = tk.BooleanVar(value=True)
        self.show_net_var   = tk.BooleanVar(value=False)
        self.show_npy_trail_var = tk.BooleanVar(value=False)
        self.show_npy_net_var   = tk.BooleanVar(value=False)
        self.show_net_error_var = tk.BooleanVar(value=False)
        ttk.Checkbutton(ctrl, text="Show Computed Trail", variable=self.show_trail_var,
                       command=self.show_flow).pack(side='left', padx=(20,5))
        ttk.Checkbutton(ctrl, text="Show Computed Net", variable=self.show_net_var,
                       command=self.show_flow).pack(side='left', padx=5)
        ttk.Checkbutton(ctrl, text="Show NPY Trail", variable=self.show_npy_trail_var,
                       command=self.show_flow).pack(side='left', padx=5)
        ttk.Checkbutton(ctrl, text="Show NPY Net", variable=self.show_npy_net_var,
                       command=self.show_flow).pack(side='left', padx=5)
        ttk.Checkbutton(ctrl, text="Show Net Error", variable=self.show_net_error_var,
                        command=self.show_flow).pack(side='left', padx=5)

        # Canvas for display
        self.canvas = tk.Canvas(self, bg='#f0f0f0')
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)
        self.canvas.bind('<Configure>', lambda e: self.show_flow())

        # Internal state
        self.images = []
        self.flows_inc = []      # computed incremental flows
        self.net_flow = None     # computed net flow
        self.npy_inc = []        # loaded incremental dx/dy as npy per frame
        self.npy_net = None      # loaded or computed net from npy
        self.index = 0

    @staticmethod
    def _numeric_key(fn):
        name = os.path.splitext(fn)[0]
        m = re.search(r"(\d+)", name)
        return (0, int(m.group(1))) if m else (1, name.lower())

    def load_images(self):
        folder = filedialog.askdirectory()
        if not folder:
            return
        files = os.listdir(folder)
        imgs = sorted([f for f in files if f.lower().endswith(('.png','.jpg','.jpeg','.tif'))],
                      key=self._numeric_key)
        self.images = [cv2.imread(os.path.join(folder, f), cv2.IMREAD_GRAYSCALE)
                       for f in imgs]

        # Load .npy dx_/dy_ files for incremental displacement
        dx_files = sorted([f for f in files if f.startswith('dx_') and f.endswith('.npy')], key=self._numeric_key)
        dy_files = sorted([f for f in files if f.startswith('dy_') and f.endswith('.npy')], key=self._numeric_key)
        self.npy_inc = [None]
        for dx_fn, dy_fn in zip(dx_files, dy_files):
            dx = np.load(os.path.join(folder, dx_fn))
            dy = np.load(os.path.join(folder, dy_fn))
            self.npy_inc.append(np.dstack((dx, dy)))
        # NPY net will be derived dynamically from trail, not precomputed
        self.npy_net = None
        self.index = 0
        self._compute_all()
        self.show_flow()

    def prev_image(self):
        if self.index > 0:
            self.index -= 1
            self.show_flow()

    def next_image(self):
        if self.index < len(self.images) - 1:
            self.index += 1
            self.show_flow()

    def _compute_all(self):
        alg = self.selected_alg.get()
        # Computed incremental flows
        self.flows_inc = [None]
        for i in range(1, len(self.images)):
            f = self.compute_flow(self.images[i-1], self.images[i], alg)
            self.flows_inc.append(f)

    def _compute_all_and_show(self):
        if self.images:
            self._compute_all()
            self.show_flow()

    def compute_flow(self, img1, img2, method):
        if method == 'Farneback':
            return cv2.calcOpticalFlowFarneback(
                img1, img2, None,
                pyr_scale=0.5, levels=3,
                winsize=15, iterations=3,
                poly_n=5, poly_sigma=1.2, flags=0)
        if method == 'TVL1' and self.tvl1:
            return self.tvl1.calc(img1, img2, None)
        if method == 'Lucas-Kanade':
            p0 = cv2.goodFeaturesToTrack(img1, maxCorners=200,
                                         qualityLevel=0.01, minDistance=7)
            if p0 is None:
                return None
            p1, st, _ = cv2.calcOpticalFlowPyrLK(img1, img2, p0, None)
            pts0 = p0[st.flatten()==1].reshape(-1,2)
            pts1 = p1[st.flatten()==1].reshape(-1,2)
            disp = pts1 - pts0
            return np.stack((pts0[:,0], pts0[:,1], disp[:,0], disp[:,1]), axis=1)
        if method == 'Speckle Tracking':
            h, w = img1.shape
            step = self.grid_step.get()
            pts = [(x,y) for y in range(0,h,step) for x in range(0,w,step)]
            pts1 = []
            half = 5
            tpl_h, tpl_w = 2*half+1, 2*half+1
            for x0,y0 in pts:
                tpl = img1[max(y0-half,0):y0+half+1, max(x0-half,0):x0+half+1]
                win = img2[max(y0-step,0):y0+step+1, max(x0-step,0):x0+step+1]
                if win.shape[0]<tpl_h or win.shape[1]<tpl_w:
                    pts1.append((x0,y0))
                    continue
                _,_,_,mx = cv2.minMaxLoc(cv2.matchTemplate(win, tpl, cv2.TM_CCOEFF_NORMED))
                dx = mx[0] - (win.shape[1]//2 - half)
                dy = mx[1] - (win.shape[0]//2 - half)
                pts1.append((x0+dx, y0+dy))
            disp = np.array(pts1) - np.array(pts)
            return np.concatenate((np.array(pts), disp), axis=1)
        return None

    def show_flow(self):
        self.canvas.delete('all')
        if not self.images:
            return

        # Fit raw image to canvas
        img = self.images[self.index]
        h_img, w_img = img.shape
        c_w, c_h = self.canvas.winfo_width(), self.canvas.winfo_height()
        scale = min(c_w/w_img, c_h/h_img)
        nw, nh = int(w_img*scale), int(h_img*scale)
        x_off, y_off = (c_w-nw)//2, (c_h-nh)//2
        rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        pil = Image.fromarray(rgb).resize((nw, nh))
        self.photo = ImageTk.PhotoImage(pil)
        self.canvas.create_image(x_off, y_off, anchor='nw', image=self.photo)

        step = self.grid_step.get()

        # Computed trail
        if self.show_trail_var.get() and self.index>0:
            color='blue'
            if hasattr(self.flows_inc[1], 'ndim') and self.flows_inc[1].ndim==3:
                for y0 in range(0,h_img,step):
                    for x0 in range(0,w_img,step):
                        x_prev, y_prev = float(x0), float(y0)
                        for t in range(1, self.index+1):
                            f = self.flows_inc[t]
                            yi = int(np.clip(round(y_prev),0,h_img-1))
                            xi = int(np.clip(round(x_prev),0,w_img-1))
                            dx, dy = f[yi, xi, 0], f[yi, xi, 1]
                            x_next, y_next = x_prev+dx, y_prev+dy
                            x1,y1 = x_prev*scale+x_off, y_prev*scale+y_off
                            x2,y2 = x_next*scale+x_off, y_next*scale+y_off
                            self.canvas.create_line(x1,y1,x2,y2, fill=color)
                            x_prev, y_prev = x_next, y_next
            else:
                for x0, y0, dx, dy in self.flows_inc[self.index].astype(float):
                    x_prev, y_prev = x0, y0
                    for t in range(1, self.index+1):
                        f = self.flows_inc[t]
                        mask = (f[:,0]==y_prev)&(f[:,1]==x_prev)
                        if not mask.any(): break
                        idx = np.where(mask)[0][0]
                        dx_t, dy_t = f[idx,2], f[idx,3]
                        x_next, y_next = x_prev+dx_t, y_prev+dy_t
                        x1,y1 = x_prev*scale+x_off, y_prev*scale+y_off
                        x2,y2 = x_next*scale+x_off, y_next*scale+y_off
                        self.canvas.create_line(x1,y1,x2,y2, fill=color)
                        x_prev, y_prev = x_next, y_next

        # Computed net head‑to‑tail from image0 to final image
        if self.show_net_var.get() and len(self.flows_inc) > 1:
            color = 'red'
            n_frames = len(self.flows_inc)  # total frames = images count

            for y0 in range(0, h_img, step):
                for x0 in range(0, w_img, step):
                    x_prev, y_prev = float(x0), float(y0)

                    # integrate every incremental flow from frame1 → final
                    for t in range(1, n_frames):
                        f = self.flows_inc[t]
                        if f is None:
                            break

                        if hasattr(f, 'ndim') and f.ndim == 3:
                            yi = int(np.clip(round(y_prev), 0, h_img-1))
                            xi = int(np.clip(round(x_prev), 0, w_img-1))
                            dx, dy = f[yi, xi, 0], f[yi, xi, 1]
                        else:
                            # sparse case (Lucas‑Kanade)
                            pts0 = f[:, :2]
                            disps = f[:, 2:]
                            # find nearest point
                            dists = np.hypot(pts0[:,0]-x_prev, pts0[:,1]-y_prev)
                            idx = np.argmin(dists)
                            dx, dy = disps[idx]

                        x_prev += dx
                        y_prev += dy

                    # draw one net arrow from (x0,y0) to final (x_prev,y_prev)
                    x1 = x0*scale + x_off
                    y1 = y0*scale + y_off
                    x2 = x_prev*scale + x_off
                    y2 = y_prev*scale + y_off
                    self.canvas.create_line(x1, y1, x2, y2, fill=color)

        # NPY-based trail
        if self.show_npy_trail_var.get() and self.npy_inc and self.index > 0:
            color = 'green'
            for y0 in range(0, h_img, step):
                for x0 in range(0, w_img, step):
                    x_prev, y_prev = float(x0), float(y0)
                    # walk through each incremental npy flow up to current frame
                    for t in range(1, self.index+1):
                        f = self.npy_inc[t]
                        if f is None:
                            break
                        yi = int(np.clip(round(y_prev), 0, h_img-1))
                        xi = int(np.clip(round(x_prev), 0, w_img-1))
                        # invert both x and y displacements
                        dx, dy = -f[yi, xi, 0], -f[yi, xi, 1]
                        x_next, y_next = x_prev + dx, y_prev + dy
                        x1 = x_prev*scale + x_off
                        y1 = y_prev*scale + y_off
                        x2 = x_next*scale + x_off
                        y2 = y_next*scale + y_off
                        self.canvas.create_line(x1, y1, x2, y2, fill=color)
                        x_prev, y_prev = x_next, y_next


        # NPY‑based net (head‑to‑tail from image0 to final image)
        if self.show_npy_net_var.get() and len(self.npy_inc) > 1:
            color = 'lime'
            n_frames = len(self.npy_inc)
            for y0 in range(0, h_img, step):
                for x0 in range(0, w_img, step):
                    net_dx, net_dy = 0.0, 0.0
                    x_prev, y_prev = float(x0), float(y0)

                    # accumulate inverted displacements through to the last npy file
                    for t in range(1, n_frames-1):
                        f = self.npy_inc[t]
                        if f is None:
                            break
                        yi = int(np.clip(round(y_prev), 0, h_img-1))
                        xi = int(np.clip(round(x_prev), 0, w_img-1))
                        dx, dy = -f[yi, xi, 0], -f[yi, xi, 1]  # invert to match trail
                        net_dx += dx
                        net_dy += dy
                        x_prev += dx
                        y_prev += dy
                    
                    if self.show_net_error_var.get():
                        net_c_dx, net_c_dy = 0.0, 0.0
                        xc_prev, yc_prev = float(x0), float(y0)
                        # integrate every incremental flow from frame1 → final
                        for t in range(1, len(self.flows_inc)):
                            f = self.flows_inc[t]
                            if f is None:
                                break

                            if hasattr(f, 'ndim') and f.ndim == 3:
                                yi = int(np.clip(round(yc_prev), 0, h_img-1))
                                xi = int(np.clip(round(xc_prev), 0, w_img-1))
                                dx, dy = f[yi, xi, 0], f[yi, xi, 1]
                            else:
                                # sparse case (Lucas‑Kanade)
                                pts0 = f[:, :2]
                                disps = f[:, 2:]
                                # find nearest point
                                dists = np.hypot(pts0[:,0]-xc_prev, pts0[:,1]-yc_prev)
                                idx = np.argmin(dists)
                                dx, dy = disps[idx]

                        xc_prev += dx
                        yc_prev += dy
                        net_c_dx += dx
                        net_c_dy += dy

                    # draw one arrow from the original grid point to the final location
                    x1 = x0*scale + x_off
                    y1 = y0*scale + y_off
                    x2 = (x0 + net_dx)*scale + x_off
                    y2 = (y0 + net_dy)*scale + y_off
                    # net error lines
                    if self.show_net_error_var.get():
                        x3 = net_c_dx*scale + x_off
                        y3 = net_c_dy*scale + y_off
                        self.canvas.create_line(x2, y2, x3, y3, fill='pink')
                    self.canvas.create_line(x1, y1, x2, y2, fill='lime')



In [48]:
class FeatureRegistrationPage(Page):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # --- Controls bar ---
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)

        ttk.Button(ctrl, text="Select Folder", command=self.load_images).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Previous",      command=self.prev_image).pack(side='left', padx=5)
        ttk.Button(ctrl, text="Next",          command=self.next_image).pack(side='left', padx=5)

        # Relative threshold slider (background minus pixel)
        ttk.Label(ctrl, text="Rel Thresh:").pack(side='left', padx=(20,5))
        self.rel_thresh = tk.IntVar(value=15)
        rel_sld = ttk.Scale(ctrl, from_=0, to=100, variable=self.rel_thresh,
                            command=lambda e: self.show_registration())
        rel_sld.pack(side='left')
        self.rel_lbl = ttk.Label(ctrl, text=str(self.rel_thresh.get()))
        self.rel_lbl.pack(side='left', padx=(5,10))
        self.rel_thresh.trace_add("write",
            lambda *a: self.rel_lbl.config(text=str(self.rel_thresh.get()))
        )

        # Min contour area slider
        ttk.Label(ctrl, text="Min Area:").pack(side='left', padx=(10,5))
        self.min_area = tk.IntVar(value=50)
        area_sld = ttk.Scale(ctrl, from_=1, to=2000, variable=self.min_area,
                             command=lambda e: self.show_registration())
        area_sld.pack(side='left', padx=(0,5))
        self.area_lbl = ttk.Label(ctrl, text=str(self.min_area.get()))
        self.area_lbl.pack(side='left')
        self.min_area.trace_add("write",
            lambda *a: self.area_lbl.config(text=str(self.min_area.get()))
        )

        # --- Canvas ---
        self.canvas = tk.Canvas(self, bg='#f0f0f0')
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)

        # --- State ---
        self.images = []    # grayscale frames
        self.index  = 0
        self.tk_imgs = []   # PhotoImage refs

    def load_images(self):
        folder = filedialog.askdirectory()
        if not folder:
            return

        fnames = [f for f in os.listdir(folder) if f.lower().endswith('.png')]
        fnames.sort(key=lambda fn: int(os.path.splitext(fn)[0]) if fn[:-4].isdigit() else fn)
        self.images = [
            cv2.imread(os.path.join(folder, fn), cv2.IMREAD_GRAYSCALE)
            for fn in fnames
        ]
        if not self.images:
            return

        self.index = 0
        self.show_registration()

    def _segment_local_dark_spots(self, img):
        """
        Finds spots darker than their local background.
        Returns filtered contours.
        """
        # 1) compute local background by Gaussian blur
        #    kernel size = roughly 1/10th of min dimension, odd
        #k = max(5, (min(img.shape)//10)//2*2+1)
        k= 3
        bg = cv2.GaussianBlur(img, (k,k), 0)

        # 2) difference image: bg - img
        diff = cv2.subtract(bg, img)

        # 3) threshold difference
        thresh = self.rel_thresh.get()
        _, mask = cv2.threshold(diff, thresh, 255, cv2.THRESH_BINARY)

        # 4) clean mask
        kern = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kern, iterations=1)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kern, iterations=2)

        # 5) find contours
        cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # 6) filter by area
        h, w = img.shape
        min_a = self.min_area.get()
        max_a = 0.8 * h * w
        good = [c for c in cnts
                if min_a <= cv2.contourArea(c) <= max_a]
        return good

    def show_registration(self):
        self.canvas.delete("all")
        self.tk_imgs.clear()
        if not self.images:
            return

        img = self.images[self.index]

        # prepare three variants
        denoised = cv2.bilateralFilter(img, d=9, sigmaColor=75, sigmaSpace=75)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced = clahe.apply(denoised)
        variants = [('Raw', img), ('Denoised', denoised), ('Enhanced', enhanced)]

        # layout thumbnails
        self.canvas.update_idletasks()
        W = self.canvas.winfo_width()
        pad = 20
        thumb_w = (W - pad*4) // 3

        for i, (name, vimg) in enumerate(variants):
            vis = cv2.cvtColor(vimg, cv2.COLOR_GRAY2BGR)
            # detect local dark spots
            cnts = self._segment_local_dark_spots(vimg)
            # draw outlines in green
            cv2.drawContours(vis, cnts, -1, (0,255,0), 1)

            # resize to thumbnail
            h, w = vis.shape[:2]
            thumb_h = int(thumb_w * h / w)
            thumb = cv2.resize(vis, (thumb_w, thumb_h), interpolation=cv2.INTER_AREA)

            imgtk = ImageTk.PhotoImage(Image.fromarray(thumb))
            self.tk_imgs.append(imgtk)

            x = pad + i*(thumb_w + pad)
            y = pad
            self.canvas.create_image(x, y, anchor='nw', image=imgtk)
            self.canvas.create_text(
                x + thumb_w//2, y + thumb_h + 12,
                text=name, font=("Helvetica", 9, "bold"), fill="#333"
            )

        total = len(self.images)
        self.master.master.title(
            f"Local Dark Spots — Img{self.index}/{total-1}  "
            f"(Δ>{self.rel_thresh.get()}, Area>{self.min_area.get()})"
        )

    def next_image(self):
        if self.index < len(self.images)-1:
            self.index += 1
            self.show_registration()

    def prev_image(self):
        if self.index > 0:
            self.index -= 1
            self.show_registration()


In [49]:
import tkinter as tk
from tkinter import ttk
from PIL import Image, ImageTk
import numpy as np
import imageio
from scipy.ndimage import gaussian_filter, map_coordinates
import cv2

def generate_oct_like_image(h, w, struct_sigma=12, speckle_scale=0.2, speckle_contrast=5.0):
    base_struct = gaussian_filter(np.random.randn(h, w), sigma=struct_sigma)
    norm = (base_struct - base_struct.min()) / base_struct.ptp()
    tissue = norm * 255
    speckle = np.random.rayleigh(scale=speckle_scale, size=(h, w))
    speckle = (speckle - speckle.mean()) * speckle_contrast + speckle.mean()
    full_img = np.clip(tissue * (1+speckle), 0, 255)
    return full_img.astype(np.float32)

def generate_displacement_field(h, w, magnitude=10, smoothness=30):
    dx = gaussian_filter(np.random.randn(h, w), sigma=smoothness) * magnitude
    dy = gaussian_filter(np.random.randn(h, w), sigma=smoothness) * magnitude
    return dx, dy

def apply_displacement(img, dx, dy):
    h, w = img.shape
    Y, X = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
    coords = [Y + dy, X + dx]
    warped = map_coordinates(img, coords, order=1, mode='reflect')
    return warped

class ImageSynthesisPage(ttk.Frame):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)

        ttk.Label(ctrl, text="Frames:").pack(side='left')
        self.num_frames = tk.IntVar(value=9)
        ttk.Spinbox(ctrl, from_=1, to=20, width=4, textvariable=self.num_frames).pack(side='left')

        ttk.Label(ctrl, text="Deform:").pack(side='left', padx=(10,0))
        self.deform_scale = tk.DoubleVar(value=300.0)
        ttk.Scale(ctrl, from_=0, to=400, variable=self.deform_scale, length=100).pack(side='left')
        ttk.Label(ctrl, textvariable=self.deform_scale, width=4).pack(side='left')

        ttk.Label(ctrl, text="Noise:").pack(side='left', padx=(10,0))
        self.noise_scale = tk.DoubleVar(value=0.12)
        ttk.Scale(ctrl, from_=0, to=0.4, variable=self.noise_scale, length=100).pack(side='left')
        ttk.Label(ctrl, textvariable=self.noise_scale, width=4).pack(side='left')

        ttk.Label(ctrl, text="Struct.:").pack(side='left', padx=(10,0))
        self.struct_sigma = tk.DoubleVar(value=10.0)
        ttk.Scale(ctrl, from_=1, to=20, variable=self.struct_sigma, length=100).pack(side='left')
        ttk.Label(ctrl, textvariable=self.struct_sigma, width=4).pack(side='left')

        ttk.Button(ctrl, text="Generate Sequence", command=self.generate_sequence).pack(side='left', padx=10)
        ttk.Button(ctrl, text="Export Sequence", command=self.export_sequence).pack(side='left')

        self.show_cumulative = tk.BooleanVar(value=False)
        self.show_incremental = tk.BooleanVar(value=False)
        ttk.Checkbutton(ctrl, text="Show Cumulative", variable=self.show_cumulative, command=self.refresh_current_frame).pack(side='left')
        ttk.Checkbutton(ctrl, text="Show Incremental", variable=self.show_incremental, command=self.refresh_current_frame).pack(side='left')

        nav = ttk.Frame(self)
        nav.pack(fill='x')
        ttk.Button(nav, text="⟨ Prev", command=self.show_prev).pack(side='left')
        ttk.Button(nav, text="Next ⟩", command=self.show_next).pack(side='right')
        self.frame_label = ttk.Label(nav, text="Frame 0/0")
        self.frame_label.pack(side='left', expand=True)

        self.canvas = tk.Canvas(self, bg='black', width=364, height=364)
        self.canvas.pack(fill='both', expand=True, padx=10, pady=10)
        self.tk_img = None

        self.original = None
        self.sequence = []
        self.curr_idx = 0
        self.dx_total = None
        self.dy_total = None
        self.per_frame_displacements = []

    def generate_sequence(self):
        h, w = 364, 364
        base = generate_oct_like_image(h, w, struct_sigma=self.struct_sigma.get(), speckle_scale=self.noise_scale.get())
        self.original = base.astype(np.uint8)
        self.sequence = [self.original]
        self.per_frame_displacements = []

        self.dx_total = np.zeros((h, w), dtype=np.float32)
        self.dy_total = np.zeros((h, w), dtype=np.float32)

        for _ in range(self.num_frames.get()):
            dx, dy = generate_displacement_field(h, w, magnitude=self.deform_scale.get(), smoothness=30)
            self.per_frame_displacements.append((dx, dy))
            self.dx_total += dx
            self.dy_total += dy
            warped = apply_displacement(base, self.dx_total, self.dy_total)
            self.sequence.append(np.clip(warped, 0, 255).astype(np.uint8))

        self.curr_idx = 0
        self.show_frame(0)

    def show_frame(self, idx):
        arr = self.sequence[idx].copy()
        if self.show_incremental.get() and idx > 0:
            dx, dy = self.per_frame_displacements[idx - 1]
            arr = self.overlay_displacement_arrows(arr, dx, dy, color=(0, 0, 255))
        if self.show_cumulative.get():
            arr = self.overlay_displacement_arrows(arr, self.dx_total, self.dy_total, color=(0, 0, 200))
        self.display_array(arr)
        self.frame_label.config(text=f"Frame {idx+1}/{len(self.sequence)}")

    def refresh_current_frame(self):
        self.show_frame(self.curr_idx)

    def show_prev(self):
        if not self.sequence:
            return
        self.curr_idx = max(0, self.curr_idx - 1)
        self.show_frame(self.curr_idx)

    def show_next(self):
        if not self.sequence:
            return
        self.curr_idx = min(len(self.sequence)-1, self.curr_idx + 1)
        self.show_frame(self.curr_idx)

    def display_array(self, arr):
        pil = Image.fromarray(arr.astype(np.uint8))
        self.tk_img = ImageTk.PhotoImage(pil)
        self.canvas.config(width=arr.shape[1], height=arr.shape[0])
        self.canvas.delete('all')
        self.canvas.update_idletasks()
        canvas_w = self.canvas.winfo_width()
        canvas_h = self.canvas.winfo_height()
        img_w, img_h = self.tk_img.width(), self.tk_img.height()
        x = (canvas_w - img_w) // 2
        y = (canvas_h - img_h) // 2
        self.canvas.create_image(x, y, anchor='nw', image=self.tk_img)

    def overlay_displacement_arrows(self, img, dx, dy, step=20, color=(0, 0, 255)):
        vis = cv2.cvtColor(img.copy(), cv2.COLOR_GRAY2BGR)
        h, w = img.shape
        for y in range(0, h, step):
            for x in range(0, w, step):
                pt1 = (x, y)
                pt2 = (int(x + dx[y, x]), int(y + dy[y, x]))
                cv2.arrowedLine(vis, pt1, pt2, color, 1, tipLength=0.3)
        return cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)

    def export_sequence(self):
        if not self.sequence:
            return

        folder = filedialog.askdirectory(title="Select Export Folder")
        if not folder:
            return

        # Export images
        for i, img in enumerate(self.sequence):
            img_path = os.path.join(folder, f"frame_{i:03d}.png")
            imageio.imwrite(img_path, img.astype(np.uint8))

        # Export incremental deformation fields (from frame i to i+1)
        for i, (dx, dy) in enumerate(self.per_frame_displacements):
            np.save(os.path.join(folder, f"dx_{i:03d}_{i+1:03d}.npy"), dx)
            np.save(os.path.join(folder, f"dy_{i:03d}_{i+1:03d}.npy"), dy)

        # Optional: export cumulative deformation fields
        np.save(os.path.join(folder, f"dx_cumulative.npy"), self.dx_total)
        np.save(os.path.join(folder, f"dy_cumulative.npy"), self.dy_total)

        print(f"Exported {len(self.sequence)} frames and {len(self.per_frame_displacements)} deformation fields to {folder}")





In [50]:
import os
import glob
import tkinter as tk
from tkinter import ttk, filedialog
import itertools
import numpy as np
import cv2
from scipy.ndimage import map_coordinates

class BenchmarkingPage(ttk.Frame):
    def __init__(self, parent, **kwargs):
        super().__init__(parent, **kwargs)

        # Controls frame
        ctrl = ttk.Frame(self)
        ctrl.pack(fill='x', pady=10)

        # Dataset selector
        ttk.Button(ctrl, text="Select Dataset Folder", command=self.load_dataset_root).pack(side='left', padx=5)
        self.dataset_label = ttk.Label(ctrl, text="No folder selected")
        self.dataset_label.pack(side='left', padx=5)

        # Algorithm selector (only Farneback for grid search)
        ttk.Label(ctrl, text="Algorithm:").pack(side='left', padx=(20,5))
        self.algorithms = ['Farneback']
        self.selected_alg = tk.StringVar(value='Farneback')
        alg_combo = ttk.Combobox(ctrl, textvariable=self.selected_alg,
                                 values=self.algorithms, state='readonly', width=12)
        alg_combo.pack(side='left')

        # Farneback parameter entries (comma-separated lists)
        param_defs = [
            ("pyr_scale", "0.3,0.5,0.7"),
            ("levels",    "1,2,3,4"),
            ("winsize",   "9,15,21"),
            ("iterations","1,3,5"),
            ("poly_n",    "5,7"),
            ("poly_sigma","1.1,1.5,1.9"),
        ]
        self.param_vars = {}
        for label, default in param_defs:
            var = tk.StringVar(value=default)
            self.param_vars[label] = var
            ttk.Label(ctrl, text=f"{label}:").pack(side='left', padx=(10,0))
            ttk.Entry(ctrl, textvariable=var, width=8).pack(side='left')

        # Run button
        ttk.Button(ctrl, text="Run Grid Search", command=self.run_benchmark).pack(side='left', padx=10)
        self.status = ttk.Label(ctrl, text="Idle")
        self.status.pack(side='left', padx=5)

        # Results table
        cols = ("params", "avgEPE", "avgWarpErr")
        self.table = ttk.Treeview(self, columns=cols, show="headings")
        for c in cols:
            self.table.heading(c, text=c)
            self.table.column(c, width=120)
        self.table.pack(fill='both', expand=True, padx=10, pady=10)

    def load_dataset_root(self):
        folder = filedialog.askdirectory(title="Select Synthetic OCT Dataset Folder")
        if not folder:
            return
        self.dataset_root = folder
        self.dataset_label.config(text=os.path.basename(folder))
        self.status.config(text="Dataset loaded")

    def run_benchmark(self):
        # Ensure dataset selected
        if not hasattr(self, 'dataset_root'):
            self.status.config(text="Select dataset folder first")
            return

        # Parse parameter lists
        lists = {}
        for name, var in self.param_vars.items():
            vals = [float(x) if '.' in x else int(x) for x in var.get().split(',')]
            lists[name] = vals

        # Generate all combinations
        combos = list(itertools.product(
            lists['pyr_scale'], lists['levels'], lists['winsize'],
            lists['iterations'], lists['poly_n'], lists['poly_sigma']
        ))

        # Clear previous results
        for i in self.table.get_children():
            self.table.delete(i)

        # Dataset subsets
        sets = sorted(d for d in os.listdir(self.dataset_root)
                      if os.path.isdir(os.path.join(self.dataset_root, d)))
        total = len(combos)
        count = 0

        # Grid search loop
        for combo in combos:
            ps, lv, ws, iters, pn, psig = combo
            epe_accum, werr_accum = [], []

            for subset in sets:
                path = os.path.join(self.dataset_root, subset)
                frames = sorted(glob.glob(os.path.join(path, "frame_*.png")))
                imgs = [cv2.imread(f, cv2.IMREAD_GRAYSCALE) for f in frames]
                epe_vals, werr_vals = [], []

                for i in range(len(imgs)-1):
                    img1, img2 = imgs[i], imgs[i+1]
                    dx_gt = np.load(os.path.join(path, f"dx_{i:03d}_{i+1:03d}.npy"))
                    dy_gt = np.load(os.path.join(path, f"dy_{i:03d}_{i+1:03d}.npy"))

                    flow = cv2.calcOpticalFlowFarneback(
                        img1, img2, None,
                        pyr_scale=ps, levels=int(lv),
                        winsize=int(ws), iterations=int(iters),
                        poly_n=int(pn), poly_sigma=psig,
                        flags=0
                    )
                    dx_pred, dy_pred = flow[...,0], flow[...,1]
                    epe_vals.append(np.sqrt((dx_gt-dx_pred)**2 + (dy_gt-dy_pred)**2).mean())
                    warped = self.warp_image(img1.astype(np.float32), dx_pred, dy_pred)
                    werr_vals.append(((warped - img2.astype(np.float32))**2).mean())

                if epe_vals:
                    epe_accum.append(np.mean(epe_vals))
                    werr_accum.append(np.mean(werr_vals))

            # Compute and insert average metrics
            avg_epe = float(np.mean(epe_accum)) if epe_accum else np.nan
            avg_werr = float(np.mean(werr_accum)) if werr_accum else np.nan
            param_str = f"ps={ps},lv={lv},ws={ws},it={iters},pn={pn},psig={psig}"
            self.table.insert("", "end", values=(param_str, f"{avg_epe:.4f}", f"{avg_werr:.4f}"))

            count += 1
            self.status.config(text=f"Test {count}/{total}")
            self.update_idletasks()

        self.status.config(text="Grid search complete")

    def warp_image(self, img, dx, dy):
        h, w = img.shape
        Y, X = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')
        coords = [Y + dy, X + dx]
        return map_coordinates(img, coords, order=1, mode='reflect')


In [51]:
class MainApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("OCT Deformation Tracking Suite")
        self.geometry("1200x800")
        style = ttk.Style(self)
        style.theme_use('clam')
        style.configure('TNotebook.Tab', padding=(10, 10))

        notebook = ttk.Notebook(self)
        notebook.pack(fill='both', expand=True)

        # add all your existing pages + the new one
        pages = [
            (HomePage,                "Home"),
            (PreprocessingPage,       "Preprocessing"),
            (OpticalFlowPage,         "Optical Flow"),
            (FeatureRegistrationPage, "Feature Reg."),
            (ImageSynthesisPage,      "Image Synthesis"),
            (BenchmarkingPage,        "Benchmarking"),
        ]

        for PageClass, title in pages:
            page = PageClass(notebook)
            notebook.add(page, text=title)

if __name__ == "__main__":
    MainApp().mainloop()