In [None]:
import pandas as pd
import geopandas as gpd
from shapely import wkt
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import sys
import os
from PIL import Image
import matplotlib.pyplot as plt
import tifffile as tiff
PROJECT_ROOT = Path(os.getcwd()).resolve().parents[0]

In [None]:
data_dir = PROJECT_ROOT / "data"
print(type(data_dir))

In [None]:
test_dir = data_dir / "Louisiana-West-Test"
data_dir.exists() or sys.exit(f"Data directory {data_dir} does not exist. Please check the path.")
test_dir.exists() or sys.exit(f"Test directory {test_dir} does not exist. Please check the path.")

In [None]:
location = "Germany"
split = "Training"
data_types = ["label_image_mapping", "reference"]
naming_dict = {'label_image_mapping': 'scenes', 'reference': 'objects'}
dfs = {naming_dict[data_type]: pd.read_csv(data_dir / f"{location}_{split}_Public_{data_type}.csv") for data_type in data_types}

In [None]:
for data_type, df in dfs.items():
    print(f"Data type: {data_type}")
    print(f"Number of rows: {len(df)}")
    display(df.head())

In [None]:
dfs['scenes']['city'] = 'Germany'
print(dfs['scenes'])

In [None]:
unique_labels = len(dfs['scenes']['label'].unique()) == len(dfs['scenes'])
if unique_labels:
    print("All labels are unique.")
else:
    print("There are duplicate labels.")


In [None]:
image2_mask = ~dfs['scenes']["post-event image 2"].isna()
print(f"Number of rows with post-event image 2: {image2_mask.sum()}")

In [None]:
def show_pre_post_images(label_image_mapping, label, images='all'):
    row = label_image_mapping[label_image_mapping['label'] == label]
    if images == 'all':
        image_cols = ['pre-event image', 'post-event image 1', 'post-event image 2']
    for col in image_cols:
        if row[col].isna().all():
            print(f"No image found for {col} in label {label}.")
            continue
        img = Image.open(data_dir / ('PRE-event' if col=='pre-event image' else 'POST-event') / row[col].values[0])
        plt.imshow(img)
        plt.axis('off')
        plt.show()

In [None]:
show_pre_post_images(dfs['scenes'], '0_41_59.geojson')

In [None]:
# check if any post image names are used as imageIds in reference table
items_in_ref = []
for col in ['post-event image 1', 'post-event image 2']:
    for item in dfs['scenes'][col].dropna():
        if item.split('.')[0] in dfs['objects']['ImageId'].values:
            items_in_ref.append(item)
            print(f"Image {item} from {col} found in reference table.")
if len(items_in_ref) == 0:
    print("No post image items found in reference table.")

In [None]:
#gets counts of geometry types in WKT column of reference table

def get_geotype(wkt):
    if 'POLYGON' in wkt:
        return 'Polygon'
    elif 'LINESTRING' in wkt:
        return 'LineString'
    elif 'POINT' in wkt:
        return 'Point'
    else:
        return 'Unknown'
def convert_to_image(img):
    if img.shape[0] == 3:
        img = img.transpose((1, 2, 0))
    return img

def view_image(img, ax=None):
    if ax is None:
        old_ax = None
        fig = plt.figure(figsize=(6,6))
        ax = fig.add_subplot(111)
    else:
        old_ax = ax
    img = convert_to_image(img)
    ax.imshow(img)
    if old_ax is None:
        plt.axis('off')
        plt.show()
    return ax
 
def get_geoms(reference, ImageId):
    """
    Get geometries for a specific image.
    """
    image_mask = reference['ImageId'] == ImageId
    geoms = reference[image_mask]['Wkt_Pix'].apply(wkt.loads)
    lines = geoms[geoms.apply(lambda x: x.geom_type == 'LineString')]
    polys = geoms[geoms.apply(lambda x: x.geom_type == 'Polygon')]
    return lines, polys

def get_image(regions, label, img_type='pre-event image', city='Germany', verbose=True, data_dir=data_dir, Id_only=False):
    row = regions[regions['label'] == label]
    if row.empty or row[img_type].isna().all():
        if verbose:
            print(f"No {img_type} found for label {label}.")
        return None
    else:
        #Ids are just the pre-event image names without the .tif extension
        if Id_only:
            return row['pre-event image'].values[0].split('.tif')[0]
        else:
            img_path = data_dir / f'{city}_Training' / ('PRE-event' if img_type == 'pre-event image' else 'POST-event') / row[img_type].values[0]
            img = tiff.imread(img_path)
            return img

In [None]:
def plot_image_with_geometries(label_image_mapping, reference, label, flag_idx=None):
    row = label_image_mapping[label_image_mapping['label'] == label]

    #get image Id
    ImageId = get_image(row, label, Id_only=True)
    if ImageId is None:
        return
    
    image_types = ['pre-event image', 'post-event image 1', 'post-event image 2']
    for img_type in image_types:
        if img_type != 'pre-event image' and row[img_type].isna().all():
            print(f"No image found for {img_type} in label {label}.")
            continue

        #get lines and polygons for this image
        lines, polys = get_geoms(reference, ImageId)

        #load image
        img = get_image(row, label, img_type=img_type)

        flag_mask = np.ones_like(lines, dtype=bool)

        print(f"Showing {img_type} for label {label}.")
        fig, ax = plt.subplots(figsize=(6, 6))
        ax = view_image(img, ax=ax)
        if flag_idx is not None:
            flag_mask[flag_idx] = False
            # print pixels of flagged geometries
            print(f"Flagged geometries: {lines.iloc[flag_idx]}")
            gdf_flag = gpd.GeoDataFrame(geometry=lines[~flag_mask])
            gdf_flag.plot(ax=ax, facecolor='none', edgecolor='purple')

        gdf_lines = gpd.GeoDataFrame(geometry=lines[flag_mask])
        gdf_polys = gpd.GeoDataFrame(geometry=polys)
    
        gdf_lines.plot(ax=ax, facecolor='none', edgecolor='red')
        gdf_polys.plot(ax=ax, facecolor='blue', edgecolor='blue', alpha = 0.7)
        plt.axis('off')
        plt.show()

In [None]:
'''Extract objects in a certain image'''
#------------------------------------------
label = '0_41_58.geojson'
#------------------------------------------
which_image = 'pre-event image' #must be pre-event image, references are only linked to pre-event images
ImageId = dfs['scenes'][which_image][dfs['scenes']['label'] == label].values[0].split('.tif')[0]
image_mask = dfs['objects']['ImageId'] == ImageId


In [None]:
shapes = dfs['objects'][image_mask]['Wkt_Pix'].apply(wkt.loads)
lines = shapes[shapes.apply(lambda x: x.geom_type == 'LineString')]
line_lengths = lines.apply(lambda geom: geom.coords).apply(len)

In [None]:
idx = 20
plot_image_with_geometries(dfs['scenes'], dfs['objects'], '0_41_59.geojson', flag_idx=idx)
assert shapes.iloc[idx].geom_type == 'LineString', f'{shapes.iloc[idx].geom_type} object is not subscriptable'
point = shapes.iloc[idx].coords[0]
print(point)
print('number of segments with same point:', shapes.apply(lambda x: point in list(x.coords) if x.geom_type == 'LineString' else False).sum())

In [None]:
print(dfs['objects']['Wkt_Pix'].apply(get_geotype).value_counts())

In [None]:
#plot images with overlaid geometries from reference table
plot_image_with_geometries(dfs['scenes'], dfs['objects'], '0_26_62.geojson', flag_idx=None)


In [None]:
#inspect pre and post images
#label = '0_41_59.geojson'
image_types = ['pre-event image', 'post-event image 1', 'post-event image 2']
arrs = []
for img_type in image_types:
    for label in dfs['scenes']['label'].unique():
        row = dfs['scenes'][dfs['scenes']['label'] == label]
        if row.empty or row[img_type].isna().all():
            print(f"No {img_type} found for label {label}.")
            continue
        img_path = data_dir / ('PRE-event' if img_type == 'pre-event image' else 'POST-event') / row[img_type].values[0]
        img = tiff.imread(img_path)
        print(f'shape of {img_type} for label {label}: ', img.shape)

In [None]:
# upsample post-event images
import cv2
num_pix = 1300
img = arrs[-1]
out = cv2.resize(
        img, 
        (num_pix, num_pix),
        interpolation=cv2.INTER_LANCZOS4  # Lanczos with 8-tap window
    )
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.axis('off')
plt.figure(figsize=(6, 6))
plt.imshow(out)
plt.axis('off')

In [None]:
#flag labels with non-square post-event images
for label in dfs['scenes']['label'].unique():
    row = dfs['scenes'][dfs['scenes']['label'] == label]
    if row.empty or row['post-event image 1'].isna().all():
        print(f"No post-event image 1 found for label {label}.")
        continue
    img_path = data_dir / 'POST-event' / row['post-event image 1'].values[0]
    img = tiff.imread(img_path)
    if img.shape[0] != img.shape[1]:
        print(f"Label {label} has non-square post-event image 1: {img.shape}")


In [None]:
# check for non-square post-event images
image_types = ['post-event image 1', 'post-event image 2']
for label in dfs['scenes']['label'].unique():
    mismatch = False
    imgs = []
    for image_type in image_types:
        img = get_image(dfs['scenes'], label, img_type=image_type, verbose=False)
        imgs.append(img)
        if img is not None:
            if abs(img.shape[0] - img.shape[1]) > 1:
                mismatch = True
    if mismatch:
        pre_img = get_image(dfs['scenes'], label, img_type='pre-event image', verbose=False)
        print(f"Label {label} has non-square post-event images")
        # display pre-image and dimensions
        ax = view_image(pre_img)
        ax.set_title(f"pre-event image shape: {pre_img.shape}")
        print(f"pre-event image for label {label} shape: {pre_img.shape}")
        for img, image_type in zip(imgs, image_types):
            if img is not None:
                fig = plt.figure(figsize=(6, 6))
                ax = view_image(img)
                ax.set_title(f"{image_type} shape: {img.shape}")
                print(f"{image_type} for label {label} shape: {img.shape}")

In [None]:
def rootSIFT(D, eps=1e-12):
    D = D / (np.sum(D, axis=1, keepdims=True) + eps)
    D = np.sqrt(D)
    return D

def scale_filtered_pairs(kp1, kp2, match, max_ratio=1.5, min_px=None, max_px=None): 
    s1 = kp1[match.queryIdx].size
    s2 = kp2[match.trainIdx].size
    if min_px and (s1 < min_px or s2 < min_px): 
        return False
    if max_px and (s1 > max_px or s2 > max_px): 
        return False
    r = max(s1, s2) / max(1e-6, min(s1, s2))
    if r <= max_ratio:                 # e.g., allow ≤ 1.5× scale difference
        return True
    return False

def sift_matches(imgA, imgB, ratio=0.7):
    sift = cv2.SIFT_create(
        nOctaveLayers=4,          # keep near-edge features
        sigma=1.2
    )

    kA, dA = sift.detectAndCompute(imgA, None)
    kB, dB = sift.detectAndCompute(imgB, None)
    #dA = rootSIFT(dA)
    #dB = rootSIFT(dB)
   
    index_params = dict(algorithm=1, trees=5)
    search_params = dict(checks=100)
    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
    #matcher = cv2.FlannBasedMatcher(index_params, search_params)
    knn = bf.knnMatch(dA, dB, k=2)
    print(len(knn), "matches found")
    good = []
    #only keep matches that are large enough 

    for m, n in knn:
        if not scale_filtered_pairs(kA, kB, m, max_ratio=1.5, min_px=10, max_px=None):
            continue
        if m.distance < ratio * n.distance:
            good.append((m.queryIdx, m.trainIdx, m.distance))

    if len(good) < 3:
        raise RuntimeError("Not enough matches after ratio test")

    ptsA = np.float32([kA[i].pt for i, _, _ in good])
    ptsB = np.float32([kB[j].pt for _, j, _ in good])
    szA = np.array([kA[i].size for i, _, _ in good])
    szB = np.array([kB[j].size for _, j, _ in good])
    distances = [d for _, _, d in good]
    return ptsA, ptsB, distances, szA, szB

In [None]:
img_pre = get_image(dfs['scenes'], '0_40_59.geojson', img_type='pre-event image', verbose=True)
img_post = get_image(dfs['scenes'], '0_40_59.geojson', img_type='post-event image 1', verbose=True)

#pre event image
print(f"pre-event image shape: {img_pre.shape}")
view_image(img_pre)

#post event image
print(f"post-event image shape: {img_post.shape}")
view_image(img_post)


In [None]:
type(img_pre[0, 0, 0])

In [None]:
''' Resize function to maintain aspect ratio and use appropriate interpolation based on size
ref_size: tuple of (width, height) to resize the image to
'''
def resize(img, ref_size=(1300, 1300), padding='top'):
    if img.shape[0] == 3:
        img = img.transpose((1, 2, 0))
    if ref_size[0] > img.shape[1]:
        interp = cv2.INTER_LANCZOS4
    else:
        interp = cv2.INTER_AREA
    new_size = (ref_size[0], int(ref_size[0] * img.shape[0] / img.shape[1])) # (W, H)
    scaled_img = cv2.resize(img, new_size, interpolation=interp)
    if img.shape[0] == img.shape[1]:
        padded_img = scaled_img
    else:   
        if padding == 'top':
            padded_img = cv2.copyMakeBorder(scaled_img, max(0, ref_size[1] - new_size[1]), 0, 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0))
        elif padding == 'bottom':
            padded_img = cv2.copyMakeBorder(scaled_img, 0, max(0, ref_size[1] - new_size[1]), 0, 0, cv2.BORDER_CONSTANT, value=(0, 0, 0))
        else:
            padded_img = scaled_img
    return padded_img


def align_w_pre(img_post, img_pre, clahe=False):
    ref_size = img_pre.shape[:2]
    img_post = resize(img_post, ref_size=(1300, 1300), padding='top')

    # convert to grayscale (required for ECC)
    gray_post = cv2.cvtColor(convert_to_image(img_post), cv2.COLOR_BGR2GRAY)
    gray_pre = cv2.cvtColor(convert_to_image(img_pre), cv2.COLOR_BGR2GRAY)
    
    if clahe:
        clahe = cv2.createCLAHE(clipLimit=3, tileGridSize=(16, 16))
        gray_pre = clahe.apply(gray_pre)
        gray_post = clahe.apply(gray_post)
    # template = reference, input = to be warped
    template = gray_pre
    input_img = gray_post

    # initial warp matrix for translation = 2x3 identity
    warp_matrix = np.eye(2, 3, dtype=np.float32)

    # define the stopping criteria
    criteria = (
        cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
        10000,    # max iterations
        1e-6     # convergence threshold
    )

    # run ECC
    cc, warp_matrix = cv2.findTransformECC(
        template,
        input_img,
        warp_matrix,
        motionType=cv2.MOTION_TRANSLATION,
        criteria=criteria
    )

    print("Correlation coefficient:", cc)
    print("Warp matrix:\n", warp_matrix)

    # apply the translation
    h, w = template.shape
    aligned = cv2.warpAffine(img_post, warp_matrix, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)


In [None]:
size = 1300
start_x, start_y = 0, 0
img_pre.shape
view_image(img_pre[:, start_y:start_y+size, start_x:start_x+size])
view_image(aligned[start_y:start_y+size, start_x:start_x+size, :])
view_image(img_post[start_y:start_y+size, start_x:start_x+size, :])


In [None]:
num_pix = 1300
img_post = cv2.resize(
        img_post, 
        (num_pix, num_pix),
        interpolation=cv2.INTER_LANCZOS4  # Lanczos with 8-tap window
)
gray_pre = cv2.cvtColor(convert_to_image(img_pre), cv2.COLOR_BGR2GRAY)
gray_post = cv2.cvtColor(convert_to_image(img_post), cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(16, 16))
#gray_pre = clahe.apply(gray_pre)
#gray_post = clahe.apply(gray_post)
print(gray_pre.shape)
#view_image(gray_pre)
#view_image(gray_post)


In [None]:
pts_post, pts_pre, distances, sz_post, sz_pre = sift_matches(gray_post, gray_pre, ratio=.9)

In [None]:
num_pts = pts_post.shape[0]
print(f"Number of matched points retained: {num_pts}")

In [None]:
max_idx = num_pts
print(f'Top {max_idx} matches:')
#for idx in range(max_idx):
    #print(f"Match {idx}: diff {pts_post[idx, :]-pts_pre[idx, :]} post {pts_post[idx, :]} pre {pts_pre[idx, :]} distance {distances[idx]} size_post {sz_post[idx]} size_pre {sz_pre[idx]}")


In [None]:
#check matches
pts_post_rescaled = pts_post*(1300/gray_post.shape[0])
pts_pre_rescaled = pts_pre*(1300/gray_pre.shape[0])
max_idx = num_pts
#print(f'Top {max_idx} matches:')
#for idx in range(max_idx):
    #print(f"Match {idx}: post {pts_post_rescaled[idx, :]} pre {pts_pre_rescaled[idx, :]} distance {distances[idx]}")


In [None]:
pre_img = get_image(dfs['scenes'], '0_27_62.geojson', img_type='pre-event image', verbose=True)
view_image(pre_img)
post_img = get_image(dfs['scenes'], '0_27_62.geojson', img_type='post-event image 1', verbose=True)
view_image(post_img)
post_img2 = get_image(dfs['scenes'], '0_27_62.geojson', img_type='post-event image 2', verbose=True)
if post_img2 is not None:   
    view_image(post_img2)

In [None]:
post_img = resize(post_img, ref_size=(1300, 1300))
post_img2 = resize(post_img2, ref_size=(1300, 1300)) if post_img2 is not None else None
view_image(pre_img)
view_image(post_img)
if post_img2 is not None:
    view_image(post_img2)

In [None]:
print(post_img.shape)
print(post_img2.shape if post_img2 is not None else "No second post-event image")

In [None]:
def point_to_segment_distance(px, py, x0, y0, x1, y1):
    # vectorized distance from grid points (px,py) to segment (x0,y0)-(x1,y1)
    vx, vy = x1 - x0, y1 - y0
    wx, wy = px - x0, py - y0
    vv = vx*vx + vy*vy + 1e-12
    t = (wx*vx + wy*vy) / vv
    t = np.clip(t, 0.0, 1.0)
    projx, projy = x0 + t*vx, y0 + t*vy
    dx, dy = px - projx, py - projy
    return np.sqrt(dx*dx + dy*dy), t  # return t if you want d_parallel too

# D: stack of per-segment distances, shape [n_segments, H, W], all >= 0
# tau: softness (pixels), e.g., 0.5–1.0
def softmin_distance(D, tau, eps=1e-12):
    """
    D: [n, H, W] unsquared distances; may contain +inf
    tau: softness (pixels)
    returns: d_soft [H, W] with +inf where no finite distances exist
    """
    finite_mask = np.isfinite(D)             # [n,H,W]
    has_finite  = finite_mask.any(axis=0)    # [H,W]

    # Per-pixel min over finite entries; inf if none
    D_masked = np.where(finite_mask, D, np.inf)
    m = D_masked.min(axis=0)                 # [H,W], inf where !has_finite

    # shifted = (D - m) / tau only where finite
    shifted = np.empty_like(D)
    np.subtract(D, m[None, ...], out=shifted, where=finite_mask)     # no inf-inf
    np.divide(shifted, max(tau, eps), out=shifted, where=finite_mask)

    # S = sum exp(-shifted) over finite entries only
    exp_term = np.zeros_like(D, dtype=D.dtype)
    np.exp(-shifted, out=exp_term, where=finite_mask)
    S = exp_term.sum(axis=0)                                        # [H,W]
    S = np.where(has_finite, S, 0.0)

    # d_soft = m - tau * log(S) where finite; else +inf
    d_soft = np.where(has_finite, m - tau * np.log(S + eps), np.inf)

    # Optional: clamp tiny negatives from fp errors
    d_soft = np.where(np.isfinite(d_soft), np.maximum(d_soft, 0.0), d_soft)
    return d_soft

def soft_or_centerlines(H, W, segments, sigma_perp, sigma_para=None):
    # segments: Series of linestring objects with two points ((x0,y0),(x1,y1)) in pixel coords (float)
    Y_log = np.zeros((H, W), dtype=np.float32)  # will accumulate log(1 - g_i)
    Y_log[:] = 0.0  # since log(1) = 0; we will add log(1 - g_i)

    three_sigma = 3.0 * sigma_perp
    for segment in segments:
        (x0,y0),(x1,y1) = segment.coords
        # ROI bounds padded by ~3σ⊥
        xmin = int(max(0, np.floor(min(x0,x1) - three_sigma)))
        xmax = int(min(W-1, np.ceil (max(x0,x1) + three_sigma)))
        ymin = int(max(0, np.floor(min(y0,y1) - three_sigma)))
        ymax = int(min(H-1, np.ceil (max(y0,y1) + three_sigma)))
        if xmin > xmax or ymin > ymax: 
            continue

        xs = np.arange(xmin, xmax+1, dtype=np.float32)
        ys = np.arange(ymin, ymax+1, dtype=np.float32)
        px, py = np.meshgrid(xs, ys)  # shape [roi_h, roi_w]

        d_perp, t = point_to_segment_distance(px, py, x0, y0, x1, y1)

        if sigma_para is None:  # perpendicular-only kernel
            g = np.exp(-0.5 * (d_perp / sigma_perp)**2)
        else:
            # compute parallel distance for anisotropic kernel
            # segment direction (unit)
            vx, vy = x1 - x0, y1 - y0
            seg_len = np.sqrt(vx*vx + vy*vy) + 1e-12
            ux, uy = vx/seg_len, vy/seg_len
            # vector from projection to point: (dx,dy) = (px - projx, py - projy)
            projx, projy = x0 + t*vx, y0 + t*vy
            dx, dy = px - projx, py - projy
            d_para = dx*ux + dy*uy
            g = np.exp(-0.5 * (d_perp / sigma_perp)**2 - 0.5 * (d_para / sigma_para)**2)

        g = np.clip(g, 0.0, 1.0 - 1e-7)  # avoid log(0)
        Y_log[ymin:ymax+1, xmin:xmax+1] += np.log1p(-g)  # log(1 - g)

    Y = 1.0 - np.exp(Y_log)  # since log(1 - Y) = sum log(1 - g_i)
    return np.clip(Y, 0.0, 1.0)

    # given: list of segments and a function point_to_segment_distance(px,py, ...)

def softmin_gauss_centerlines_old(H, W, segments, tau=10, sigma=3):
    D = []  
    d_full = np.full((H, W), np.inf) # baseline distance map, all inf
    num_segments = len(segments)   
    for i, segment in enumerate(segments):
        print(f"Processing segment {i+1}/{num_segments}")
        (x0,y0),(x1,y1) = segment.coords
        # ROI padding ~ 3*tau (or 3*sqrt(tau) for squared version)
        pad = 200
        xmin = int(max(0, np.floor(min(x0,x1) - pad)))
        xmax = int(min(W-1, np.ceil (max(x0,x1) + pad)))
        ymin = int(max(0, np.floor(min(y0,y1) - pad)))
        ymax = int(min(H-1, np.ceil (max(y0,y1) + pad)))
        if xmin>xmax or ymin>ymax: 
            continue

        xs = np.arange(xmin, xmax+1, dtype=np.float32)
        ys = np.arange(ymin, ymax+1, dtype=np.float32)
        px, py = np.meshgrid(xs, ys)

        d, _ = point_to_segment_distance(px, py, x0, y0, x1, y1)
        d_full[ymin:ymax+1, xmin:xmax+1] = d
        D.append(np.copy(d_full))
        d_full = np.full((H, W), np.inf) # reset for next segment
    # soft-min distance
    D = np.stack(D, axis=0)  # [n_segments, H, W]
    d_soft = softmin_distance(D, tau)

    #return gaussian of the soft-min distance
    return np.exp(-(d_soft**2) / (2 * sigma**2))

def softmin_gauss_centerlines(H, W, segments, sigma=5, gamma=2, eps=1e-16, pad=200): 
    d_inv = np.zeros((H, W)) # baseline inverse distance map, all zeros
    num_segments = len(segments)
    for i, segment in enumerate(segments):
        print(f"Processing segment {i+1}/{num_segments}")
        (x0,y0),(x1,y1) = segment.coords
        xmin = int(max(0, np.floor(min(x0,x1) - pad)))
        xmax = int(min(W-1, np.ceil (max(x0,x1) + pad)))
        ymin = int(max(0, np.floor(min(y0,y1) - pad)))
        ymax = int(min(H-1, np.ceil (max(y0,y1) + pad)))
        if xmin>xmax or ymin>ymax: 
            continue

        xs = np.arange(xmin, xmax+1)
        ys = np.arange(ymin, ymax+1)
        px, py = np.meshgrid(xs, ys)

        d, _ = point_to_segment_distance(px, py, x0, y0, x1, y1)
        d_inv[ymin:ymax+1, xmin:xmax+1] += (d+eps)**(-gamma)
    zeros_mask = d_inv!=0
    d_soft = zeros_mask*(d_inv+eps**2)**(-1/gamma) - eps  # avoid division by zero
    d_soft[~zeros_mask] = np.inf
    #return gaussian of the soft-min distance
    return np.exp(-(d_soft**2) / (2 * sigma**2))

def gauss_centerlines(H, W, segments, sigma=5, eps=1e-16, pad=200): 
    d_min = np.full((H, W), np.inf) # baseline inverse distance map, all inf
    num_segments = len(segments)
    for i, segment in enumerate(segments):
        print(f"Processing segment {i+1}/{num_segments}")
        (x0,y0),(x1,y1) = segment.coords
        xmin = int(max(0, np.floor(min(x0,x1) - pad)))
        xmax = int(min(W-1, np.ceil (max(x0,x1) + pad)))
        ymin = int(max(0, np.floor(min(y0,y1) - pad)))
        ymax = int(min(H-1, np.ceil (max(y0,y1) + pad)))
        if xmin>xmax or ymin>ymax: 
            continue

        xs = np.arange(xmin, xmax+1)
        ys = np.arange(ymin, ymax+1)
        px, py = np.meshgrid(xs, ys)

        d, _ = point_to_segment_distance(px, py, x0, y0, x1, y1)
        d_min[ymin:ymax+1, xmin:xmax+1] = np.minimum(d_min[ymin:ymax+1, xmin:xmax+1], d)

    #return gaussian of the min distance
    return np.exp(-(d_min**2) / (2 * sigma**2))


In [None]:
#convert roads to grayscale
label = '0_41_59.geojson'

ImageId = get_image(dfs['scenes'], label, Id_only=True)
roads, buildings = get_geoms(dfs['objects'], ImageId)

In [None]:
#road_target = softmin_gauss_centerlines(1300, 1300, roads, sigma=10, gamma=10)
road_target = gauss_centerlines(1300, 1300, roads, sigma=10)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
view_image(road_target, ax=ax)
plt.axis('off')
plt.show()

In [None]:
print(np.max(road_target), np.min(road_target))

In [None]:
view_image(get_image(dfs['scenes'], label, img_type='pre-event image', verbose=True))

In [None]:
class TileSampler:
    #assumes image and tiles are square and accepts only single integers for sizes
    def __init__(self, img_size, core_size, halo_size, stride):
        self.img_size = img_size
        self.core_size = core_size
        self.halo_size = halo_size
        self.stride = stride
        # sample grid jitter
        self.ux = np.random.randint(0, self.stride)
        self.uy = np.random.randint(0, self.stride)
        #number of gridlines for this arrangemnt
        self.num_grid_x = -(-(self.img_size - self.ux) // self.stride)
        self.num_grid_y = -(-(self.img_size - self.uy) // self.stride)
        self.used_tiles = []  # to keep track of used tiles

    def sample_tile_coords(self):
        #sample anchor pixel
        px, py = np.random.randint(0, self.img_size, size=(2,))
        #check if pixel is within a stride of any edge
        top, bottom = py < self.stride, py >= self.img_size - self.stride
        left, right = px<self.stride, px>=self.img_size-self.stride
        #which tile does this pixel belong to?
        #tile num 0 is the tile belonging to the first full stride
        #tile num self.num_grid_(x,y)-1 is the tile belonging to the last fractional stride
        num_x = (px - self.ux) // self.stride
        num_y = (py - self.uy) // self.stride
        #if pixel belongs to an edge tile whose stride overlaps the next tile, sample between the two
        if left and num_x==0:
            num_x = np.random.randint(-1, 1)
        elif right and num_x == self.num_grid_x - 2:
            num_x = np.random.randint(self.num_grid_x - 2, self.num_grid_x)
        if top and num_y==0:
            num_y = np.random.randint(-1, 1)
        elif bottom and num_y == self.num_grid_y - 2:
            num_y = np.random.randint(self.num_grid_y - 2, self.num_grid_y)
        #check if tile was used already
        if (num_x, num_y) in self.used_tiles:
            return self.sample_tile_coords()  # recursively sample until unused tile is found
        #calculate tile coordinates
        if num_x == -1:
            x0 = 0
        elif num_x == self.num_grid_x - 1:
            x0 = self.img_size - self.core_size
        else:
            x0 = num_x * self.stride + self.ux
        #same for y
        if num_y == -1:
            y0 = 0
        elif num_y == self.num_grid_y - 1:
            y0 = self.img_size - self.core_size
        else:
            y0 = num_y * self.stride + self.uy
        
        #store used tile coordinates
        self.used_tiles.append((num_x, num_y))
        return (x0, y0)
        
    def get_tiles(self):
        tiles = []
        for y in range(0, self.image.shape[0], self.tile_size):
            for x in range(0, self.image.shape[1], self.tile_size):
                tiles.append(self.get_tile(x, y))
        return tiles

In [None]:
sampler = TileSampler(1300, 512, 32, 256)
for _ in range(10):
    print(sampler.sample_tile_coords())