In [1]:
import geopandas as gpd
import rasterio
from PIL import Image, ImageTk, ImageDraw
import tkinter as tk
from tkinter import ttk, messagebox
import numpy as np
import os

# File paths
GPKG_PATH = "/Users/ilseoplee/cape_town_annotation_checker/2.sampling_for_qc/2nd_sampled_part4_group2.gpkg"  # Adjust this path as needed
IMAGE_FOLDER = "/Users/ilseoplee/cape_town_annotation_checker/1.db_pipeline/download/images"
OUTPUT_PATH = GPKG_PATH

gdf = gpd.read_file(GPKG_PATH)

# QC columns
qc_cols = ["PV_normal_qc", "PV_heater_qc", "PV_pool_qc", "PV_heater_mat_qc", "uncertflag_qc", "delete_qc", "resizing_qc"]
for col in qc_cols:
    if col not in gdf.columns:
        gdf[col] = 0

class QCChecker:
    def __init__(self, master):
        self.master = master
        self.index = self.find_next_unchecked_index()
        self.resizing_selected = False
        self.selected_qcs = set()
        self.zoom_scale = 1.0
        self.history = []  # For back button

        self.label_frame = tk.Frame(master)
        self.label_frame.pack()
        self.label_original = tk.Label(self.label_frame)
        self.label_original.pack(side=tk.LEFT, padx=5)
        self.label_annotated = tk.Label(self.label_frame)
        self.label_annotated.pack(side=tk.LEFT, padx=5)

        self.info = tk.Label(master, text="", font=("Arial", 12), justify="left")
        self.info.pack()

        button_frame = ttk.Frame(master)
        button_frame.pack()
        for i, col in enumerate(qc_cols):
            ttk.Button(button_frame, text=col, command=lambda c=col: self.mark(c)).grid(row=0, column=i, padx=4)

        ttk.Button(button_frame, text="Back", command=self.go_back).grid(row=1, column=0, columnspan=2, pady=5)

        zoom_frame = ttk.Frame(master)
        zoom_frame.pack(pady=5)
        ttk.Button(zoom_frame, text="Zoom +", command=self.zoom_in).pack(side=tk.LEFT, padx=5)
        ttk.Button(zoom_frame, text="Zoom -", command=self.zoom_out).pack(side=tk.LEFT, padx=5)

        self.load_image()

    def find_next_unchecked_index(self):
        for i, row in gdf.iterrows():
            if not any(row[qc] == 1 for qc in qc_cols):
                return i
        print("All annotations already checked. Starting from the beginning.")
        return 0

    def mark(self, col):
        if col in {"resizing_qc", "uncertflag_qc"}:
            self.selected_qcs = {col}
            self.resizing_selected = True

            if col == "uncertflag_qc":
                self.save_current_images_as_png()

            print(f"{col} selected. Please choose one more QC label.")
            return

        elif self.resizing_selected:
            self.selected_qcs.add(col)

            if "uncertflag_qc" in self.selected_qcs:
                self.save_current_images_as_png()

            self.save_and_advance()
            self.advance_or_quit()
            return

        else:
            self.selected_qcs = {col}
            self.save_and_advance()
            self.advance_or_quit()

    def save_and_advance(self):
        self.history.append(self.index)  # Save index for back
        for qc in qc_cols:
            gdf.at[self.index, qc] = 1 if qc in self.selected_qcs else 0
        try:
            gdf.to_file(OUTPUT_PATH, driver="GPKG")
            print(f"Saved at ID {gdf.iloc[self.index].get('id', self.index)}")
        except Exception as e:
            print(f"Save failed: {e}")
        self.selected_qcs = set()
        self.resizing_selected = False

    def advance_or_quit(self):
        if self.index < len(gdf) - 1:
            self.index += 1
            self.load_image()
        else:
            messagebox.showinfo("Done", "QC annotations are complete!")
            self.master.quit()

    def go_back(self):
        if self.history:
            self.index = self.history.pop()
            print(f"Returning to index: {self.index}")
            self.load_image()
        else:
            messagebox.showinfo("Info", "No previous image to go back to.")

    def save_current_images_as_png(self, save_dir="saved_screens"):
        os.makedirs(save_dir, exist_ok=True)
        row = gdf.iloc[self.index]
        image_id = row.get("id", self.index)

        original_path = os.path.join(save_dir, f"{image_id}_original.png")
        annotated_path = os.path.join(save_dir, f"{image_id}_annotated.png")

        try:
            self.img_original_pil.save(original_path, "PNG")
            self.img_annotated_pil.save(annotated_path, "PNG")
            print(f"Saved images for ID {image_id}")
        except Exception as e:
            print(f"Failed to save images for ID {image_id}: {e}")

    def zoom_in(self):
        self.zoom_scale /= 1.2
        self.load_image()

    def zoom_out(self):
        self.zoom_scale *= 1.2
        self.load_image()

    def load_image(self):
        row = gdf.iloc[self.index]
        image_name = row.get("image_name")
        image_path = os.path.join(IMAGE_FOLDER, image_name + ".tif")

        try:
            with rasterio.open(image_path) as src:
                geom = row.geometry
                transform = src.transform
                centroid = geom.centroid
                cx, cy = ~transform * (centroid.x, centroid.y)

                base_crop = 300
                half_w = int(base_crop * self.zoom_scale)
                half_h = int(base_crop * self.zoom_scale)

                box_crop = (
                    max(0, int(cx - half_w)),
                    max(0, int(cy - half_h)),
                    min(src.width, int(cx + half_w)),
                    min(src.height, int(cy + half_h))
                )

                window = rasterio.windows.Window(
                    col_off=box_crop[0],
                    row_off=box_crop[1],
                    width=box_crop[2] - box_crop[0],
                    height=box_crop[3] - box_crop[1]
                )

                data = src.read([1, 2, 3], window=window)
                win_transform = src.window_transform(window)

                rgb = np.transpose(data, (1, 2, 0))
                rgb = np.nan_to_num(rgb)
                if rgb.dtype != np.uint8:
                    rgb = ((rgb - rgb.min()) / (rgb.ptp() + 1e-6) * 255).astype(np.uint8)

                img_original = Image.fromarray(rgb)
                img_annotated = img_original.copy()

                draw = ImageDraw.Draw(img_annotated)
                if hasattr(geom, "exterior"):
                    coords = list(geom.exterior.coords)
                    pixels = [~win_transform * (x, y) for x, y in coords]
                    pixels = [(int(x), int(y)) for x, y in pixels]
                    if len(pixels) > 2:
                        draw.polygon(pixels, outline="red", width=3)

                for img in [img_original, img_annotated]:
                    if img.width > 800 or img.height > 800:
                        img.thumbnail((800, 800), Image.LANCZOS)

                self.img_original_pil = img_original
                self.img_annotated_pil = img_annotated

                self.tk_img_original = ImageTk.PhotoImage(img_original)
                self.tk_img_annotated = ImageTk.PhotoImage(img_annotated)

                self.label_original.configure(image=self.tk_img_original)
                self.label_annotated.configure(image=self.tk_img_annotated)

                self.label_original.image = self.tk_img_original
                self.label_annotated.image = self.tk_img_annotated

                self.info.config(
                    text=(
                        f"ID: {row.get('id', 'NA')} | image: {image_name} | annotator: {row.get('annotator', 'NA')}\n"
                        f"PV_normal: {row.get('PV_normal')}, "
                        f"PV_heater: {row.get('PV_heater')}, "
                        f"PV_pool: {row.get('PV_pool')}, "
                        f"uncertflag: {row.get('uncertflag')}"
                    )
                )
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")

# Start GUI
root = tk.Tk()
root.title("Annotation QC Checker")
app = QCChecker(root)
root.mainloop()


: 