In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

the_blind_flight_synapse_drive_ps_1_path = kagglehub.competition_download('the-blind-flight-synapse-drive-ps-1')

print('Data source import complete.')


In [None]:
# Source - https://stackoverflow.com/a
# Posted by Shaima' safaaldin Bahaaldin, modified by community. See post 'Timeline' for change history
# Retrieved 2025-12-20, License - CC BY-SA 4.0


!pip install "numpy<2"


In [None]:

import cv2
import numpy as np
import matplotlib.pyplot as plt

def order_points(pts):
    """ Standard corner ordering: TL, TR, BR, BL """
    rect = np.zeros((4, 2), dtype="float32")
    s = pts.sum(axis=1)
    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]
    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    return rect

def intelligent_grid_slice(image_path, target_size=800):
    img = cv2.imread(str(image_path))
    if img is None: return None, None, "Error"

    orig = img.copy()
    h, w = img.shape[:2]
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 1. Edge Detection (Lower thresholds to catch faint lines)
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    thresh = cv2.adaptiveThreshold(blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 11, 2)

    # 2. Find Contours
    cnts, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = sorted(cnts, key=cv2.contourArea, reverse=True)

    grid_contour = None

    for c in cnts:
        peri = cv2.arcLength(c, True)
        approx = cv2.approxPolyDP(c, 0.02 * peri, True)

        # We look for a 4-sided polygon that is LARGE (at least 50% of image area)
        # This prevents picking up a small square inside the board
        if len(approx) == 4 and cv2.contourArea(c) > (h * w * 0.5):
            grid_contour = approx
            break

    # 3. DECISION LOGIC
    if grid_contour is not None:
        # CASE A: Corners are visible (Standard Skewed Image)
        rect = order_points(grid_contour.reshape(4, 2))

        dst = np.array([
            [0, 0],
            [target_size - 1, 0],
            [target_size - 1, target_size - 1],
            [0, target_size - 1]], dtype="float32")

        M = cv2.getPerspectiveTransform(rect, dst)
        warped = cv2.warpPerspective(orig, M, (target_size, target_size))
        mode = "Contour De-Skew"

    else:
        # CASE B: No Corners found (Zoomed In / Overfitted Board)
        # We assume the image IS the grid. Just resize.
        warped = cv2.resize(orig, (target_size, target_size))
        mode = "Full Image Resize"

    # 4. SLICING
    cells = []
    rows, cols = 20, 20
    cell_h = target_size // rows
    cell_w = target_size // cols

    for y in range(rows):
        for x in range(cols):
            y1, y2 = y * cell_h, (y + 1) * cell_h
            x1, x2 = x * cell_w, (x + 1) * cell_w

            cell = warped[y1:y2, x1:x2]

            # Center Crop (Crucial to remove grid lines)
            crop = 4
            if cell.shape[0] > 2*crop and cell.shape[1] > 2*crop:
                cell = cell[crop:-crop, crop:-crop]

            cells.append(cell)

    return warped, np.array(cells), mode

# --- VISUALIZATION ---
# Replace with your test image path to verify
# warped, cells, mode = intelligent_grid_slice("test_image.png")
# print(f"Processing Mode: {mode}")
# plt.imshow(warped)
def visualize_processing(image_path):
    # 1. Run the intelligent slicer
    warped, cells, mode = intelligent_grid_slice(image_path)

    if warped is None:
        print("‚ùå Error processing image")
        return

    # 2. Setup Plot
    fig = plt.figure(figsize=(15, 10))

    # --- Subplot 1: Original Image ---
    ax1 = fig.add_subplot(2, 2, 1)
    orig_img = cv2.imread(str(image_path))
    ax1.imshow(cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB))
    ax1.set_title("Original Image", fontsize=14, fontweight='bold')
    ax1.axis('off')

    # --- Subplot 2: The "Warped/Flat" Result ---
    ax2 = fig.add_subplot(2, 2, 2)
    ax2.imshow(cv2.cvtColor(warped, cv2.COLOR_BGR2RGB))
    ax2.set_title(f"Processed Board ({mode})", fontsize=14, fontweight='bold', color='green')
    ax2.axis('off')

    # --- Subplot 3: The 400 Grid Cells ---
    # We create a 20x20 grid of subplots in the bottom half
    print(f"‚úÖ Sliced into {len(cells)} cells. Displaying grid...")

    # Create a nested gridspec for the bottom half
    import matplotlib.gridspec as gridspec
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 2]) # Top row (images), Bottom row (grid)

    # We already plotted top row, now let's do the bottom massive grid
    # To make it easier, let's just use a new figure logic for the grid or stitch them
    # Actually, simpler approach: Stitch cells back together with white borders for display

    rows, cols = 20, 20
    cell_h, cell_w, _ = cells[0].shape

    # Create a canvas with gaps between cells
    gap = 2
    grid_canvas_h = rows * cell_h + (rows + 1) * gap
    grid_canvas_w = cols * cell_w + (cols + 1) * gap
    grid_canvas = np.ones((grid_canvas_h, grid_canvas_w, 3), dtype=np.uint8) * 255 # White background

    idx = 0
    for y in range(rows):
        for x in range(cols):
            c_img = cells[idx]
            # Calculate placement
            y_start = gap + y * (cell_h + gap)
            x_start = gap + x * (cell_w + gap)

            # Place cell
            grid_canvas[y_start:y_start+cell_h, x_start:x_start+cell_w] = c_img
            idx += 1

    ax3 = fig.add_subplot(2, 1, 2)
    ax3.imshow(cv2.cvtColor(grid_canvas, cv2.COLOR_BGR2RGB))
    ax3.set_title("Final 20x20 Sliced Grid (With padding)", fontsize=14, fontweight='bold')
    ax3.axis('off')

    plt.tight_layout()
    plt.show()

# --- USAGE ---
# Replace with a real path to test
TEST_IMG = "/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images/0004.png"

visualize_processing(TEST_IMG)

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def order_points(pts):
    """
    Sorts 4 points: Top-Left, Top-Right, Bottom-Right, Bottom-Left
    """
    rect = np.zeros((4, 2), dtype="float32")

    # Top-Left has smallest sum(x+y), Bottom-Right has largest sum(x+y)
    s = pts.sum(axis=1)
    rect[0] = pts[np.argmin(s)]
    rect[2] = pts[np.argmax(s)]

    # Top-Right has smallest diff(y-x), Bottom-Left has largest diff(y-x)
    diff = np.diff(pts, axis=1)
    rect[1] = pts[np.argmin(diff)]
    rect[3] = pts[np.argmax(diff)]
    return rect

def align_by_corners(image_path):
    # 1. Load Image
    img = cv2.imread(image_path)
    if img is None: return None, None, "Error"
    orig = img.copy()

    # 2. Isolate Grid Lines (Morphological Trick)
    # This removes all terrain, leaving only the "mesh"
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 15, 5)

    # Kernels to extract lines
    scale = 20
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))

    # Filter out everything that isn't a long line
    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)
    grid_mask = cv2.addWeighted(mask_h, 0.5, mask_v, 0.5, 0)

    # 3. Detect Corners on the Clean Grid
    # maxCorners=0 means "unlimited", qualityLevel=0.2 filters weak corners
    corners = cv2.goodFeaturesToTrack(grid_mask, maxCorners=1000, qualityLevel=0.2, minDistance=20)

    if corners is None: return None, None, "No corners found"

    corners = np.int0(corners)
    corner_pts = corners.reshape(-1, 2) # Flatten to list of (x,y)

    # 4. Find the "Extreme" Corners (The bounding quadrilateral)
    # This finds the corners farthest apart to define the board area
    rect = order_points(corner_pts)
    (tl, tr, br, bl) = rect

    # 5. Perspective Transform (De-Skew)
    # We map these extreme corners to a perfect square (800x800)
    target_size = 800
    dst = np.array([
        [0, 0],
        [target_size - 1, 0],
        [target_size - 1, target_size - 1],
        [0, target_size - 1]], dtype="float32")

    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(orig, M, (target_size, target_size))

    return warped, grid_mask, "Success"

# --- VISUALIZATION ---
def visualize_deskew(image_path):
    warped, mask, status = align_by_corners(image_path)

    if warped is None:
        print(status)
        return

    fig, ax = plt.subplots(1, 3, figsize=(18, 6))

    # Show Original
    img = cv2.imread(image_path)
    ax[0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    ax[0].set_title("Original")

    # Show the "Clean" Grid Mask (What the computer sees)
    ax[1].imshow(mask, cmap='gray')
    ax[1].set_title("Morphological Grid Mask")

    # Show the De-skewed Result
    ax[2].imshow(cv2.cvtColor(warped, cv2.COLOR_BGR2RGB))
    ax[2].set_title("Aligned & Reskewed")

    plt.show()

visualize_deskew("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images/0009.png")

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

def get_morphological_mask(img):
    """ Reusing your morphological logic which worked perfectly """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 15, 5)

    scale = 20
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))

    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)

    # Combine to get the mesh (Grid lines = White, Cells = Black)
    grid_mask = cv2.addWeighted(mask_h, 0.5, mask_v, 0.5, 0)
    return grid_mask

def slice_from_mask(image_path):
    img = cv2.imread(image_path)
    if img is None: return None, None

    # 1. Get the Grid "Skeleton"
    grid_mask = get_morphological_mask(img)

    # 2. Invert: Make Grid Lines Black (0), Cells White (255)
    # We use a simple threshold to make sure it's binary
    _, binary_grid = cv2.threshold(grid_mask, 50, 255, cv2.THRESH_BINARY)
    inverted_mask = cv2.bitwise_not(binary_grid)

    # 3. Find Contours (The Cells)
    cnts, _ = cv2.findContours(inverted_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    cells = []
    cell_coords = []

    # 4. Filter and Crop
    img_area = img.shape[0] * img.shape[1]
    min_area = img_area / (30 * 30) # Approx smallest valid cell
    max_area = img_area / (10 * 10) # Approx largest valid cell (prevents picking full background)

    for c in cnts:
        area = cv2.contourArea(c)

        # Filter noise
        if area > min_area and area < max_area:
            # Get Bounding Box
            x, y, w, h = cv2.boundingRect(c)

            # Crop slightly inside the box to avoid the grid lines
            margin = 2
            if w > 2*margin and h > 2*margin:
                roi = img[y+margin : y+h-margin, x+margin : x+w-margin]

                # Store cell and its center (x, y) for sorting later
                center_x = x + w // 2
                center_y = y + h // 2

                cells.append(roi)
                cell_coords.append((center_x, center_y))

    # 5. Sort Cells (Top-to-Bottom, Left-to-Right)
    # This reconstructs the grid order roughly.
    # Note: If cells are missing, the indices will shift, but this gives a good visual order.
    if len(cell_coords) > 0:
        # Zip, sort by Y (row), then roughly by X (col)
        # Using a "row tolerance" to group items in the same row
        zipped = sorted(zip(cells, cell_coords), key=lambda k: k[1][1]) # Sort by Y first

        sorted_cells = []
        # Basic heuristic: Sort fully by Y, then assuming 20 items per row is risky if some are missing.
        # Instead, we just return the raw blobs for now to visualize.
        sorted_cells = [x[0] for x in zipped]
        return sorted_cells, inverted_mask

    return cells, inverted_mask

# --- VISUALIZATION ---
def visualize_blob_slice(image_path):
    cells, mask = slice_from_mask(image_path)

    if not cells:
        print("‚ùå No cells found!")
        return

    print(f"‚úÖ Extracted {len(cells)} candidate cells.")

    fig = plt.figure(figsize=(12, 8))

    # Show the Mask we used
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.imshow(mask, cmap='gray')
    ax1.set_title("Inverted Grid Mask (White = Cell)")
    ax1.axis('off')

    # Show the first 100 extracted cells in a grid
    ax2 = fig.add_subplot(1, 2, 2)

    # Create a canvas to stitch them for preview
    # We'll just show the first 64 (8x8) for sanity check
    preview_dim = int(np.sqrt(len(cells)))
    if preview_dim > 8: preview_dim = 8

    cell_h, cell_w = 32, 32 # Resize for display
    canvas = np.zeros((preview_dim * cell_h, preview_dim * cell_w, 3), dtype=np.uint8)

    idx = 0
    for i in range(preview_dim):
        for j in range(preview_dim):
            if idx < len(cells):
                resized = cv2.resize(cells[idx], (cell_w, cell_h))
                canvas[i*cell_h:(i+1)*cell_h, j*cell_w:(j+1)*cell_w] = resized
                idx += 1

    ax2.imshow(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB))
    ax2.set_title(f"Sample Extracted Cells (Unsorted)")
    ax2.axis('off')

    plt.tight_layout()
    plt.show()

# Replace with your image

visualize_blob_slice("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images/0009.png")

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from scipy.interpolate import Rbf

# ---------------- CONFIG ----------------
IMAGE_PATH = "/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images/0004.png"  # CHANGE THIS
GRID_N = 20

# ---------------- LOAD IMAGE ----------------
img = cv2.imread(IMAGE_PATH)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
h, w = img.shape[:2]

# ---------------- EDGE + LINE DETECTION ----------------
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 60, 150)

lines = cv2.HoughLinesP(
    edges,
    rho=1,
    theta=np.pi/180,
    threshold=180,
    minLineLength=250,
    maxLineGap=20
)

if lines is None:
    raise RuntimeError("No grid lines detected")

h_lines, v_lines = [], []
for l in lines:
    x1,y1,x2,y2 = l[0]
    if abs(y1-y2) < abs(x1-x2):
        h_lines.append((x1,y1,x2,y2))
    else:
        v_lines.append((x1,y1,x2,y2))

# ---------------- INTERSECTIONS ----------------
def intersect(l1, l2):
    x1,y1,x2,y2 = l1
    x3,y3,x4,y4 = l2

    A = np.array([[x2-x1, x3-x4],
                  [y2-y1, y3-y4]])
    B = np.array([x3-x1, y3-y1])

    if abs(np.linalg.det(A)) < 1e-6:
        return None

    t = np.linalg.solve(A, B)[0]
    return x1 + t*(x2-x1), y1 + t*(y2-y1)

pts = []
for hl in h_lines:
    for vl in v_lines:
        p = intersect(hl, vl)
        if p and 0 <= p[0] < w and 0 <= p[1] < h:
            pts.append(p)

pts = np.array(pts)
print("Intersections:", len(pts))

# ---------------- CLUSTER POINTS INTO GRID ----------------
K = (GRID_N+1)**2
pts = pts[np.random.choice(len(pts), min(len(pts), K*2), replace=False)]

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=K, n_init=10, random_state=0).fit(pts)
src = kmeans.cluster_centers_

# Sort row-major
src = src[np.lexsort((src[:,0], src[:,1]))]

# ---------------- IDEAL GRID ----------------
xs = np.linspace(0, w-1, GRID_N+1)
ys = np.linspace(0, h-1, GRID_N+1)
dst = np.array([(x,y) for y in ys for x in xs])

# ---------------- TPS VIA RBF ----------------
rbf_x = Rbf(src[:,0], src[:,1], dst[:,0], function="thin_plate")
rbf_y = Rbf(src[:,0], src[:,1], dst[:,1], function="thin_plate")

grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h))
map_x = rbf_x(grid_x, grid_y).astype(np.float32)
map_y = rbf_y(grid_x, grid_y).astype(np.float32)

warped = cv2.remap(img, map_x, map_y, interpolation=cv2.INTER_LINEAR)

# ---------------- VISUALIZE ----------------
plt.figure(figsize=(14,6))

plt.subplot(1,2,1)
plt.title("Original (Skewed)")
plt.imshow(img)
plt.scatter(src[:,0], src[:,1], s=5, c="red")
plt.axis("off")

plt.subplot(1,2,2)
plt.title("TPS Rectified (SciPy)")
plt.imshow(warped)
plt.axis("off")

plt.show()


In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN

def preprocess_for_grid(img):
    """
    Special preprocessing to kill texture (Lab/Forest) but keep black lines.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # 1. MEDIAN BLUR (The Fix for Lab Texture)
    # This removes high-frequency noise (like sand/checkerboard)
    # while preserving strong edges (grid lines).
    no_texture = cv2.medianBlur(gray, 7)

    # 2. Adaptive Threshold to find black lines
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)

    # Morphological Kernels to isolate grid lines
    scale = 25 # Slightly larger for robustness
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))

    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)

    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5))) # Dilate more to ensure connectivity

    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)

    points = []
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] > 10: # Ignore tiny noise
            points.append(centroids[i])

    if not points: return np.array([])

    # Clustering to merge close points (Fixes double-detection)
    points = np.array(points)
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)

    clean_points = []
    for label in set(clustering.labels_):
        cluster_pts = points[clustering.labels_ == label]
        clean_points.append(np.mean(cluster_pts, axis=0))

    return np.array(clean_points)

def sort_points_robust(points):
    # Cluster Y-coordinates to define Rows
    y_coords = points[:, 1].reshape(-1, 1)
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(y_coords)

    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)

    # Sort rows top-to-bottom
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))

    final_rows = []
    for k in sorted_keys:
        row_pts = rows_dict[k]
        row_pts.sort(key=lambda p: p[0]) # Sort left-to-right
        final_rows.append(np.array(row_pts))

    return final_rows

def reconstruct_grid_safe(image_path, cell_size=40):
    """
    Main function with FALLBACK. Never returns None.
    """
    # 1. Load Safe
    img = cv2.imread(str(image_path))
    if img is None:
        print(f"‚ùå Error: Cannot read {image_path}")
        return np.zeros((800, 800, 3), dtype=np.uint8) # Return black square on fail

    h, w = img.shape[:2]
    fallback_img = cv2.resize(img, (800, 800)) # Default fallback

    try:
        # 2. Detect
        points = get_intersections(img)

        # Validation: Do we have enough points for a grid?
        if len(points) < 50:
            print(f"‚ö†Ô∏è Low detection ({len(points)} pts). Using fallback resize.")
            return fallback_img

        # 3. Organize
        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 10: # Too few rows
            return fallback_img

        # 4. Warp Cells
        num_rows = len(grid_rows)
        # Use median row length to define width
        row_lens = [len(r) for r in grid_rows]
        num_cols = int(np.median(row_lens))

        canvas_h = (num_rows - 1) * cell_size
        canvas_w = (num_cols - 1) * cell_size

        # Create Grey Canvas (so we can see if cells are missing)
        final_grid = np.ones((canvas_h, canvas_w, 3), dtype=np.uint8) * 128

        for r in range(num_rows - 1):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]

            # Simple matching: find closest X in bottom row
            for pt_top in row_top:
                # Find closest point in bottom row (by X coordinate)
                # Filter to only look at points within reasonable X distance (skew limit)
                candidates = [p for p in row_btm if abs(p[0] - pt_top[0]) < 60]

                if not candidates: continue

                # We need a Top-Right neighbor to form a quad
                # Look for a point in row_top that is to the RIGHT of pt_top
                neighbors = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors: continue
                pt_tr = min(neighbors, key=lambda p: p[0]) # The immediate right neighbor

                if abs(pt_tr[0] - pt_top[0]) > 80: continue # Neighbor too far (gap in grid)

                # Now find Bottom-Right matching that Top-Right
                pt_bl = min(candidates, key=lambda p: abs(p[0] - pt_top[0]))

                # Find BR
                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 60]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                # Warp
                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [cell_size,0], [cell_size,cell_size], [0,cell_size]], dtype="float32")

                M = cv2.getPerspectiveTransform(src, dst)
                warped = cv2.warpPerspective(img, M, (cell_size, cell_size))

                # Determine Place
                # Heuristic: grid index based on X coordinate relative to image width
                # This approximates the column index even if points are missing
                col_idx = int(pt_top[0] / (w / num_cols))

                # Refine placement logic: Just stack them?
                # Better: Use the row index 'r' and a simple counter if the grid is dense
                # For visualization, let's just use relative position

                y_pos = r * cell_size
                # Find matching column index in the 'row_top' array
                # (This assumes row_top is sorted, which it is)
                c_idx = np.where(np.all(row_top == pt_top, axis=1))[0][0]

                x_pos = c_idx * cell_size

                if y_pos+cell_size <= canvas_h and x_pos+cell_size <= canvas_w:
                    # Crop and paste
                    warped = warped[2:-2, 2:-2] # Crop border
                    warped = cv2.resize(warped, (cell_size, cell_size))
                    final_grid[y_pos:y_pos+cell_size, x_pos:x_pos+cell_size] = warped

        return final_grid

    except Exception as e:
        print(f"‚ö†Ô∏è Processing Error: {e}. Falling back.")
        return fallback_img

# --- ERROR-FREE VISUALIZATION ---
def visualize_safe(image_path):
    # This now returns an IMAGE (Array) always. Never None.
    result = reconstruct_grid_safe(image_path)

    plt.figure(figsize=(10,10))
    # Check if result is empty or weird shape
    if result.shape[0] == 0:
        print("Empty result")
        return

    plt.imshow(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
    plt.title("Grid Result (Safe Mode)")
    plt.axis('off')
    plt.show()

# Replace with your Lab image path
visualize_safe("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images/0064.png")

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from pathlib import Path
import random
from tqdm.auto import tqdm

# --- CONFIG ---
TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
NUM_TO_VISUALIZE = 50

# ==========================================
#    CORE GRID PROCESSING FUNCTIONS
# ==========================================

def preprocess_for_grid(img):
    """ Removes texture (Lab/Forest) but keeps black lines. """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # Median Blur is crucial for killing texture noise
    no_texture = cv2.medianBlur(gray, 7)
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25

    # --- CHANGE 1: THICKER KERNELS (25x3) ---
    # Catches lines even if they are slightly tilted/skewed
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 3))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, scale))

    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)

    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))

    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    points = []
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] > 10: points.append(centroids[i])

    if not points: return np.array([])

    # DBSCAN Clustering to merge double-detections
    points = np.array(points)
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)
    clean_points = []
    for label in set(clustering.labels_):
        clean_points.append(np.mean(points[clustering.labels_ == label], axis=0))
    return np.array(clean_points)

def sort_points_robust(points):
    # Cluster Y-coordinates to define Rows
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(points[:, 1].reshape(-1, 1))
    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)

    # Sort rows top-to-bottom, points left-to-right
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))
    final_rows = []
    for k in sorted_keys:
        row_pts = sorted(rows_dict[k], key=lambda p: p[0])
        final_rows.append(np.array(row_pts))
    return final_rows

def reconstruct_grid_safe_with_status(image_path, cell_size=40):
    """ Returns (processed_image, status_string) """
    img = cv2.imread(str(image_path))
    if img is None: return np.zeros((800,800,3), dtype=np.uint8), "Load Error"

    h, w = img.shape[:2]
    fallback_img = cv2.resize(img, (800, 800))

    try:
        points = get_intersections(img)
        # VALIDATION THRESHOLDS
        if len(points) < 40: return fallback_img, f"Fallback: Low Pts ({len(points)})"

        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 8: return fallback_img, f"Fallback: Few Rows ({len(grid_rows)})"

        # WARPING LOOP
        num_rows = len(grid_rows)
        # Use median row length to estimate grid width
        row_lens = [len(r) for r in grid_rows]
        if not row_lens: return fallback_img, "Fallback: Empty Rows"
        num_cols_est = int(np.median(row_lens))
        if num_cols_est < 8: return fallback_img, f"Fallback: Narrow Grid ({num_cols_est} cols)"

        canvas_h = (num_rows - 1) * cell_size
        canvas_w = (num_cols_est - 1) * cell_size
        # Grey background to show gaps
        final_grid = np.ones((canvas_h, canvas_w, 3), dtype=np.uint8) * 128
        cells_warped = 0

        for r in range(num_rows - 1):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]

            for pt_top in row_top:
                # 1. Find TR neighbor
                neighbors_tr = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors_tr: continue
                pt_tr = min(neighbors_tr, key=lambda p: p[0])

                # --- CHANGE 3: RELAXED SKEW LIMIT (120px) ---
                if abs(pt_tr[0] - pt_top[0]) > 120: continue

                # 2. Find BL match in bottom row
                candidates_bl = [p for p in row_btm if abs(p[0] - pt_top[0]) < 70]
                if not candidates_bl: continue
                pt_bl = min(candidates_bl, key=lambda p: abs(p[0] - pt_top[0]))

                # 3. Find BR match in bottom row
                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 70]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                # Warp
                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [cell_size,0], [cell_size,cell_size], [0,cell_size]], dtype="float32")
                M = cv2.getPerspectiveTransform(src, dst)

                # --- CHANGE 2: BORDER REPLICATE ---
                # Smears the edge texture into empty space instead of adding black walls
                warped = cv2.warpPerspective(img, M, (cell_size, cell_size), borderMode=cv2.BORDER_REPLICATE)

                # Crop border & Place
                warped = warped[2:-2, 2:-2]
                warped = cv2.resize(warped, (cell_size, cell_size))

                # Determine placement based on X-coord relative to image width
                c_idx = int(pt_top[0] / (w / num_cols_est))
                y_pos, x_pos = r * cell_size, c_idx * cell_size

                if y_pos+cell_size <= canvas_h and x_pos+cell_size <= canvas_w:
                    final_grid[y_pos:y_pos+cell_size, x_pos:x_pos+cell_size] = warped
                    cells_warped += 1

        if cells_warped < 20: return fallback_img, "Fallback: Warp Failed"
        return final_grid, "Active: Gridded"

    except Exception as e:
        return fallback_img, f"Fallback: Error"

# ==========================================
#    BATCH VISUALIZATION
# ==========================================

def visualize_batch_processing():
    all_images = sorted(list(TEST_IMG_DIR.glob("*.png")))
    if not all_images:
        print("‚ùå No images found in directory.")
        return

    # Select random sample
    sample_paths = random.sample(all_images, min(NUM_TO_VISUALIZE, len(all_images)))

    results = []
    print(f"Processing {len(sample_paths)} images...")
    for p in tqdm(sample_paths):
        orig = cv2.imread(str(p))
        processed, status = reconstruct_grid_safe_with_status(p)
        results.append((p.name, orig, processed, status))

    # Setup Grid Plot: 10 rows, 10 columns (5 pairs per row)
    rows = 10
    cols_per_pair = 2
    total_cols = 5 * cols_per_pair

    fig, axes = plt.subplots(rows, total_cols, figsize=(20, 25))
    axes = axes.flatten()

    active_count = 0

    for i, (name, orig, proc, status) in enumerate(results):
        if i >= len(axes)//2: break

        ax_orig = axes[i * 2]
        ax_proc = axes[i * 2 + 1]

        # Original
        if orig is not None:
            ax_orig.imshow(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB))
        ax_orig.set_title(f"{name}\nOriginal", fontsize=9)
        ax_orig.axis('off')

        # Processed
        ax_proc.imshow(cv2.cvtColor(proc, cv2.COLOR_BGR2RGB))

        # Color-code title based on status
        title_color = 'green' if status.startswith("Active") else 'orangered'
        if status.startswith("Active"): active_count += 1

        ax_proc.set_title(status, fontsize=9, color=title_color, fontweight='bold')
        ax_proc.axis('off')

    # Hide unused axes
    for j in range(len(results)*2, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.suptitle(f"Batch Processing Report: {active_count}/{len(results)} Active Grid Reconstruction", y=1.02, fontsize=16)
    plt.show()

# Run the report
visualize_batch_processing()

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from pathlib import Path
import random
from tqdm.auto import tqdm


TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
CELL_SIZE = 40
NUM_TO_VISUALIZE = 50
def tiny_deskew(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 50, 150)
    lines = cv2.HoughLines(edges, 1, np.pi/180, 250)

    if lines is None:
        return img

    angles = []
    for rho, theta in lines[:,0]:
        ang = (theta - np.pi/2) * 180 / np.pi
        if abs(ang) < 20:
            angles.append(ang)

    if not angles:
        return img

    angle = np.median(angles)
    if abs(angle) < 3:
        return img

    h, w = img.shape[:2]
    M = cv2.getRotationMatrix2D((w//2, h//2), angle, 1.0)
    return cv2.warpAffine(img, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
def preprocess_for_grid(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.medianBlur(gray, 7)
    return cv2.adaptiveThreshold(
        blur, 255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY_INV,
        19, 5
    )
def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25

    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 5))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, scale))

    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)

    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))

    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    pts = [centroids[i] for i in range(1, num_labels)
           if stats[i, cv2.CC_STAT_AREA] > 15]

    if not pts:
        return np.array([])

    pts = np.array(pts)
    clustering = DBSCAN(eps=20, min_samples=1).fit(pts)

    return np.array([
        pts[clustering.labels_ == k].mean(axis=0)
        for k in set(clustering.labels_)
    ])
def sort_points_robust(points):
    y_clustering = DBSCAN(eps=30, min_samples=3).fit(points[:,1:2])

    rows = {}
    for pt, lbl in zip(points, y_clustering.labels_):
        if lbl == -1:
            continue
        rows.setdefault(lbl, []).append(pt)

    sorted_rows = sorted(
        rows.values(),
        key=lambda r: np.mean([p[1] for p in r])
    )

    return [np.array(sorted(row, key=lambda p: p[0])) for row in sorted_rows]
def reconstruct_grid_safe(img):
    img = tiny_deskew(img)
    fallback = cv2.resize(img, (800,800))

    points = get_intersections(img)
    if len(points) < 40:
        return fallback, "Fallback: Low intersections"

    rows = sort_points_robust(points)
    if len(rows) < 8:
        return fallback, "Fallback: Few rows"

    num_rows = len(rows)
    num_cols = int(np.median([len(r) for r in rows]))
    if num_cols < 8:
        return fallback, "Fallback: Few cols"

    canvas = np.ones(
        ((num_rows-1)*CELL_SIZE, (num_cols-1)*CELL_SIZE, 3),
        dtype=np.uint8
    ) * 128

    warped = 0
    h, w = img.shape[:2]

    for r in range(num_rows - 1):
        top = rows[r]
        bottom = rows[r+1]

        for pt_tl in top:
            rights = [p for p in top if p[0] > pt_tl[0]]
            if not rights:
                continue
            pt_tr = min(rights, key=lambda p: p[0])

            # üîß tighter skew limit
            if abs(pt_tr[0] - pt_tl[0]) > 90:
                continue

            bls = [p for p in bottom if abs(p[0] - pt_tl[0]) < 70]
            brs = [p for p in bottom if abs(p[0] - pt_tr[0]) < 70]
            if not bls or not brs:
                continue

            pt_bl = min(bls, key=lambda p: abs(p[0] - pt_tl[0]))
            pt_br = min(brs, key=lambda p: abs(p[0] - pt_tr[0]))

            src = np.array([pt_tl, pt_tr, pt_br, pt_bl], dtype=np.float32)

            # üîß quad safety
            if cv2.contourArea(src) < 150:
                continue
            if not cv2.isContourConvex(src):
                continue

            dst = np.array([
                [0,0], [CELL_SIZE,0],
                [CELL_SIZE,CELL_SIZE], [0,CELL_SIZE]
            ], dtype=np.float32)

            M = cv2.getPerspectiveTransform(src, dst)
            cell = cv2.warpPerspective(
                img, M, (CELL_SIZE, CELL_SIZE),
                borderMode=cv2.BORDER_REPLICATE
            )

            c_idx = int(pt_tl[0] / (w / num_cols))
            y, x = r*CELL_SIZE, c_idx*CELL_SIZE

            if y+CELL_SIZE <= canvas.shape[0] and x+CELL_SIZE <= canvas.shape[1]:
                canvas[y:y+CELL_SIZE, x:x+CELL_SIZE] = cell
                warped += 1

    if warped < 20:
        return fallback, "Fallback: Warp failed"

    return canvas, "Active: Gridded"
def visualize_batch():
    imgs = list(TEST_IMG_DIR.glob("*.png"))
    samples = random.sample(imgs, min(NUM_TO_VISUALIZE, len(imgs)))

    fig, axes = plt.subplots(len(samples), 2, figsize=(10, 3*len(samples)))

    for i, p in enumerate(tqdm(samples)):
        img = cv2.imread(str(p))
        out, status = reconstruct_grid_safe(img)

        axes[i,0].imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        axes[i,0].set_title("Original")
        axes[i,0].axis("off")

        axes[i,1].imshow(cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
        axes[i,1].set_title(status, color="green" if "Active" in status else "red")
        axes[i,1].axis("off")

    plt.tight_layout()
    plt.show()
visualize_batch()



In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from pathlib import Path
import random
from tqdm.auto import tqdm

# --- CONFIG ---
TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
NUM_TO_VISUALIZE = 50

# ==========================================
#    CORE GRID PROCESSING FUNCTIONS
# ==========================================

def preprocess_for_grid(img):
    """ Removes texture (Lab/Forest) but keeps black lines. """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # Median Blur is crucial for killing texture noise
    no_texture = cv2.medianBlur(gray, 7)
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))
    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)
    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))

    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    points = []
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] > 10: points.append(centroids[i])

    if not points: return np.array([])

    # DBSCAN Clustering to merge double-detections
    points = np.array(points)
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)
    clean_points = []
    for label in set(clustering.labels_):
        clean_points.append(np.mean(points[clustering.labels_ == label], axis=0))
    return np.array(clean_points)

def sort_points_robust(points):
    # Cluster Y-coordinates to define Rows
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(points[:, 1].reshape(-1, 1))
    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)

    # Sort rows top-to-bottom, points left-to-right
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))
    final_rows = []
    for k in sorted_keys:
        row_pts = sorted(rows_dict[k], key=lambda p: p[0])
        final_rows.append(np.array(row_pts))
    return final_rows

def reconstruct_grid_safe_with_status(image_path, cell_size=40):
    """ Returns (processed_image, status_string) """
    img = cv2.imread(str(image_path))
    if img is None: return np.zeros((800,800,3), dtype=np.uint8), "Load Error"

    h, w = img.shape[:2]
    fallback_img = cv2.resize(img, (800, 800))

    try:
        points = get_intersections(img)
        # VALIDATION THRESHOLDS
        if len(points) < 40: return fallback_img, f"Fallback: Low Pts ({len(points)})"

        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 8: return fallback_img, f"Fallback: Few Rows ({len(grid_rows)})"

        # WARPING LOOP
        num_rows = len(grid_rows)
        # Use median row length to estimate grid width
        row_lens = [len(r) for r in grid_rows]
        if not row_lens: return fallback_img, "Fallback: Empty Rows"
        num_cols_est = int(np.median(row_lens))
        if num_cols_est < 8: return fallback_img, f"Fallback: Narrow Grid ({num_cols_est} cols)"

        canvas_h = (num_rows - 1) * cell_size
        canvas_w = (num_cols_est - 1) * cell_size
        # Grey background to show gaps
        final_grid = np.ones((canvas_h, canvas_w, 3), dtype=np.uint8) * 128
        cells_warped = 0

        for r in range(num_rows - 1):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]

            for pt_top in row_top:
                # 1. Find TR neighbor
                neighbors_tr = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors_tr: continue
                pt_tr = min(neighbors_tr, key=lambda p: p[0])
                if abs(pt_tr[0] - pt_top[0]) > 100: continue # Gap too big

                # 2. Find BL match in bottom row
                candidates_bl = [p for p in row_btm if abs(p[0] - pt_top[0]) < 70]
                if not candidates_bl: continue
                pt_bl = min(candidates_bl, key=lambda p: abs(p[0] - pt_top[0]))

                # 3. Find BR match in bottom row
                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 70]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                # Warp
                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [cell_size,0], [cell_size,cell_size], [0,cell_size]], dtype="float32")
                M = cv2.getPerspectiveTransform(src, dst)
                warped = cv2.warpPerspective(img, M, (cell_size, cell_size))

                # Crop border & Place
                warped = warped[2:-2, 2:-2]
                warped = cv2.resize(warped, (cell_size, cell_size))

                # Determine placement based on X-coord relative to image width
                c_idx = int(pt_top[0] / (w / num_cols_est))
                y_pos, x_pos = r * cell_size, c_idx * cell_size

                if y_pos+cell_size <= canvas_h and x_pos+cell_size <= canvas_w:
                    final_grid[y_pos:y_pos+cell_size, x_pos:x_pos+cell_size] = warped
                    cells_warped += 1

        if cells_warped < 20: return fallback_img, "Fallback: Warp Failed"
        return final_grid, "Active: Gridded"

    except Exception as e:
        return fallback_img, f"Fallback: Error"

# ==========================================
#    BATCH VISUALIZATION
# ==========================================

def visualize_batch_processing():
    all_images = sorted(list(TEST_IMG_DIR.glob("*.png")))
    if not all_images:
        print("‚ùå No images found in directory.")
        return

    # Select random sample
    sample_paths = random.sample(all_images, min(NUM_TO_VISUALIZE, len(all_images)))

    results = []
    print(f"Processing {len(sample_paths)} images...")
    for p in tqdm(sample_paths):
        orig = cv2.imread(str(p))
        processed, status = reconstruct_grid_safe_with_status(p)
        results.append((p.name, orig, processed, status))

    # Setup Grid Plot: 10 rows, 10 columns (5 pairs per row)
    rows = 10
    cols_per_pair = 2
    total_cols = 5 * cols_per_pair

    fig, axes = plt.subplots(rows, total_cols, figsize=(20, 25))
    axes = axes.flatten()

    active_count = 0

    for i, (name, orig, proc, status) in enumerate(results):
        if i >= len(axes)//2: break

        ax_orig = axes[i * 2]
        ax_proc = axes[i * 2 + 1]

        # Original
        if orig is not None:
            ax_orig.imshow(cv2.cvtColor(orig, cv2.COLOR_BGR2RGB))
        ax_orig.set_title(f"{name}\nOriginal", fontsize=9)
        ax_orig.axis('off')

        # Processed
        ax_proc.imshow(cv2.cvtColor(proc, cv2.COLOR_BGR2RGB))

        # Color-code title based on status
        title_color = 'green' if status.startswith("Active") else 'orangered'
        if status.startswith("Active"): active_count += 1

        ax_proc.set_title(status, fontsize=9, color=title_color, fontweight='bold')
        ax_proc.axis('off')

    # Hide unused axes
    for j in range(len(results)*2, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.suptitle(f"Batch Processing Report: {active_count}/{len(results)} Active Grid Reconstruction", y=1.02, fontsize=16)
    plt.show()

# Run the report
visualize_batch_processing()

In [None]:
import cv2
import numpy as np
import os
from pathlib import Path
from tqdm.auto import tqdm
import shutil

# --- CONFIG ---
TRAIN_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/train/images")
TRAIN_MASK_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/train/masks")
OUTPUT_DATASET_DIR = Path("./cnn_dataset")

# Classes based on Mask Colors
# 0: Lab/Road (RGB: 128,128,128 or similar) -> We'll define ranges
# 1: Wall (Black)
# 2: Mud/Forest/Sand (Texture)
# 3: Start (Green)
# 4: Goal (Red)

# Setup Directories
for category in ["road", "wall", "mud", "start", "goal"]:
    os.makedirs(OUTPUT_DATASET_DIR / category, exist_ok=True)

# Reuse your BEST Slicer Function
def get_grid_cells_and_masks(img, mask, cell_size=40):
    """
    Slices both the image and the mask using the robust grid logic.
    Returns a list of (cell_img, cell_mask_patch).
    """
    # ... (Insert your 'preprocess_for_grid' and 'get_intersections' logic here) ...
    # For brevity, I will use a simplified "Resize & Slice" logic which mimics
    # what we will do at inference time for stability.

    # 1. Resize both to 800x800 (Force Grid Alignment)
    img_r = cv2.resize(img, (800, 800))
    mask_r = cv2.resize(mask, (800, 800), interpolation=cv2.INTER_NEAREST)

    cells = []

    # 2. Slice strictly into 20x20
    for r in range(20):
        for c in range(20):
            y1, y2 = r*cell_size, (r+1)*cell_size
            x1, x2 = c*cell_size, (c+1)*cell_size

            img_cell = img_r[y1:y2, x1:x2]
            mask_cell = mask_r[y1:y2, x1:x2]

            # Optional: Crop border to remove grid lines
            crop = 3
            img_cell = img_cell[crop:-crop, crop:-crop]
            mask_cell = mask_cell[crop:-crop, crop:-crop]

            cells.append((img_cell, mask_cell))

    return cells

def identify_label(mask_cell):
    """
    Determines the class of the cell based on the ground truth mask.
    """
    # Count unique colors
    # We look at the center pixel or majority vote

    # Simple Logic: Check presence of specific colors
    # Mask format is likely (H, W, 3) BGR

    # 1. Check for Goal (Red) - BGR: (0, 0, 255) approx
    red_mask = cv2.inRange(mask_cell, (0, 0, 200), (50, 50, 255))
    if np.sum(red_mask) > 10: return "goal"

    # 2. Check for Start (Green) - BGR: (0, 200, 0) approx
    green_mask = cv2.inRange(mask_cell, (0, 200, 0), (50, 255, 50))
    if np.sum(green_mask) > 10: return "start"

    # 3. Check for Wall (Black)
    # Walls are strictly black (0,0,0)
    black_pixels = np.sum(np.all(mask_cell < 30, axis=-1))
    total_pixels = mask_cell.shape[0] * mask_cell.shape[1]
    if black_pixels > total_pixels * 0.5: return "wall"

    # 4. Check for Mud/Texture
    # This depends on the specific mask color coding for mud.
    # Usually mud is Brown/Orange or similar.
    # If not Wall, Start, Goal, or Road (Grey), it's Mud.

    # Road/Lab floor is usually greyish or white in masks?
    # Let's assume anything else is Mud for now, or check training mask palette.
    # For this snippet, let's assume Road is the default if not others.

    return "road" # Default

# --- EXECUTION ---
img_paths = sorted(list(TRAIN_IMG_DIR.glob("*.png")))
mask_paths = sorted(list(TRAIN_MASK_DIR.glob("*.png")))

print(f"Generating dataset from {len(img_paths)} images...")

count = 0
for img_p, mask_p in tqdm(zip(img_paths, mask_paths), total=len(img_paths)):
    img = cv2.imread(str(img_p))
    mask = cv2.imread(str(mask_p))

    # Use the robust slicer (or the simple resizer for training data generation)
    # Using Simple Resizer ensures we get perfect alignment for training labels
    cell_data = get_grid_cells_and_masks(img, mask)

    for i, (c_img, c_mask) in enumerate(cell_data):
        label = identify_label(c_mask)

        # Save
        filename = f"{img_p.stem}_cell_{i}.png"
        save_path = OUTPUT_DATASET_DIR / label / filename

        # Only save valid chunks
        if c_img.shape[0] > 10 and c_img.shape[1] > 10:
            cv2.imwrite(str(save_path), c_img)
            count += 1

print(f"‚úÖ Generated {count} labeled cell images!")

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import DBSCAN
from pathlib import Path
from tqdm.auto import tqdm

# --- CONFIG ---
TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
CELL_SIZE = 64         # Fixed size for CNN input
MAX_ROWS = 20
MAX_COLS = 20

# ==========================================
#  CORE HELPERS (Unchanged)
# ==========================================
def preprocess_for_grid(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    no_texture = cv2.medianBlur(gray, 7)
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))
    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)
    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))

    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    points = []
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] > 10: points.append(centroids[i])

    if not points: return np.array([])

    points = np.array(points)
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)
    clean_points = []
    for label in set(clustering.labels_):
        clean_points.append(np.mean(points[clustering.labels_ == label], axis=0))
    return np.array(clean_points)

def sort_points_robust(points):
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(points[:, 1].reshape(-1, 1))
    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)

    # Sort rows by Y, then points in row by X
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))
    final_rows = []
    for k in sorted_keys:
        row_pts = sorted(rows_dict[k], key=lambda p: p[0])
        final_rows.append(np.array(row_pts))
    return final_rows

# ==========================================
#  NEW: GLOBAL COLUMN ALIGNMENT
# ==========================================
def get_global_column_grid(grid_rows):
    """
    Collects all X-coords and clusters them to find the canonical
    vertical grid lines valid for the ENTIRE image.
    """
    all_x = []
    for row in grid_rows:
        for pt in row:
            all_x.append(pt[0])

    if not all_x: return np.array([])

    all_x = np.array(all_x).reshape(-1, 1)

    # Cluster X-coordinates (Vertical Lines)
    # eps=15 means points within 15px horizontally are considered the same column line
    clustering = DBSCAN(eps=15, min_samples=1).fit(all_x)

    unique_labels = set(clustering.labels_)
    col_centers = []
    for label in unique_labels:
        if label == -1: continue
        # Average x for this vertical line
        center = np.mean(all_x[clustering.labels_ == label])
        col_centers.append(center)

    # Sort left to right
    return np.array(sorted(col_centers))

# ==========================================
#  UPDATED EXTRACTOR
# ==========================================
def extract_grid_data(image_path):
    img = cv2.imread(str(image_path))
    if img is None: return None, None, "Load Error"

    grid_tensor = np.zeros((MAX_ROWS, MAX_COLS, CELL_SIZE, CELL_SIZE, 3), dtype=np.uint8)
    mask_tensor = np.zeros((MAX_ROWS, MAX_COLS), dtype=np.uint8)

    try:
        points = get_intersections(img)
        if len(points) < 40: return grid_tensor, mask_tensor, "Fallback: Low Pts"

        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 4: return grid_tensor, mask_tensor, "Fallback: Few Rows"

        # --- FIX STEP: Get Canonical Columns ---
        col_centers = get_global_column_grid(grid_rows)
        if len(col_centers) < 2: return grid_tensor, mask_tensor, "Fallback: No Cols"

        # We can also verify row alignment if needed, but rows are usually cleaner

        for r in range(min(len(grid_rows) - 1, MAX_ROWS)):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]

            for pt_top in row_top:
                # 1. Find neighbors (Same as before)
                neighbors_tr = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors_tr: continue
                pt_tr = min(neighbors_tr, key=lambda p: p[0])
                if abs(pt_tr[0] - pt_top[0]) > 100: continue

                candidates_bl = [p for p in row_btm if abs(p[0] - pt_top[0]) < 70]
                if not candidates_bl: continue
                pt_bl = min(candidates_bl, key=lambda p: abs(p[0] - pt_top[0]))

                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 70]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                # Warp
                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [CELL_SIZE,0], [CELL_SIZE,CELL_SIZE], [0,CELL_SIZE]], dtype="float32")
                M = cv2.getPerspectiveTransform(src, dst)
                warped = cv2.warpPerspective(img, M, (CELL_SIZE, CELL_SIZE))

                # --- FIX STEP: Robust Indexing ---
                # Find which canonical column this cell starts at
                # We compare pt_top[0] (left edge) to our known column lines
                diffs = np.abs(col_centers - pt_top[0])
                c_idx = np.argmin(diffs)

                if c_idx < MAX_COLS:
                    grid_tensor[r, c_idx] = warped
                    mask_tensor[r, c_idx] = 1

        if np.sum(mask_tensor) < 20: return grid_tensor, mask_tensor, "Fallback: Warp Failed"
        return grid_tensor, mask_tensor, "Active: Gridded"

    except Exception as e:
        return grid_tensor, mask_tensor, f"Error: {str(e)}"

# ==========================================
#  VISUALIZATION
# ==========================================
# 1. Pick an image
img_path = list(TEST_IMG_DIR.glob("*.png"))[0]

# 2. Extract Data
grid_data, valid_mask, status = extract_grid_data(img_path)

print(f"Status: {status}")
print(f"Grid Shape: {grid_data.shape}")
print(f"Valid Cells Found: {np.sum(valid_mask)}")

# 3. Reconstruct
reconstruction = np.zeros((MAX_ROWS * CELL_SIZE, MAX_COLS * CELL_SIZE, 3), dtype=np.uint8)

for r in range(MAX_ROWS):
    for c in range(MAX_COLS):
        if valid_mask[r, c] == 1:
            reconstruction[r*CELL_SIZE:(r+1)*CELL_SIZE, c*CELL_SIZE:(c+1)*CELL_SIZE] = grid_data[r, c]
        else:
            # Mark empty/grey spots clearly
            cv2.rectangle(reconstruction,
                         (c*CELL_SIZE, r*CELL_SIZE),
                         ((c+1)*CELL_SIZE, (r+1)*CELL_SIZE),
                         (30, 30, 30), -1)

plt.figure(figsize=(12,12))
plt.imshow(cv2.cvtColor(reconstruction, cv2.COLOR_BGR2RGB))
plt.title(f"Aligned Reconstruction: {img_path.name}")
plt.axis('off')
plt.show()

In [None]:
import cv2
import numpy as np
import os
from pathlib import Path
from tqdm import tqdm
import random

# ==========================================
# 1. CONFIGURATION
# ==========================================
ASSETS_ROOT = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/assets")
OUTPUT_DIR = Path("./generated_dataset_final_v7")
IMG_SIZE = 64

BASE_SAMPLES = 600
BOOST_SAMPLES = 1200
MEGA_BOOST_SAMPLES = 2000 # For the problematic Desert Start/End

# ==========================================
# 2. MAPPING
# ==========================================
ASSET_MAP = {
    # --- DESERT ---
    "t1_sand": "Desert_Road",
    "t1_cacti": "Desert_Cacti",
    "t1_rocks": "Desert_Rocks",
    "t1_quicksand": "Desert_Hazard",
    "t1_rover": "Desert_Start",    # <--- MEGA TARGET
    "t1_goal": "Desert_End",       # <--- MEGA TARGET

    # --- FOREST ---
    "t0_dirt": "Forest_Road",
    "t0_tree": "Forest_Tree",
    "t0_puddle": "Forest_Hazard",
    "t0_startship": "Forest_Start",
    "t0_goal": "Forest_End",

    # --- LAB ---
    "t2_floor": "Lab_Road",
    "t2_wall": "Lab_Wall",
    "t2_plasma": "Lab_Plasma",
    "t2_glue": "Lab_Hazard",
    "t2_drone": "Lab_Start",
    "t2_goal": "Lab_End",
}

# ==========================================
# 3. AUGMENTATION ENGINES
# ==========================================
def apply_realistic_tint(img):
    """ Subtle Blue/Grey/Green tints only """
    if random.random() > 0.6: return img

    overlay = np.zeros_like(img)
    mode = random.choice(["Blue", "Green", "Grey"])
    b, g, r = 0, 0, 0
    intensity = random.randint(20, 50)

    if mode == "Blue": b = intensity + 20; g = intensity // 2
    elif mode == "Green": g = intensity + 20; b = intensity // 2
    elif mode == "Grey": b = intensity; g = intensity; r = intensity

    overlay[:] = (b, g, r)
    return cv2.addWeighted(img, 0.85, overlay, 0.15, 0)

def augment_standard(img):
    """ Normal augmentation for roads/walls """
    aug_img = img.copy()
    if random.random() > 0.40:
        k = random.choice([1, 2, 3])
        aug_img = np.rot90(aug_img, k)
    if random.random() > 0.5: aug_img = cv2.flip(aug_img, 1)
    aug_img = apply_realistic_tint(aug_img)
    return aug_img

def augment_gentle(img):
    """
    SPECIAL AUGMENTATION for Start/End.
    - NO FLIPPING (Keeps orientation).
    - VERY LIGHT ROTATION (Only 90 deg steps, no weird skews).
    - MINIMAL TINT (Keep colors true).
    """
    aug_img = img.copy()

    # Rotation is okay, but let's keep it simple
    if random.random() > 0.30:
        k = random.choice([1, 2, 3])
        aug_img = np.rot90(aug_img, k)

    # NO FLIPPING! The rover/flag shape might be asymmetric.

    # 80% Clean, 20% very light tint
    if random.random() > 0.8:
        aug_img = apply_realistic_tint(aug_img)

    return aug_img

# ==========================================
# 4. GENERATOR LOOP
# ==========================================
def generate_final_dataset():
    if OUTPUT_DIR.exists():
        import shutil
        shutil.rmtree(OUTPUT_DIR)
    OUTPUT_DIR.mkdir(parents=True)

    all_assets = list(ASSETS_ROOT.rglob("*.png"))
    total_images = 0

    print(f"Generating V7 Precision Dataset...")

    for asset_path in all_assets:
        filename = asset_path.name.lower().replace(".png", "")

        class_label = None
        for key, label in ASSET_MAP.items():
            if key in filename:
                class_label = label
                break

        if class_label is None:
            if "rover" in filename: class_label = "Desert_Start"
            elif "startship" in filename: class_label = "Forest_Start"
            elif "drone" in filename: class_label = "Lab_Start"
            elif "goal" in filename:
                if "t0" in filename: class_label = "Forest_End"
                elif "t1" in filename: class_label = "Desert_End"
                elif "t2" in filename: class_label = "Lab_End"
                else: class_label = "Desert_End"
            else: continue

        class_dir = OUTPUT_DIR / class_label
        class_dir.mkdir(exist_ok=True)

        # LOGIC: HOW MANY SAMPLES?
        count = BASE_SAMPLES

        # 1. Mega Boost for Desert Start/End
        if class_label in ["Desert_Start", "Desert_End"]:
            count = MEGA_BOOST_SAMPLES

        # 2. Regular Boost for other tricky classes
        elif class_label in ["Lab_Start", "Desert_Road", "Forest_Road", "Lab_Road"]:
            count = BOOST_SAMPLES

        original = cv2.imread(str(asset_path))
        if original is None: continue
        original = cv2.resize(original, (IMG_SIZE, IMG_SIZE))

        for i in range(count):
            # LOGIC: WHICH AUGMENTATION?
            if class_label in ["Desert_Start", "Desert_End"]:
                img = augment_gentle(original)
            else:
                img = augment_standard(original)

            save_name = f"{class_label}_{i}.png"
            cv2.imwrite(str(class_dir / save_name), img)
            total_images += 1

    print(f"‚úÖ Generated {total_images} V7 images.")
    print(f"Classes: {sorted([d.name for d in OUTPUT_DIR.iterdir()])}")

if __name__ == "__main__":
    generate_final_dataset()
    visualize_clean_samples()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import json
from pathlib import Path
from tqdm import tqdm

# Config
DATA_DIR = Path("/kaggle/working/generated_dataset_final_v7")
MODEL_SAVE_PATH = "terrain_classifier_granular7.pth"
MAPPING_SAVE_PATH = "class_mapping.json"
BATCH_SIZE = 32
EPOCHS = 8
LEARNING_RATE = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Data
transform = transforms.ToTensor()
full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Model
class SimpleTerrainCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleTerrainCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleTerrainCNN(num_classes=len(full_dataset.classes)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting Training V5 (Pink/Cyan Tints)...")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Epoch {epoch+1}: Acc = {100 * correct / total:.2f}%")

torch.save(model.state_dict(), MODEL_SAVE_PATH)
idx_to_class = {v: k for k, v in full_dataset.class_to_idx.items()}
with open(MAPPING_SAVE_PATH, 'w') as f:
    json.dump(idx_to_class, f)
print("‚úÖ Saved Model V5!")

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
import cv2
import numpy as np
import matplotlib.pyplot as plt
import heapq
import json
from pathlib import Path
from sklearn.cluster import DBSCAN
import random
from tqdm.auto import tqdm

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_PATH = "terrain_classifier_resnet.pth"
MAPPING_PATH = "class_mapping.json"
TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
CELL_SIZE = 64
MAX_ROWS = 20
MAX_COLS = 20
NUM_SAMPLES = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- SPLIT COST TABLES (Matches V7 Dataset) ---
COST_TABLES = {
    "Desert": {
        "Road": 1.2, "Start": 1.2, "End": 2.2,
        "Hazard": 3.7,
        "Cacti": 999.0,
        "Rocks": 999.0,
        "Obstacle": 999.0,
        "Unknown": 8.0
    },
    "Forest": {
        "Road": 1.5, "Start": 1.5, "End": 2.5,
        "Hazard": 2.8,
        "Tree": 999.0,
        "Obstacle": 999.0,
        "Unknown": 8.0
    },
    "Lab": {
        "Road": 1.0, "Start": 1.0, "End": 2.0,
        "Hazard": 5.0,    # Glue
        "Wall": 999.0,
        "Plasma": 999.0,
        "Obstacle": 999.0,
        "Unknown": 8.0
    }
}

# ==========================================
# 2. VISION HELPERS (Global Alignment)
# ==========================================
def preprocess_for_grid(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    no_texture = cv2.medianBlur(gray, 7)
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))
    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)
    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))

    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    points = [centroids[i] for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > 10]
    if not points: return np.array([])
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)
    points = np.array(points)
    clean_points = [np.mean(points[clustering.labels_ == label], axis=0) for label in set(clustering.labels_)]
    return np.array(clean_points)

def sort_points_robust(points):
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(points[:, 1].reshape(-1, 1))
    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))
    return [np.array(sorted(rows_dict[k], key=lambda p: p[0])) for k in sorted_keys]

def get_global_column_grid(grid_rows):
    all_x = [pt[0] for row in grid_rows for pt in row]
    if not all_x: return np.array([])
    all_x = np.array(all_x).reshape(-1, 1)
    clustering = DBSCAN(eps=15, min_samples=1).fit(all_x)
    col_centers = [np.mean(all_x[clustering.labels_ == label]) for label in set(clustering.labels_) if label != -1]
    return np.array(sorted(col_centers))

def extract_grid_data(image_path):
    img = cv2.imread(str(image_path))
    if img is None: return None, None, "Load Error"

    grid_tensor = np.zeros((MAX_ROWS, MAX_COLS, CELL_SIZE, CELL_SIZE, 3), dtype=np.uint8)
    mask_tensor = np.zeros((MAX_ROWS, MAX_COLS), dtype=np.uint8)

    try:
        points = get_intersections(img)
        if len(points) < 40: return grid_tensor, mask_tensor, "Fallback: Low Pts"
        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 4: return grid_tensor, mask_tensor, "Fallback: Few Rows"
        col_centers = get_global_column_grid(grid_rows)
        if len(col_centers) < 2: return grid_tensor, mask_tensor, "Fallback: No Cols"

        for r in range(min(len(grid_rows) - 1, MAX_ROWS)):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]
            for pt_top in row_top:
                neighbors_tr = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors_tr: continue
                pt_tr = min(neighbors_tr, key=lambda p: p[0])
                if abs(pt_tr[0] - pt_top[0]) > 100: continue

                candidates_bl = [p for p in row_btm if abs(p[0] - pt_top[0]) < 70]
                if not candidates_bl: continue
                pt_bl = min(candidates_bl, key=lambda p: abs(p[0] - pt_top[0]))
                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 70]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [CELL_SIZE,0], [CELL_SIZE,CELL_SIZE], [0,CELL_SIZE]], dtype="float32")
                M = cv2.getPerspectiveTransform(src, dst)
                warped = cv2.warpPerspective(img, M, (CELL_SIZE, CELL_SIZE))

                diffs = np.abs(col_centers - pt_top[0])
                c_idx = np.argmin(diffs)
                if c_idx < MAX_COLS:
                    grid_tensor[r, c_idx] = warped
                    mask_tensor[r, c_idx] = 1
        return grid_tensor, mask_tensor, "Active: Gridded"
    except: return grid_tensor, mask_tensor, "Error"

# ==========================================
# 3. RESNET MODEL LOADING
# ==========================================
# 1. Load Mapping First
with open(MAPPING_PATH, 'r') as f:
    idx_to_class = json.load(f)
    idx_to_class = {int(k): v for k, v in idx_to_class.items()}

# 2. Define & Load ResNet
print("üöÄ Loading ResNet-34...")
model = models.resnet34(pretrained=False) # We load our own weights
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(idx_to_class))
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()

# 3. Normalization (CRITICAL for ResNet)
transform_pipe = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("‚úÖ Model Loaded Successfully")

# ==========================================
# 4. BATCH PROCESSOR
# ==========================================
def process_single_image(image_path):
    # 1. Slice
    grid_tensor, mask_tensor, status = extract_grid_data(image_path)
    if "Active" not in status: return None, "Slicer Fail"

    rows, cols = mask_tensor.shape
    batch_tensors = []
    coords = []

    # 2. Prepare Batch (BGR -> RGB -> Normalize)
    for r in range(rows):
        for c in range(cols):
            if mask_tensor[r, c] == 1:
                rgb_cell = cv2.cvtColor(grid_tensor[r, c], cv2.COLOR_BGR2RGB)
                batch_tensors.append(transform_pipe(rgb_cell))
                coords.append((r, c))

    if not batch_tensors: return None, "Empty Grid"

    # 3. Predict
    batch_stack = torch.stack(batch_tensors).to(device)
    with torch.no_grad():
        outputs = model(batch_stack)
        _, preds = torch.max(outputs, 1)

    # 4. Build Map & Biome Voting
    terrain_map = np.full((rows, cols), "Unknown", dtype=object)
    counts = {"Desert": 0, "Forest": 0, "Lab": 0}

    for i, (r, c) in enumerate(coords):
        class_name = idx_to_class[preds[i].item()]
        terrain_map[r, c] = class_name

        biome = class_name.split("_")[0] if "_" in class_name else "Lab"

        # Weight unique items higher to prevent ambiguity
        weight = 1
        if any(x in class_name for x in ["Cacti", "Tree", "Plasma", "Wall", "Rocks"]): weight = 3
        if "Start" in class_name or "End" in class_name: weight = 2

        counts[biome] = counts.get(biome, 0) + weight

    dominant_biome = max(counts, key=counts.get) if counts else "Lab"
    costs = COST_TABLES.get(dominant_biome, COST_TABLES["Lab"])

    # 5. Solver (A*)
    start_pos, end_pos = None, None
    for r in range(rows):
        for c in range(cols):
            if "Start" in terrain_map[r, c]: start_pos = (r, c)
            if "End" in terrain_map[r, c]: end_pos = (r, c)

    path = []
    status_msg = "No Path"

    if start_pos and end_pos:
        pq = [(0, start_pos)]
        cost_so_far = {start_pos: 0}
        came_from = {}

        while pq:
            curr_cost, curr = heapq.heappop(pq)
            if curr == end_pos:
                status_msg = "Solved"
                break

            r, c = curr
            for nr, nc in [(r-1,c), (r+1,c), (r,c-1), (r,c+1)]:
                if 0 <= nr < rows and 0 <= nc < cols:
                    label = terrain_map[nr, nc]

                    if label == "Unknown": cell_type = "Unknown"
                    else: cell_type = label.split("_")[1] if "_" in label else label

                    step_cost = costs.get(cell_type, 999.0)
                    new_cost = cost_so_far[curr] + step_cost

                    if step_cost < 100:
                        if (nr, nc) not in cost_so_far or new_cost < cost_so_far[(nr, nc)]:
                            cost_so_far[(nr, nc)] = new_cost
                            priority = new_cost + abs(nr-end_pos[0]) + abs(nc-end_pos[1])
                            heapq.heappush(pq, (priority, (nr, nc)))
                            came_from[(nr, nc)] = curr

        if status_msg == "Solved":
            curr = end_pos
            while curr != start_pos:
                path.append(curr)
                curr = came_from[curr]
            path.append(start_pos)

    # 6. Visualization
    img_vis = cv2.imread(str(image_path))
    img_vis = cv2.resize(img_vis, (400, 400))
    cell_h, cell_w = 400 // rows, 400 // cols

    for r in range(rows):
        for c in range(cols):
            if mask_tensor[r, c] == 1:
                lbl = terrain_map[r, c]
                cx, cy = int((c + 0.5) * cell_w), int((r + 0.5) * cell_h)
                color = (100, 100, 100)
                if "Start" in lbl: color = (0, 255, 255) # Yellow
                if "End" in lbl: color = (0, 0, 255)     # Red
                if any(x in lbl for x in ["Obstacle", "Wall", "Tree", "Cacti", "Rocks", "Plasma"]): color = (0, 0, 0)
                if "Hazard" in lbl: color = (0, 165, 255) # Orange
                cv2.circle(img_vis, (cx, cy), 3, color, -1)
            else:
                cx, cy = int((c + 0.5) * cell_w), int((r + 0.5) * cell_h)
                cv2.circle(img_vis, (cx, cy), 1, (50, 50, 50), -1)

    if path:
        for i in range(len(path) - 1):
            p1 = (int((path[i][1]+0.5)*cell_w), int((path[i][0]+0.5)*cell_h))
            p2 = (int((path[i+1][1]+0.5)*cell_w), int((path[i+1][0]+0.5)*cell_h))
            cv2.line(img_vis, p1, p2, (0, 255, 0), 2)

    return img_vis, f"{dominant_biome}: {status_msg}"

# ==========================================
# 5. EXECUTION
# ==========================================
test_files = list(TEST_IMG_DIR.glob("*.png"))
samples = random.sample(test_files, min(len(test_files), NUM_SAMPLES))

print(f"Processing {len(samples)} images...")
results = []
for img_path in tqdm(samples):
    res_img, status = process_single_image(img_path)
    if res_img is not None: results.append((res_img, status))

rows = (len(results) + 4) // 5
fig, axes = plt.subplots(rows, 5, figsize=(20, 4 * rows))
axes = axes.flatten()
for i, (img, status) in enumerate(results):
    ax = axes[i]
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    title_color = 'green' if "Solved" in status else 'red'
    ax.set_title(status, color=title_color, fontsize=10, fontweight='bold')
    ax.axis('off')
for j in range(i + 1, len(axes)): axes[j].axis('off')
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import json
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

# ==========================================
# 1. CONFIGURATION
# ==========================================
DATA_DIR = Path("/kaggle/working/generated_dataset_final_v5")
MODEL_SAVE_PATH = "terrain_classifier_resnet.pth"
MAPPING_SAVE_PATH = "class_mapping.json"
BATCH_SIZE = 32
EPOCHS = 10
LEARNING_RATE = 0.0001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==========================================
# 2. DATA LOADING & RESNET SETUP
# ==========================================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("üöÄ Loading ResNet-34...")
model = models.resnet34(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(full_dataset.classes))
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# ==========================================
# 3. TRAINING LOOP WITH LOGGING
# ==========================================
# Lists to store metrics
history = {
    'train_loss': [], 'val_loss': [],
    'train_acc': [], 'val_acc': []
}

print(f"Starting Training on {len(full_dataset.classes)} classes...")

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    # --- TRAINING PHASE ---
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
        loop.set_postfix(loss=loss.item())

    avg_train_loss = running_loss / len(train_loader)
    train_acc = 100 * correct_train / total_train

    # --- VALIDATION PHASE ---
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    avg_val_loss = val_loss / len(val_loader)
    val_acc = 100 * correct_val / total_val

    # Store Stats
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f} | Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")

# ==========================================
# 4. SAVE & VISUALIZE
# ==========================================
torch.save(model.state_dict(), MODEL_SAVE_PATH)
idx_to_class = {v: k for k, v in full_dataset.class_to_idx.items()}
with open(MAPPING_SAVE_PATH, 'w') as f:
    json.dump(idx_to_class, f)
print(f"‚úÖ Saved Model to {MODEL_SAVE_PATH}")

# PLOTTING
plt.figure(figsize=(12, 5))

# Plot 1: Loss Side by Side
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train Loss', color='blue', linestyle='--')
plt.plot(history['val_loss'], label='Val Loss', color='red')
plt.title('Loss: Training vs Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot 2: Accuracy Side by Side
plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train Acc', color='blue', linestyle='--')
plt.plot(history['val_acc'], label='Val Acc', color='green')
plt.title('Accuracy: Training vs Validation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt
import heapq
import json
from pathlib import Path
from sklearn.cluster import DBSCAN
import random
from tqdm.auto import tqdm

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_PATH = "/kaggle/working/terrain_classifier_granular7.pth"
MAPPING_PATH = "class_mapping.json"
TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
CELL_SIZE = 64
MAX_ROWS = 20
MAX_COLS = 20
NUM_SAMPLES = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Cost Tables (V3 - Split Classes)
COST_TABLES = {
    "Desert": {
        "Road": 1.2, "Start": 1.2, "End": 2.2, "Hazard": 3.7, "Cacti": 999.0, "Rocks": 999.0, "Obstacle": 999.0, "Unknown": 8.0
    },
    "Forest": {
        "Road": 1.5, "Start": 1.5, "End": 2.5, "Hazard": 2.8, "Tree": 999.0, "Obstacle": 999.0, "Unknown": 8.0
    },
    "Lab": {
        "Road": 1.0, "Start": 1.0, "End": 2.0, "Hazard": 3.0, "Wall": 999.0, "Plasma": 999.0, "Obstacle": 999.0, "Unknown": 8.0
    }
}

# ==========================================
# 2. CORE HELPERS (Raw Vision - No Tint Fix)
# ==========================================
def preprocess_for_grid(img):
    # Removed correct_white_balance!
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    no_texture = cv2.medianBlur(gray, 7)
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))
    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)
    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))
    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    points = [centroids[i] for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > 10]
    if not points: return np.array([])
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)
    points = np.array(points)
    clean_points = [np.mean(points[clustering.labels_ == label], axis=0) for label in set(clustering.labels_)]
    return np.array(clean_points)

def sort_points_robust(points):
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(points[:, 1].reshape(-1, 1))
    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))
    return [np.array(sorted(rows_dict[k], key=lambda p: p[0])) for k in sorted_keys]

def get_global_column_grid(grid_rows):
    all_x = [pt[0] for row in grid_rows for pt in row]
    if not all_x: return np.array([])
    all_x = np.array(all_x).reshape(-1, 1)
    clustering = DBSCAN(eps=15, min_samples=1).fit(all_x)
    col_centers = [np.mean(all_x[clustering.labels_ == label]) for label in set(clustering.labels_) if label != -1]
    return np.array(sorted(col_centers))

# ==========================================
# 3. ROBUST EXTRACTOR
# ==========================================
def extract_grid_data(image_path):
    img = cv2.imread(str(image_path))
    if img is None: return None, None, "Load Error"

    grid_tensor = np.zeros((MAX_ROWS, MAX_COLS, CELL_SIZE, CELL_SIZE, 3), dtype=np.uint8)
    mask_tensor = np.zeros((MAX_ROWS, MAX_COLS), dtype=np.uint8)

    try:
        points = get_intersections(img)
        if len(points) < 40: return grid_tensor, mask_tensor, "Fallback: Low Pts"
        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 4: return grid_tensor, mask_tensor, "Fallback: Few Rows"
        col_centers = get_global_column_grid(grid_rows)
        if len(col_centers) < 2: return grid_tensor, mask_tensor, "Fallback: No Cols"

        for r in range(min(len(grid_rows) - 1, MAX_ROWS)):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]
            for pt_top in row_top:
                neighbors_tr = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors_tr: continue
                pt_tr = min(neighbors_tr, key=lambda p: p[0])
                if abs(pt_tr[0] - pt_top[0]) > 100: continue

                candidates_bl = [p for p in row_btm if abs(p[0] - pt_top[0]) < 70]
                if not candidates_bl: continue
                pt_bl = min(candidates_bl, key=lambda p: abs(p[0] - pt_top[0]))
                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 70]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [CELL_SIZE,0], [CELL_SIZE,CELL_SIZE], [0,CELL_SIZE]], dtype="float32")
                M = cv2.getPerspectiveTransform(src, dst)

                # Warping pure image (No tint fix applied)
                warped = cv2.warpPerspective(img, M, (CELL_SIZE, CELL_SIZE))

                diffs = np.abs(col_centers - pt_top[0])
                c_idx = np.argmin(diffs)
                if c_idx < MAX_COLS:
                    grid_tensor[r, c_idx] = warped
                    mask_tensor[r, c_idx] = 1
        return grid_tensor, mask_tensor, "Active: Gridded"
    except: return grid_tensor, mask_tensor, "Error"

# ==========================================
# 4. MODEL LOADING
# ==========================================
class SimpleTerrainCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleTerrainCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

with open(MAPPING_PATH, 'r') as f:
    idx_to_class = json.load(f)
    idx_to_class = {int(k): v for k, v in idx_to_class.items()}

model = SimpleTerrainCNN(num_classes=len(idx_to_class)).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

# ==========================================
# 5. BATCH PROCESSOR (Weighted Biome Voting)
# ==========================================
def process_single_image(image_path):
    grid_tensor, mask_tensor, status = extract_grid_data(image_path)
    if "Active" not in status: return None, "Slicer Fail"

    rows, cols = mask_tensor.shape
    batch_tensors = []
    coords = []

    transform_pipe = transforms.Compose([
        transforms.ToPILImage(), transforms.Resize((64, 64)), transforms.ToTensor()
    ])

    for r in range(rows):
        for c in range(cols):
            if mask_tensor[r, c] == 1:
                rgb_cell = cv2.cvtColor(grid_tensor[r, c], cv2.COLOR_BGR2RGB)
                batch_tensors.append(transform_pipe(rgb_cell))
                coords.append((r, c))

    if not batch_tensors: return None, "Empty Grid"

    batch_stack = torch.stack(batch_tensors).to(device)
    with torch.no_grad():
        outputs = model(batch_stack)
        _, preds = torch.max(outputs, 1)

    terrain_map = np.full((rows, cols), "Unknown", dtype=object)

    # --- WEIGHTED BIOME VOTING ---
    # Give higher weight to unique items (Cacti/Plasma) to override generic Roads
    counts = {"Desert": 0, "Forest": 0, "Lab": 0}

    for i, (r, c) in enumerate(coords):
        class_name = idx_to_class[preds[i].item()]
        terrain_map[r, c] = class_name

        biome = class_name.split("_")[0] if "_" in class_name else "Lab"

        weight = 1
        # If it's a unique object, it counts for 3 votes
        if any(x in class_name for x in ["Cacti", "Tree", "Plasma", "Wall", "Rocks"]):
            weight = 3
        if "Start" in class_name or "End" in class_name:
            weight = 2

        counts[biome] = counts.get(biome, 0) + weight

    dominant_biome = max(counts, key=counts.get) if counts else "Lab"
    costs = COST_TABLES.get(dominant_biome, COST_TABLES["Lab"])

    # 6. Solver
    start_pos, end_pos = None, None
    for r in range(rows):
        for c in range(cols):
            if "Start" in terrain_map[r, c]: start_pos = (r, c)
            if "End" in terrain_map[r, c]: end_pos = (r, c)

    path = []
    status_msg = "No Path"

    if start_pos and end_pos:
        pq = [(0, start_pos)]
        cost_so_far = {start_pos: 0}
        came_from = {}

        while pq:
            curr_cost, curr = heapq.heappop(pq)
            if curr == end_pos:
                status_msg = "Solved"
                break

            r, c = curr
            for nr, nc in [(r-1,c), (r+1,c), (r,c-1), (r,c+1)]:
                if 0 <= nr < rows and 0 <= nc < cols:
                    label = terrain_map[nr, nc]

                    if label == "Unknown": cell_type = "Unknown"
                    else: cell_type = label.split("_")[1] if "_" in label else label

                    step_cost = costs.get(cell_type, 999.0)
                    new_cost = cost_so_far[curr] + step_cost

                    if step_cost < 100:
                        if (nr, nc) not in cost_so_far or new_cost < cost_so_far[(nr, nc)]:
                            cost_so_far[(nr, nc)] = new_cost
                            # Heuristic: simple Manhattan distance
                            priority = new_cost + abs(nr-end_pos[0]) + abs(nc-end_pos[1])
                            heapq.heappush(pq, (priority, (nr, nc)))
                            came_from[(nr, nc)] = curr

        if status_msg == "Solved":
            curr = end_pos
            while curr != start_pos:
                path.append(curr)
                curr = came_from[curr]
            path.append(start_pos)

    # 7. Visualization
    img_vis = cv2.imread(str(image_path))
    # No white balance correction here either!
    img_vis = cv2.resize(img_vis, (400, 400))
    cell_h, cell_w = 400 // rows, 400 // cols

    for r in range(rows):
        for c in range(cols):
            if mask_tensor[r, c] == 1:
                lbl = terrain_map[r, c]
                cx, cy = int((c + 0.5) * cell_w), int((r + 0.5) * cell_h)
                color = (100, 100, 100)
                if "Start" in lbl: color = (0, 255, 255)
                if "End" in lbl: color = (0, 0, 255)
                # Black for obstacles
                if any(x in lbl for x in ["Obstacle", "Wall", "Tree", "Cacti", "Rocks", "Plasma"]):
                    color = (0, 0, 0)
                # Orange for passable hazards
                if "Hazard" in lbl: color = (0, 165, 255)

                cv2.circle(img_vis, (cx, cy), 3, color, -1)
            else:
                cx, cy = int((c + 0.5) * cell_w), int((r + 0.5) * cell_h)
                cv2.circle(img_vis, (cx, cy), 1, (50, 50, 50), -1)

    if path:
        for i in range(len(path) - 1):
            p1 = (int((path[i][1]+0.5)*cell_w), int((path[i][0]+0.5)*cell_h))
            p2 = (int((path[i+1][1]+0.5)*cell_w), int((path[i+1][0]+0.5)*cell_h))
            cv2.line(img_vis, p1, p2, (0, 255, 0), 2)

    return img_vis, f"{dominant_biome}: {status_msg}"

# ==========================================
# 6. RUN
# ==========================================
test_files = list(TEST_IMG_DIR.glob("*.png"))
samples = random.sample(test_files, min(len(test_files), NUM_SAMPLES))

print(f"Processing {len(samples)} images...")
results = []
for img_path in tqdm(samples):
    res_img, status = process_single_image(img_path)
    if res_img is not None: results.append((res_img, status))

rows = (len(results) + 4) // 5
fig, axes = plt.subplots(rows, 5, figsize=(20, 4 * rows))
axes = axes.flatten()
for i, (img, status) in enumerate(results):
    ax = axes[i]
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    title_color = 'green' if "Solved" in status else 'red'
    ax.set_title(status, color=title_color, fontsize=10, fontweight='bold')
    ax.axis('off')
for j in range(i + 1, len(axes)): axes[j].axis('off')
plt.tight_layout()
plt.show()

In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, models
import cv2
import numpy as np
import matplotlib.pyplot as plt
import heapq
import json
from pathlib import Path
from sklearn.cluster import DBSCAN
import random
from tqdm.auto import tqdm

# ==========================================
# 1. CONFIGURATION
# ==========================================
# Paths to your TWO best models
PATH_SIMPLE_CNN = "terrain_classifier_granular7.pth" # Model A (Good for Lab/Forest)
PATH_RESNET     = "terrain_classifier_resnet.pth"    # Model B (Good for Desert)
MAPPING_PATH    = "class_mapping.json"

TEST_IMG_DIR = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset/test/images")
CELL_SIZE = 64
MAX_ROWS = 20
MAX_COLS = 20
NUM_SAMPLES = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Unified Cost Table
COST_TABLES = {
    "Desert": {"Road": 1.2, "Start": 1.2, "End": 2.2, "Hazard": 3.7, "Cacti": 999.0, "Rocks": 999.0, "Obstacle": 999.0, "Unknown": 8.0},
    "Forest": {"Road": 1.5, "Start": 1.5, "End": 2.5, "Hazard": 2.8, "Tree": 999.0, "Obstacle": 999.0, "Unknown": 8.0},
    "Lab":    {"Road": 1.0, "Start": 1.0, "End": 2.0, "Hazard": 5.0, "Wall": 999.0, "Plasma": 999.0, "Obstacle": 999.0, "Unknown": 8.0}
}

# ==========================================
# 2. DEFINE BOTH ARCHITECTURES
# ==========================================
class SimpleTerrainCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleTerrainCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.pool(self.relu(self.conv3(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ==========================================
# 3. LOAD BOTH MODELS
# ==========================================
with open(MAPPING_PATH, 'r') as f:
    idx_to_class = json.load(f)
    idx_to_class = {int(k): v for k, v in idx_to_class.items()}
num_classes = len(idx_to_class)

print("üöÄ Loading Hybrid Engine...")

# --- Load Model A (Simple CNN) ---
model_a = SimpleTerrainCNN(num_classes).to(device)
try:
    model_a.load_state_dict(torch.load(PATH_SIMPLE_CNN, map_location=device))
    print("‚úÖ Model A (SimpleCNN) Loaded")
except:
    print("‚ö†Ô∏è Model A not found or mismatch. Check path.")

# --- Load Model B (ResNet) ---
model_b = models.resnet34(pretrained=False) # Or resnet18 depending on what you trained last
num_ftrs = model_b.fc.in_features
model_b.fc = nn.Linear(num_ftrs, num_classes)
try:
    model_b.load_state_dict(torch.load(PATH_RESNET, map_location=device))
    print("‚úÖ Model B (ResNet) Loaded")
except:
    print("‚ö†Ô∏è Model B not found or mismatch. Check path.")

model_a.eval()
model_b.eval()

# --- Define Transforms for each ---
# SimpleCNN usually expects raw tensors
transform_a = transforms.Compose([
    transforms.ToPILImage(), transforms.Resize((64, 64)), transforms.ToTensor()
])

# ResNet expects Normalization
transform_b = transforms.Compose([
    transforms.ToPILImage(), transforms.Resize((64, 64)), transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# ==========================================
# 4. VISION HELPERS (Grid Extraction)
# ==========================================
def preprocess_for_grid(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    no_texture = cv2.medianBlur(gray, 7)
    thresh = cv2.adaptiveThreshold(no_texture, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                   cv2.THRESH_BINARY_INV, 19, 5)
    return thresh

def get_intersections(img):
    thresh = preprocess_for_grid(img)
    scale = 25
    h_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (scale, 1))
    v_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, scale))
    mask_h = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, h_kernel)
    mask_v = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, v_kernel)
    intersections = cv2.bitwise_and(mask_h, mask_v)
    intersections = cv2.dilate(intersections, np.ones((5,5)))
    num_labels, _, stats, centroids = cv2.connectedComponentsWithStats(intersections)
    points = [centroids[i] for i in range(1, num_labels) if stats[i, cv2.CC_STAT_AREA] > 10]
    if not points: return np.array([])
    clustering = DBSCAN(eps=20, min_samples=1).fit(points)
    points = np.array(points)
    clean_points = [np.mean(points[clustering.labels_ == label], axis=0) for label in set(clustering.labels_)]
    return np.array(clean_points)

def sort_points_robust(points):
    y_clustering = DBSCAN(eps=25, min_samples=3).fit(points[:, 1].reshape(-1, 1))
    rows_dict = {}
    for pt, label in zip(points, y_clustering.labels_):
        if label == -1: continue
        if label not in rows_dict: rows_dict[label] = []
        rows_dict[label].append(pt)
    sorted_keys = sorted(rows_dict.keys(), key=lambda k: np.mean([p[1] for p in rows_dict[k]]))
    return [np.array(sorted(rows_dict[k], key=lambda p: p[0])) for k in sorted_keys]

def get_global_column_grid(grid_rows):
    all_x = [pt[0] for row in grid_rows for pt in row]
    if not all_x: return np.array([])
    all_x = np.array(all_x).reshape(-1, 1)
    clustering = DBSCAN(eps=15, min_samples=1).fit(all_x)
    col_centers = [np.mean(all_x[clustering.labels_ == label]) for label in set(clustering.labels_) if label != -1]
    return np.array(sorted(col_centers))

def extract_grid_data(image_path):
    img = cv2.imread(str(image_path))
    if img is None: return None, None, "Load Error"

    grid_tensor = np.zeros((MAX_ROWS, MAX_COLS, CELL_SIZE, CELL_SIZE, 3), dtype=np.uint8)
    mask_tensor = np.zeros((MAX_ROWS, MAX_COLS), dtype=np.uint8)

    try:
        points = get_intersections(img)
        if len(points) < 40: return grid_tensor, mask_tensor, "Fallback: Low Pts"
        grid_rows = sort_points_robust(points)
        if len(grid_rows) < 4: return grid_tensor, mask_tensor, "Fallback: Few Rows"
        col_centers = get_global_column_grid(grid_rows)
        if len(col_centers) < 2: return grid_tensor, mask_tensor, "Fallback: No Cols"

        for r in range(min(len(grid_rows) - 1, MAX_ROWS)):
            row_top = grid_rows[r]
            row_btm = grid_rows[r+1]
            for pt_top in row_top:
                neighbors_tr = [p for p in row_top if p[0] > pt_top[0]]
                if not neighbors_tr: continue
                pt_tr = min(neighbors_tr, key=lambda p: p[0])
                if abs(pt_tr[0] - pt_top[0]) > 100: continue
                candidates_bl = [p for p in row_btm if abs(p[0] - pt_top[0]) < 70]
                if not candidates_bl: continue
                pt_bl = min(candidates_bl, key=lambda p: abs(p[0] - pt_top[0]))
                candidates_br = [p for p in row_btm if abs(p[0] - pt_tr[0]) < 70]
                if not candidates_br: continue
                pt_br = min(candidates_br, key=lambda p: abs(p[0] - pt_tr[0]))

                src = np.array([pt_top, pt_tr, pt_br, pt_bl], dtype="float32")
                dst = np.array([[0,0], [CELL_SIZE,0], [CELL_SIZE,CELL_SIZE], [0,CELL_SIZE]], dtype="float32")
                M = cv2.getPerspectiveTransform(src, dst)
                warped = cv2.warpPerspective(img, M, (CELL_SIZE, CELL_SIZE))

                diffs = np.abs(col_centers - pt_top[0])
                c_idx = np.argmin(diffs)
                if c_idx < MAX_COLS:
                    grid_tensor[r, c_idx] = warped
                    mask_tensor[r, c_idx] = 1
        return grid_tensor, mask_tensor, "Active: Gridded"
    except: return grid_tensor, mask_tensor, "Error"

# ==========================================
# 5. THE ROUTER (Determine Desert vs Non-Desert)
# ==========================================
def detect_desert_mode(image_path):
    """
    Returns True if the image is likely Desert (Yellow/Red dominant).
    Returns False if Lab/Forest (Green/Blue/Grey dominant).
    """
    img = cv2.imread(str(image_path))
    # Resize to 1x1 to get average color
    avg_color = cv2.resize(img, (1, 1)).reshape(3)
    b, g, r = avg_color

    # Desert Logic: Red & Green (Yellow) are significantly higher than Blue
    # Lab/Forest Logic: Blue is present, or Green is dominant without Red

    is_desert = False

    # Simple check: Is Red dominant? (Sand is yellowish-red)
    if r > b + 20 and g > b + 10:
        is_desert = True

    return is_desert

# ==========================================
# 6. HYBRID BATCH PROCESSOR
# ==========================================
def process_single_image(image_path):
    # --- ROUTING STEP ---
    is_desert = detect_desert_mode(image_path)

    if is_desert:
        active_model = model_b  # ResNet
        active_transform = transform_b
        model_name = "RESNET (Desert)"
    else:
        active_model = model_a  # SimpleCNN
        active_transform = transform_a
        model_name = "SimpleCNN (Std)"

    # 1. Slice
    grid_tensor, mask_tensor, status = extract_grid_data(image_path)
    if "Active" not in status: return None, "Slicer Fail"

    rows, cols = mask_tensor.shape
    batch_tensors = []
    coords = []

    # 2. Prepare Batch
    for r in range(rows):
        for c in range(cols):
            if mask_tensor[r, c] == 1:
                # BGR -> RGB
                rgb_cell = cv2.cvtColor(grid_tensor[r, c], cv2.COLOR_BGR2RGB)
                # Apply Model-Specific Transform
                batch_tensors.append(active_transform(rgb_cell))
                coords.append((r, c))

    if not batch_tensors: return None, "Empty Grid"

    # 3. Predict
    batch_stack = torch.stack(batch_tensors).to(device)
    with torch.no_grad():
        outputs = active_model(batch_stack)
        _, preds = torch.max(outputs, 1)

    # 4. Build Map
    terrain_map = np.full((rows, cols), "Unknown", dtype=object)
    counts = {"Desert": 0, "Forest": 0, "Lab": 0}

    for i, (r, c) in enumerate(coords):
        class_name = idx_to_class[preds[i].item()]
        terrain_map[r, c] = class_name

        biome = class_name.split("_")[0] if "_" in class_name else "Lab"

        weight = 1
        if any(x in class_name for x in ["Cacti", "Tree", "Plasma", "Wall", "Rocks"]): weight = 3
        if "Start" in class_name or "End" in class_name: weight = 2

        counts[biome] = counts.get(biome, 0) + weight

    dominant_biome = max(counts, key=counts.get) if counts else "Lab"
    costs = COST_TABLES.get(dominant_biome, COST_TABLES["Lab"])

    # 5. Solver (Standard)
    start_pos, end_pos = None, None
    for r in range(rows):
        for c in range(cols):
            if "Start" in terrain_map[r, c]: start_pos = (r, c)
            if "End" in terrain_map[r, c]: end_pos = (r, c)

    path = []
    status_msg = "No Path"

    if start_pos and end_pos:
        pq = [(0, start_pos)]
        cost_so_far = {start_pos: 0}
        came_from = {}

        while pq:
            curr_cost, curr = heapq.heappop(pq)
            if curr == end_pos:
                status_msg = "Solved"
                break

            r, c = curr
            for nr, nc in [(r-1,c), (r+1,c), (r,c-1), (r,c+1)]:
                if 0 <= nr < rows and 0 <= nc < cols:
                    label = terrain_map[nr, nc]

                    if label == "Unknown": cell_type = "Unknown"
                    else: cell_type = label.split("_")[1] if "_" in label else label

                    step_cost = costs.get(cell_type, 999.0)
                    new_cost = cost_so_far[curr] + step_cost

                    if step_cost < 100:
                        if (nr, nc) not in cost_so_far or new_cost < cost_so_far[(nr, nc)]:
                            cost_so_far[(nr, nc)] = new_cost
                            priority = new_cost + abs(nr-end_pos[0]) + abs(nc-end_pos[1])
                            heapq.heappush(pq, (priority, (nr, nc)))
                            came_from[(nr, nc)] = curr

        if status_msg == "Solved":
            curr = end_pos
            while curr != start_pos:
                path.append(curr)
                curr = came_from[curr]
            path.append(start_pos)

    # 6. Visualization
    img_vis = cv2.imread(str(image_path))
    img_vis = cv2.resize(img_vis, (400, 400))
    cell_h, cell_w = 400 // rows, 400 // cols

    for r in range(rows):
        for c in range(cols):
            if mask_tensor[r, c] == 1:
                lbl = terrain_map[r, c]
                cx, cy = int((c + 0.5) * cell_w), int((r + 0.5) * cell_h)
                color = (100, 100, 100)
                if "Start" in lbl: color = (0, 255, 255)
                if "End" in lbl: color = (0, 0, 255)
                if any(x in lbl for x in ["Obstacle", "Wall", "Tree", "Cacti", "Rocks", "Plasma"]): color = (0, 0, 0)
                if "Hazard" in lbl: color = (0, 165, 255)
                cv2.circle(img_vis, (cx, cy), 3, color, -1)
            else:
                cx, cy = int((c + 0.5) * cell_w), int((r + 0.5) * cell_h)
                cv2.circle(img_vis, (cx, cy), 1, (50, 50, 50), -1)

    if path:
        for i in range(len(path) - 1):
            p1 = (int((path[i][1]+0.5)*cell_w), int((path[i][0]+0.5)*cell_h))
            p2 = (int((path[i+1][1]+0.5)*cell_w), int((path[i+1][0]+0.5)*cell_h))
            cv2.line(img_vis, p1, p2, (0, 255, 0), 2)

    # Add Model Name to Status for Debugging
    return img_vis, f"{model_name}\n{dominant_biome}: {status_msg}"

# ==========================================
# 7. RUN
# ==========================================
test_files = list(TEST_IMG_DIR.glob("*.png"))
samples = random.sample(test_files, min(len(test_files), NUM_SAMPLES))

print(f"Processing {len(samples)} images...")
results = []
for img_path in tqdm(samples):
    res_img, status = process_single_image(img_path)
    if res_img is not None: results.append((res_img, status))

rows = (len(results) + 4) // 5
fig, axes = plt.subplots(rows, 5, figsize=(20, 4 * rows))
axes = axes.flatten()
for i, (img, status) in enumerate(results):
    ax = axes[i]
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    title_color = 'green' if "Solved" in status else 'red'
    ax.set_title(status, color=title_color, fontsize=9, fontweight='bold')
    ax.axis('off')
for j in range(i + 1, len(axes)): axes[j].axis('off')
plt.tight_layout()
plt.show()