In [27]:
path = 'Potholes/annotated-images/'
splits = 'Potholes/splits.json'
data = 'proposals.pkl'

In [28]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

# speed-up using multithreads
cv2.setUseOptimized(True)
cv2.setNumThreads(4)

def selective_search(image_path, num_rects, quality=True):
    ss = cv2.ximgproc.segmentation.createSelectiveSearchSegmentation()

    image = cv2.imread(image_path)
    ss.setBaseImage(image)

    ss.switchToSelectiveSearchFast()
        
    rects = ss.process()

    return rects[:num_rects]

def show_selective_search(image, rects):
    imOut = image.copy()

    # itereate over all the region proposals
    for _, rect in enumerate(rects):
        # draw rectangle for region proposal
        x, y, w, h = rect
        color = list(np.random.random(size=3) * 256)
        cv2.rectangle(imOut, (x, y), (x+w, y+h), color, 2, cv2.LINE_AA)

    plt.imshow(imOut[...,::-1])
    plt.axis('off')

In [None]:
import torch
import pickle
import matplotlib.pyplot as plt
import torchvision.ops as ops
from xml.etree import ElementTree as ET
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torchvision.io import decode_image
from torchvision.utils import draw_bounding_boxes
from sklearn.model_selection import train_test_split

def read_xml(path: str) -> list:  

    tree = ET.parse(path)
    root = tree.getroot()

    obj_list = []

    for obj in root.iter('object'):

        ymin = int(obj.find("bndbox/ymin").text)
        xmin = int(obj.find("bndbox/xmin").text)
        ymax = int(obj.find("bndbox/ymax").text)
        xmax = int(obj.find("bndbox/xmax").text)

        bbox = (xmin, ymin, xmax, ymax)
        obj_list.append(bbox)
    
    return obj_list

def collate_fn(batch):
    return tuple(zip(*batch))

def visualize_boxes(images, annotations):
    all_images = []
    all_overlay = []

    for image, annotation in zip(images, annotations):
        
        overlay = draw_bounding_boxes(image, annotation, width=2)
        all_images.append(image)
        all_overlay.append(overlay)
    
    fig, axes = plt.subplots(len(all_images), 2, figsize=(10, len(all_images) * 5))
    
    for idx, image in enumerate(all_images):
        axes[idx, 0].imshow(image.permute(1,2,0))
        axes[idx, 0].axis('off')

        axes[idx, 1].imshow(all_overlay[idx].permute(1,2,0))
        axes[idx, 1].axis('off')

    plt.tight_layout()
    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    plt.show()

class Pothole_Dataset(torch.utils.data.Dataset):
    def __init__(self, data, feature_size=256, pos_thresh=.7, neg_thresh=.3, val=False, device="cpu"):
        self.data = data
        self.size = feature_size
        self.pos_thresh = pos_thresh
        self.neg_thresh = neg_thresh
        self.val = val
        self.device = device

        self.transforms = v2.Compose([
                            v2.RandomHorizontalFlip(),
                            v2.RandomVerticalFlip()
                          ])

    def __len__(self):
        return len(self.data)
    

    def __getitem__(self, idx):
        datum = self.data[idx]
        image = decode_image(datum[0]).to(self.device)
        ground_truths = datum[1].to(self.device)
        regions = datum[2].to(self.device)

        ious = ops.box_iou(regions, ground_truths)

        # Select proposals with IoU >= pos_thresh as positive samples
        pos_indices = (ious.max(dim=1)[0] >= self.pos_thresh).nonzero(as_tuple=True)[0]
        pos_samples = regions[pos_indices]

        # Select proposals with IoU < neg_thresh as background samples
        bg_indices = (ious.max(dim=1)[0] < self.neg_thresh).nonzero(as_tuple=True)[0]
        num_bg = int(len(pos_samples) * 4)  # Make background samples 80% of total proposals
        bg_samples = regions[bg_indices[:num_bg]]
        
        # Concatenate positive and background samples 
        selected_regions = torch.cat([pos_samples, bg_samples], dim=0) 
        region_labels = torch.cat([torch.ones(len(pos_samples), device=self.device), torch.zeros(num_bg, device=self.device)])

        # Collect and stack image elements
        batched_image = image.unsqueeze(0).float()
        # 2. Format boxes for ROI Align
        # ROI Align expects boxes in format (batch_idx, x1, y1, x2, y2)
        N = len(selected_regions)
        # Add batch index (0) as first column
        batched_regions = torch.zeros(N, 5, device=self.device)
        batched_regions[:, 1:] = selected_regions  # Copy x1, y1, x2, y2        

        # 3. Extract ROIs using ROI Align
        rois = ops.roi_align(
            input=batched_image,        # (1, C, H, W)
            boxes=batched_regions,        # (N, 5)
            output_size=(self.size, self.size),
            spatial_scale=1.0,          # No scaling if boxes are in absolute coords
            aligned=True                # Better alignment with original image
        )
        
        if not self.val:
            rois = self.transforms(rois)
        
        return image, ground_truths, regions, selected_regions, region_labels, rois

In [30]:
feature_size = 512
batch_size = 6
num_workers = 1
device = ("cuda" if torch.cuda.is_available() else "cpu")

In [31]:
pkl_file = 'proposals.pkl'
with open(pkl_file, 'rb') as file:
    data = pickle.load(file)
    
train_len = 532
train_data = data[:train_len]
val_data, test_data = train_test_split(data[train_len:], train_size=.5)

trainset = Pothole_Dataset(train_data, feature_size=feature_size, device=device)
valset = Pothole_Dataset(val_data, feature_size=feature_size, device=device, val=True)
testset = Pothole_Dataset(test_data, feature_size=feature_size, device=device, val=True)

train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn)
val_loader= DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn)

In [32]:
images, ground_truths, regions, selected_regions, region_labels, rois = next(iter(train_loader))

In [None]:
visualize_boxes(images, regions)

In [None]:
visualize_boxes(images, selected_regions)