# Net Localization (Cartesian-first)

Load sonar frames, convert to Cartesian, and extract net structure (edges/lines) for debugging.

In [None]:
# Google Colab Setup (uncomment if using Colab)
# from google.colab import drive
# drive.mount('/content/drive')

# # Update the path to where you uploaded your data
# COLAB_MODE = True
# COLAB_DATA_PATH = '/content/drive/MyDrive/sonar-sim/simulation/data/runs/net_following_fish'

# For local execution
COLAB_MODE = False
COLAB_DATA_PATH = None

print(f"Colab mode: {COLAB_MODE}")

## 0. Setup for Google Colab (Optional)

If running on Google Colab, uncomment and run this cell first to mount your Google Drive and set the correct path to your data.

In [3]:
import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import cv2
from ipywidgets import interact, IntSlider, fixed
from scipy import ndimage
import sys

# Basic evidence extraction helpers (inline, no external module)
def make_config():
    return {
        "denoise": {"d": 5, "sigmaColor": 50, "sigmaSpace": 50},
        "clahe": {"clipLimit": 2.0, "tileGridSize": (8, 8)},
        "canny": {"low": 100, "high": 120},
        "hough": {"rho": 1, "theta": np.pi / 180, "threshold": 50, "minLineLength": 40, "maxLineGap": 4},
        "lsd": {"scale": 0.8, "sigma_scale": 0.6, "quant": 2.0, "ang_th": 22.5, "log_eps": 0, "density_th": 0.7, "n_bins": 1024},
        "threshold": {"value": 80, "maxval": 255, "type": cv2.THRESH_BINARY},  # Direct intensity threshold
        "morph": {"kernel_size": 3, "iterations": 1},  # Morphological operations
    }

CFG = make_config()

# Steps

def denoise(img_u8):
    return cv2.bilateralFilter(img_u8, d=CFG["denoise"]["d"], sigmaColor=CFG["denoise"]["sigmaColor"], sigmaSpace=CFG["denoise"]["sigmaSpace"])


def enhance(img_u8):
    clahe = cv2.createCLAHE(clipLimit=CFG["clahe"]["clipLimit"], tileGridSize=CFG["clahe"]["tileGridSize"])
    return clahe.apply(img_u8)


def edges(img_u8):
    return cv2.Canny(img_u8, CFG["canny"]["low"], CFG["canny"]["high"])


def lines(edges_img):
    result = cv2.HoughLinesP(
        edges_img,
        rho=CFG["hough"]["rho"],
        theta=CFG["hough"]["theta"],
        threshold=CFG["hough"]["threshold"],
        minLineLength=CFG["hough"]["minLineLength"],
        maxLineGap=CFG["hough"]["maxLineGap"],
    )
    return result if result is not None else []


# ALTERNATIVE METHODS (no edge detection needed)

def lines_lsd(img_u8):
    """Line Segment Detector - detects lines directly without edge detection"""
    lsd = cv2.createLineSegmentDetector(
        cv2.LSD_REFINE_STD,
        scale=CFG["lsd"]["scale"],
        sigma_scale=CFG["lsd"]["sigma_scale"],
        quant=CFG["lsd"]["quant"],
        ang_th=CFG["lsd"]["ang_th"],
        log_eps=CFG["lsd"]["log_eps"],
        density_th=CFG["lsd"]["density_th"],
        n_bins=CFG["lsd"]["n_bins"]
    )
    lines_result = lsd.detect(img_u8)[0]
    
    # Convert to HoughLinesP format: [[x1, y1, x2, y2]]
    if lines_result is not None:
        return lines_result.reshape(-1, 1, 4).astype(np.int32)
    return []


def morphological_skeleton(binary):
    """Simple morphological skeleton using erosion (scipy-based)"""
    skeleton = np.zeros_like(binary, dtype=bool)
    element = np.array([[0, 1, 0],
                        [1, 1, 1],
                        [0, 1, 0]], dtype=bool)
    
    eroded = binary.astype(bool)
    while np.any(eroded):
        opened = ndimage.binary_opening(eroded, structure=element)
        skeleton |= eroded & ~opened
        eroded = ndimage.binary_erosion(eroded, structure=element)
    
    return skeleton


def threshold_skeleton(img_u8):
    """Threshold + morphological skeleton to extract line structure"""
    # Apply threshold to keep only bright pixels (likely net)
    _, binary = cv2.threshold(img_u8, CFG["threshold"]["value"], 
                              CFG["threshold"]["maxval"], 
                              CFG["threshold"]["type"])
    
    # Morphological closing to connect nearby regions
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, 
                                       (CFG["morph"]["kernel_size"], 
                                        CFG["morph"]["kernel_size"]))
    closed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, 
                              iterations=CFG["morph"]["iterations"])
    
    # Extract skeleton using scipy
    binary_bool = closed > 0
    skeleton_bool = morphological_skeleton(binary_bool)
    skeleton = (skeleton_bool * 255).astype(np.uint8)
    
    return skeleton


def gradient_magnitude(img_u8):
    """Calculate gradient magnitude to highlight intensity changes"""
    # Compute gradients using Sobel
    grad_x = cv2.Sobel(img_u8, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(img_u8, cv2.CV_64F, 0, 1, ksize=3)
    
    # Magnitude
    magnitude = np.sqrt(grad_x**2 + grad_y**2)
    magnitude = np.clip(magnitude, 0, 255).astype(np.uint8)
    
    return magnitude

print("Imports OK; inline extractor ready (Canny + Hough | LSD | Threshold+Skeleton)")

Imports OK; inline extractor ready (Canny + Hough | LSD | Threshold+Skeleton)


## 1. Load frames and metadata

In [None]:
# Paths and frame selection
if COLAB_MODE and COLAB_DATA_PATH:
    run_path = Path(COLAB_DATA_PATH)
else:
    run_path = Path.cwd().parent / 'simulation' / 'data' / 'runs' / 'net_following_fish'

frames_dir = run_path / 'frames'
frame_files = sorted(frames_dir.glob('frame_*.npz'))[:1000]

sonar_images = []
fov_list = []
range_list = []
meta_list = []

for f in frame_files:
    data = np.load(f, mmap_mode='r')
    if 'sonar_image' not in data:
        data.close(); continue
    sonar_images.append(np.array(data['sonar_image']))
    fov_deg = 120.0
    range_m = 20.0
    meta = {}
    if 'meta_json' in data:
        meta = json.loads(str(data['meta_json']))
        fov_deg = float(meta.get('fov_deg', fov_deg))
        range_m = float(meta.get('range_m', range_m))
    fov_list.append(fov_deg)
    range_list.append(range_m)
    meta_list.append(meta)
    data.close()

print(f'Loaded {len(sonar_images)} frames')
print(f'Example shape: {sonar_images[0].shape if sonar_images else None}')
print(f'FOV range: {min(fov_list):.1f}–{max(fov_list):.1f} deg, range: {min(range_list):.1f}–{max(range_list):.1f} m')

Loaded 0 frames
Example shape: None
No frames loaded - check that the path exists and contains frame_*.npz files


## 2. Polar → Cartesian conversion helper

In [5]:
def polar_to_cartesian(polar_image, max_range_m=20.0, fov_deg=120.0, output_size=400):
    r_bins, n_beams = polar_image.shape
    half_width = max_range_m * np.sin(np.radians(fov_deg / 2))
    x = np.linspace(-half_width, half_width, output_size)
    y = np.linspace(0, max_range_m, output_size)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X**2 + Y**2)
    Theta = np.arctan2(X, Y)
    fov_rad = np.radians(fov_deg)
    beam_idx = (Theta + fov_rad / 2) / fov_rad * (n_beams - 1)
    range_idx = (R / max_range_m) * (r_bins - 1)
    beam_idx = np.clip(beam_idx, 0, n_beams - 1).astype(np.float32)
    range_idx = np.clip(range_idx, 0, r_bins - 1).astype(np.float32)
    cart = cv2.remap(polar_image.astype(np.float32), beam_idx, range_idx, cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    cart[(R > max_range_m) | (Theta < -fov_rad/2) | (Theta > fov_rad/2)] = 0
    extent = [-half_width, half_width, 0, max_range_m]
    return cart, extent

## 3. Visualize raw sonar (polar & Cartesian)

In [6]:
def show_raw(frame_idx=0):
    img = sonar_images[frame_idx]
    fov_deg = fov_list[frame_idx]; max_r = range_list[frame_idx]
    cart, extent = polar_to_cartesian(img, max_r, fov_deg, 400)
    fig, ax = plt.subplots(1, 2, figsize=(14, 6))
    ax[0].imshow(img, cmap='gray', origin='lower', aspect='auto')
    ax[0].set_title(f'Polar frame {frame_idx}')
    ax[0].set_xlabel('Beam'); ax[0].set_ylabel('Range bin')
    ax[1].imshow(cart, cmap='gray', origin='lower', extent=extent, aspect='equal')
    ax[1].set_title(f'Cartesian frame {frame_idx}')
    ax[1].set_xlabel('X (m)'); ax[1].set_ylabel('Y (m)')
    ax[1].grid(True, alpha=0.3)
    plt.tight_layout(); plt.show()

interact(show_raw, frame_idx=IntSlider(min=0, max=len(sonar_images)-1, value=0, step=1, description='Frame'));

TraitError: setting max < min

## 4. Net extraction in Cartesian (edges/lines)

In [None]:
def extract_net(frame_idx=0):
    img = sonar_images[frame_idx]
    fov_deg = fov_list[frame_idx]
    max_r = range_list[frame_idx]

    img_u8 = (img * 255).astype(np.uint8)
    den_img = denoise(img_u8)
    enh_img = enhance(den_img)
    edge_img = edges(enh_img)
    line_segments = lines(edge_img)

    cart_raw, extent = polar_to_cartesian(img, max_r, fov_deg, 400)
    cart_edges, _ = polar_to_cartesian(edge_img.astype(np.float32) / 255.0, max_r, fov_deg, 400)

    # Edge points to Cartesian
    r_bins, n_beams = img.shape
    fov_rad = np.radians(fov_deg)
    pts = np.argwhere(edge_img > 0)  # (row=r, col=beam)
    cart_pts = None
    if len(pts) > 0:
        r = (pts[:, 0] / (r_bins - 1)) * max_r
        th = ((pts[:, 1] / (n_beams - 1)) - 0.5) * fov_rad
        cx = r * np.sin(th)
        cy = r * np.cos(th)
        cart_pts = np.stack([cx, cy], axis=1)

    fig, ax = plt.subplots(2, 2, figsize=(12, 10))
    ax[0, 0].imshow(cart_raw, cmap='gray', origin='lower', extent=extent, aspect='equal')
    ax[0, 0].set_title('Cartesian raw')

    ax[0, 1].imshow(cart_edges, cmap='gray', origin='lower', extent=extent, aspect='equal')
    ax[0, 1].set_title(f'Edges (Canny)')

    # Lines overlay
    ax[1, 0].imshow(cart_raw, cmap='gray', origin='lower', extent=extent, aspect='equal', alpha=0.8)
    if line_segments is not None and len(line_segments) > 0:
        for seg in line_segments:
            x1, y1, x2, y2 = seg[0]
            r1 = (y1 / (r_bins - 1)) * max_r; r2 = (y2 / (r_bins - 1)) * max_r
            t1 = ((x1 / (n_beams - 1)) - 0.5) * fov_rad; t2 = ((x2 / (n_beams - 1)) - 0.5) * fov_rad
            cx1, cy1 = r1 * np.sin(t1), r1 * np.cos(t1)
            cx2, cy2 = r2 * np.sin(t2), r2 * np.cos(t2)
            ax[1, 0].plot([cx1, cx2], [cy1, cy2], color='yellow', linewidth=2, alpha=0.8)
    ax[1, 0].set_title(f'Lines (Hough)')

    # Edge points scatter
    ax[1, 1].imshow(cart_raw, cmap='gray', origin='lower', extent=extent, aspect='equal', alpha=0.5)
    if cart_pts is not None:
        ax[1, 1].scatter(cart_pts[:, 0], cart_pts[:, 1], s=4, c='cyan', alpha=0.6)
    ax[1, 1].set_title(f'Edge points ({0 if cart_pts is None else len(cart_pts)})')

    for a in ax.flat:
        a.grid(True, alpha=0.3)
        a.set_xlabel('X (m)')
        a.set_ylabel('Y (m)')

    plt.tight_layout()
    plt.show()

interact(extract_net, frame_idx=IntSlider(min=0, max=len(sonar_images)-1, value=0, step=1, description='Frame'));

interactive(children=(IntSlider(value=0, description='Frame', max=999), Output()), _dom_classes=('widget-inter…

## 5. Ground Truth for ML: Semantic Segmentation

**Why Segmentation beats Bounding Boxes:**
- ✅ Nets curve, bend, and have irregular shapes
- ✅ Pixel-wise labels capture exact net geometry  
- ✅ No wasted space including fish/debris
- ✅ Works for partial/occluded nets

**Recommended Approach: Binary Segmentation**
- **Input**: Sonar image (1024×256 or 400×400 Cartesian)
- **Output**: Binary mask (net=1, background=0)
- **Architecture**: U-Net, SegFormer, or DeepLabV3+
- **Loss**: Dice Loss + BCE (handles class imbalance)

**Data Pipeline:**
1. Extract ground truth material IDs from simulation
2. Convert net material → binary mask
3. Apply polar-to-Cartesian transform to both image & mask
4. Train segmentation model
5. Post-process predictions with line fitting/thinning

In [None]:
def create_segmentation_mask(frame_idx=0, net_material_id=1, thickness=3):
    """
    Create binary segmentation mask from ground truth material IDs.
    
    Material IDs from simulation (simulation/src/core/materials.py):
        0 = EMPTY (water/air)
        1 = NET (fishing nets) ← DEFAULT
        2 = ROPE (support ropes)
        3 = FISH (fish bodies)
        4 = WALL (solid barriers)
        5 = BIOMASS (organic accumulation)
        6-8 = DEBRIS (light/medium/heavy)
        9 = CONCRETE, 10 = WOOD, 11 = FOLIAGE, 12 = METAL, 13 = GLASS
    
    Args:
        frame_idx: Frame index
        net_material_id: Material ID to extract (default 1 = NET)
        thickness: Line thickness in pixels for visualization/training
    
    Returns:
        mask_polar, mask_cartesian, img_cartesian, extent
    """
    # Load ground truth
    frame_path = frame_files[frame_idx]
    data = np.load(frame_path)
    
    if 'ground_truth' not in data:
        print("No ground truth available")
        return None, None
    
    ground_truth = data['ground_truth']  # Material ID map (1024, 256)
    sonar_img = data['sonar_image']
    
    # Create binary mask: net=1, everything else=0
    mask_polar = (ground_truth == net_material_id).astype(np.uint8) * 255
    
    # Optionally thicken lines for better training
    if thickness > 1:
        kernel = np.ones((thickness, thickness), np.uint8)
        mask_polar = cv2.dilate(mask_polar, kernel, iterations=1)
    
    # Convert to Cartesian
    fov_deg = fov_list[frame_idx]
    max_r = range_list[frame_idx]
    mask_cart, extent = polar_to_cartesian(mask_polar.astype(np.float32) / 255.0, max_r, fov_deg, 400)
    mask_cart = (mask_cart > 0.5).astype(np.uint8) * 255  # Re-binarize after interpolation
    
    # Also get image
    img_cart, _ = polar_to_cartesian(sonar_img, max_r, fov_deg, 400)
    
    data.close()
    
    return mask_polar, mask_cart, img_cart, extent


def visualize_segmentation_gt(frame_idx=0, net_material_id=1, thickness=2):
    """
    Visualize sonar image + ground truth mask overlay.
    
    Use the Mat ID slider to explore different materials:
    - 1 = NET (default - should show net structure)
    - 2 = ROPE (net support cables)
    - 3 = FISH (individual fish)
    - 5 = BIOMASS (accumulated organic matter)
    """
    result = create_segmentation_mask(frame_idx, net_material_id, thickness)
    
    if result[0] is None:
        print("No ground truth available for this frame")
        return
    
    mask_polar, mask_cart, img_cart, extent = result
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 6))
    
    # Raw image
    ax[0].imshow(img_cart, cmap='gray', origin='lower', extent=extent, aspect='equal')
    ax[0].set_title('Sonar Image (Input)')
    ax[0].grid(True, alpha=0.3)
    ax[0].set_xlabel('X (m)'); ax[0].set_ylabel('Y (m)')
    
    # Ground truth mask
    ax[1].imshow(mask_cart, cmap='hot', origin='lower', extent=extent, aspect='equal')
    ax[1].set_title(f'Ground Truth Mask (Mat ID={net_material_id})')
    ax[1].grid(True, alpha=0.3)
    ax[1].set_xlabel('X (m)'); ax[1].set_ylabel('Y (m)')
    
    # Overlay
    ax[2].imshow(img_cart, cmap='gray', origin='lower', extent=extent, aspect='equal', alpha=0.8)
    ax[2].imshow(mask_cart, cmap='hot', origin='lower', extent=extent, aspect='equal', alpha=0.4)
    ax[2].set_title('Overlay (Image + Mask)')
    ax[2].grid(True, alpha=0.3)
    ax[2].set_xlabel('X (m)'); ax[2].set_ylabel('Y (m)')
    
    # Stats
    net_pixels = np.sum(mask_cart > 0)
    total_pixels = mask_cart.size
    coverage = 100 * net_pixels / total_pixels
    print(f"Material ID {net_material_id}: {net_pixels}/{total_pixels} pixels ({coverage:.2f}% coverage)")
    
    plt.tight_layout()
    plt.show()

interact(visualize_segmentation_gt, 
         frame_idx=IntSlider(min=0, max=len(sonar_images)-1, value=0, step=1, description='Frame'),
         net_material_id=IntSlider(min=1, max=13, value=1, step=1, description='Mat ID'),
         thickness=IntSlider(min=1, max=5, value=2, step=1, description='Thickness'));


interactive(children=(IntSlider(value=0, description='Frame', max=999), IntSlider(value=1, description='Mat ID…

## 6. ML Pipeline: Step 1 - Export Training Data

In [None]:
def export_training_data(output_dir='../training_data', net_material_id=1, thickness=2, 
                         train_split=0.8, num_frames=None):
    """
    Export training dataset as PNG image pairs.
    
    Args:
        output_dir: Directory to save data
        net_material_id: Material ID to extract (1=NET)
        thickness: Mask thickness in pixels
        train_split: Fraction for training (rest is validation)
        num_frames: Number of frames to export (None = all)
    
    Creates:
        output_dir/
            train/
                images/000000.png, 000001.png, ...
                masks/000000.png, 000001.png, ...
            val/
                images/000800.png, 000801.png, ...
                masks/000800.png, 000801.png, ...
    """
    from pathlib import Path
    from PIL import Image
    
    output_path = Path(output_dir)
    
    # Create directories
    for split in ['train', 'val']:
        (output_path / split / 'images').mkdir(parents=True, exist_ok=True)
        (output_path / split / 'masks').mkdir(parents=True, exist_ok=True)
    
    # Determine frames to process
    total_frames = len(frame_files) if num_frames is None else min(num_frames, len(frame_files))
    train_count = int(total_frames * train_split)
    
    print(f"Exporting {total_frames} frames ({train_count} train, {total_frames - train_count} val)...")
    print(f"Material ID: {net_material_id}, Thickness: {thickness}px")
    
    for i in range(total_frames):
        # Generate mask
        result = create_segmentation_mask(i, net_material_id, thickness)
        if result[0] is None:
            print(f"Skipping frame {i}: no ground truth")
            continue
        
        mask_polar, mask_cart, img_cart, extent = result
        
        # Normalize image to 0-255
        img_norm = ((img_cart - img_cart.min()) / (img_cart.max() - img_cart.min() + 1e-8) * 255).astype(np.uint8)
        
        # Determine split
        split = 'train' if i < train_count else 'val'
        
        # Save as PNG
        img_pil = Image.fromarray(img_norm)
        mask_pil = Image.fromarray(mask_cart)
        
        img_pil.save(output_path / split / 'images' / f'{i:06d}.png')
        mask_pil.save(output_path / split / 'masks' / f'{i:06d}.png')
        
        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{total_frames} frames...")
    
    print(f"\n✅ Dataset exported to: {output_path.absolute()}")
    print(f"   Train: {train_count} image-mask pairs")
    print(f"   Val: {total_frames - train_count} image-mask pairs")
    
    return output_path

# Export dataset (uncomment to run)
export_training_data(num_frames=1000, train_split=0.8)

Exporting 1000 frames (800 train, 200 val)...
Material ID: 1, Thickness: 2px
  Processed 100/1000 frames...
  Processed 200/1000 frames...
  Processed 300/1000 frames...
  Processed 400/1000 frames...
  Processed 500/1000 frames...
  Processed 600/1000 frames...
  Processed 700/1000 frames...
  Processed 800/1000 frames...
  Processed 900/1000 frames...
  Processed 1000/1000 frames...

✅ Dataset exported to: /Users/eirikvarnes/code/sonar-sim/net_localization/../training_data
   Train: 800 image-mask pairs
   Val: 200 image-mask pairs


PosixPath('../training_data')

## 7. ML Pipeline: Step 2 - Install Dependencies

**For Google Colab:** These will install on the cloud instance.  
**For local:** Skip if already installed with `pip install -r requirements_ml.txt`

In [1]:
# Install ML dependencies (uncomment to run)
# !pip install torch torchvision segmentation-models-pytorch albumentations tqdm

print("✅ Dependencies ready (uncomment above to install)")

✅ Dependencies ready (uncomment above to install)


In [None]:
## 8. ML Pipeline: Step 3 - Dataset & Training Classes

Training code template prepared. Save as 'train_segmentation.py' to use.

Quick start:
  1. Export data: export_training_data(num_frames=1000)
  2. Install deps: pip install torch segmentation-models-pytorch albumentations
  3. Train: python train_segmentation.py --epochs 50


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import segmentation_models_pytorch as smp
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2


class SonarDataset(Dataset):
    """Load sonar images and masks from PNG files"""
    def __init__(self, data_dir, split='train', transform=None):
        self.img_dir = Path(data_dir) / split / 'images'
        self.mask_dir = Path(data_dir) / split / 'masks'
        self.files = sorted(self.img_dir.glob('*.png'))
        self.transform = transform
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        mask_path = self.mask_dir / img_path.name
        
        # Load as grayscale
        image = np.array(Image.open(img_path))
        mask = np.array(Image.open(mask_path))
        
        # Convert to 3-channel (required by pretrained models)
        if len(image.shape) == 2:
            image = np.stack([image, image, image], axis=-1)
        
        # Binarize mask
        mask = (mask > 127).astype(np.float32)
        
        # Apply augmentations
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        
        return image, mask.unsqueeze(0)  # Add channel dim


def get_transforms(train=True):
    """Data augmentation for sonar images"""
    if train:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=15, p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])


class DiceLoss(nn.Module):
    """Dice Loss for handling class imbalance"""
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = torch.sigmoid(pred)
        intersection = (pred * target).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()


print("✅ Dataset and loss classes defined")

## 9. ML Pipeline: Step 4 - Train Model

**Training Parameters:**
- `epochs`: Number of training epochs (start with 20-50)
- `batch_size`: Batch size (8 for GPU, 2-4 for CPU)
- `lr`: Learning rate (1e-4 is good default)

**Google Colab:** Make sure to enable GPU runtime (Runtime → Change runtime type → GPU)

In [None]:
def train_model(data_dir='training_data', epochs=20, batch_size=8, lr=1e-4):
    """Train U-Net segmentation model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Validate data directory
    data_path = Path(data_dir)
    if not data_path.exists():
        raise FileNotFoundError(f"Data directory not found: {data_dir}")
    
    # Create datasets
    train_dataset = SonarDataset(data_dir, 'train', get_transforms(train=True))
    val_dataset = SonarDataset(data_dir, 'val', get_transforms(train=False))
    
    # Use num_workers=0 for compatibility (Colab/macOS)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    print(f"Train: {len(train_dataset)} images, Val: {len(val_dataset)} images")
    
    # Create model (U-Net with ResNet34 encoder)
    model = smp.Unet(
        encoder_name='resnet34',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
        activation=None  # We'll apply sigmoid in loss
    ).to(device)
    
    # Loss and optimizer
    criterion = DiceLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5)
    
    best_val_loss = float('inf')
    
    # Training loop
    for epoch in range(epochs):
        # Train
        model.train()
        train_loss = 0
        for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'):
            images, masks = images.to(device), masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validate
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_net_segmentation.pth')
            print(f"  ✅ Saved best model (val_loss={val_loss:.4f})")
    
    print("\n✅ Training complete!")
    return model

# Train the model (uncomment to run)
# trained_model = train_model(data_dir='training_data', epochs=20, batch_size=8)

In [None]:
def load_trained_model(model_path='best_net_segmentation.pth'):
    """Load trained segmentation model"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = smp.Unet(
        encoder_name='resnet34',
        encoder_weights=None,
        in_channels=3,
        classes=1,
        activation='sigmoid'
    ).to(device)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    print(f"✅ Model loaded from {model_path} on {device}")
    return model, device


def predict_net_mask(model, device, image_cart):
    """
    Run inference on Cartesian sonar image.
    
    Args:
        model: Trained segmentation model
        device: torch device
        image_cart: Cartesian sonar image (H, W)
    
    Returns:
        Binary mask (H, W) with net predictions
    """
    # Prepare image
    img_norm = ((image_cart - image_cart.min()) / (image_cart.max() - image_cart.min() + 1e-8) * 255).astype(np.uint8)
    img_3ch = np.stack([img_norm, img_norm, img_norm], axis=-1)
    
    # Transform
    transform = A.Compose([
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])
    img_tensor = transform(image=img_3ch)['image'].unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(img_tensor)
        mask_pred = (output[0, 0].cpu().numpy() > 0.5).astype(np.uint8) * 255
    
    return mask_pred

# Load model (uncomment after training)
# model, device = load_trained_model('best_net_segmentation.pth')
print("Model loading functions ready")

In [None]:
## 11. ML Pipeline: Step 6 - Visualize Results with Trained Model

In [None]:
def extract_net_lines_from_mask(mask, min_length=20):
    """
    Convert binary mask to line segments.
    
    This is the key post-processing step that converts ML predictions
    back to geometric representations (lines) for tracking/control.
    
    Args:
        mask: Binary mask (255 = net, 0 = background)
        min_length: Minimum line length in pixels
    
    Returns:
        lines: Array of line segments [(x1,y1,x2,y2), ...]
    """
    # 1. Thin the mask to skeleton (1-pixel wide)
    mask_binary = (mask > 127).astype(np.uint8)
    skeleton = morphological_skeleton(mask_binary)
    skeleton_u8 = (skeleton * 255).astype(np.uint8)
    
    # 2. Find contours
    contours, _ = cv2.findContours(skeleton_u8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 3. Fit lines to contours
    lines = []
    for contour in contours:
        if len(contour) < 5:  # Need at least 5 points for fitLine
            continue
        
        # Fit line using least squares
        [vx, vy, x, y] = cv2.fitLine(contour, cv2.DIST_L2, 0, 0.01, 0.01)
        
        # Get endpoints from contour extent
        pts = contour.reshape(-1, 2)
        x_min, y_min = pts.min(axis=0)
        x_max, y_max = pts.max(axis=0)
        
        # Project endpoints onto fitted line
        length = np.sqrt((x_max - x_min)**2 + (y_max - y_min)**2)
        
        if length < min_length:
            continue
        
        # Compute line endpoints
        t1 = -length / 2
        t2 = length / 2
        x1, y1 = x + vx * t1, y + vy * t1
        x2, y2 = x + vx * t2, y + vy * t2
        
        lines.append([x1[0], y1[0], x2[0], y2[0]])
    
    return np.array(lines) if lines else np.array([]).reshape(0, 4)


def visualize_ml_pipeline(frame_idx=0, model=None, device=None, show_gt=True):
    """
    Visualize complete ML pipeline: Input → Prediction → Lines.
    
    Args:
        frame_idx: Frame index
        model: Trained model (None = use ground truth as proxy)
        device: torch device
        show_gt: Show ground truth comparison
    """
    # Load data
    result = create_segmentation_mask(frame_idx, net_material_id=1, thickness=2)
    if result[0] is None:
        print("No data for this frame")
        return
    
    mask_polar, mask_gt, img_cart, extent = result
    
    # Predict (or use ground truth as proxy)
    if model is not None:
        mask_pred = predict_net_mask(model, device, img_cart)
        print(f"✅ Running TRAINED model prediction")
    else:
        mask_pred = mask_gt  # Use ground truth as "prediction" for demo
        print("⚠️  Using ground truth as proxy (train model first)")
    
    # Extract lines from prediction
    lines = extract_net_lines_from_mask(mask_pred, min_length=20)
    
    # Visualize
    n_cols = 4 if show_gt else 3
    fig, ax = plt.subplots(1, n_cols, figsize=(6*n_cols, 6))
    
    # 1. Input image
    ax[0].imshow(img_cart, cmap='gray', origin='lower', extent=extent, aspect='equal')
    ax[0].set_title('Input: Sonar Image')
    ax[0].grid(True, alpha=0.3)
    
    # 2. Predicted mask
    ax[1].imshow(img_cart, cmap='gray', origin='lower', extent=extent, aspect='equal', alpha=0.5)
    ax[1].imshow(mask_pred, cmap='hot', origin='lower', extent=extent, aspect='equal', alpha=0.5)
    title = 'ML Prediction' if model is not None else 'Ground Truth (Proxy)'
    ax[1].set_title(title)
    ax[1].grid(True, alpha=0.3)
    
    # 3. Extracted lines
    ax[2].imshow(img_cart, cmap='gray', origin='lower', extent=extent, aspect='equal', alpha=0.5)
    if len(lines) > 0:
        # Convert pixel coords to meters
        h, w = img_cart.shape
        x_min, x_max, y_min, y_max = extent
        for line in lines:
            x1_m = x_min + (line[0] / w) * (x_max - x_min)
            y1_m = y_min + (line[1] / h) * (y_max - y_min)
            x2_m = x_min + (line[2] / w) * (x_max - x_min)
            y2_m = y_min + (line[3] / h) * (y_max - y_min)
            ax[2].plot([x1_m, x2_m], [y1_m, y2_m], 'lime', linewidth=2, alpha=0.8)
    ax[2].set_title(f'Output: Line Segments ({len(lines)} lines)')
    ax[2].grid(True, alpha=0.3)
    
    # 4. Ground truth comparison (optional)
    if show_gt:
        ax[3].imshow(img_cart, cmap='gray', origin='lower', extent=extent, aspect='equal', alpha=0.5)
        ax[3].imshow(mask_gt, cmap='spring', origin='lower', extent=extent, aspect='equal', alpha=0.5)
        ax[3].set_title('Ground Truth (for comparison)')
        ax[3].grid(True, alpha=0.3)
    
    for a in ax:
        a.set_xlabel('X (m)')
        a.set_ylabel('Y (m)')
    
    plt.tight_layout()
    plt.show()
    
    print(f"✅ Extracted {len(lines)} line segments from mask")
    return lines

# Demo with ground truth as proxy (before training)
interact(visualize_ml_pipeline, 
         frame_idx=IntSlider(min=0, max=len(sonar_images)-1, value=0, description='Frame'),
         model=fixed(None),
         device=fixed(None),
         show_gt=fixed(True));

## 12. Test Trained Model (Run After Training)

After training is complete, use this cell to load the model and test predictions.

In [None]:
# Load trained model and visualize predictions (uncomment after training)
# model, device = load_trained_model('best_net_segmentation.pth')

# interact(visualize_ml_pipeline, 
#          frame_idx=IntSlider(min=0, max=len(sonar_images)-1, value=0, description='Frame'),
#          model=fixed(model),
#          device=fixed(device),
#          show_gt=fixed(True));