In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import pyplot as plt
import geopandas as gpd
import rasterio

from src.geo_util import (
    get_masks, 
    infer_date, 
    group_images_by_date, 
    compute_raster_extent, 
    load_and_merge_masks, 
    rasterize_gdf,
    merged_no_data_mask,
    extract_linestring,
    tif_image,
    plot_mask,
    plot_line,
    plot_crops,
    merge_tifs,
    create_per_day_crops,
)
from src.ml_util import generate_square_crops_along_line
from omegaconf import DictConfig, OmegaConf
from src.config import BeachSegConfig
from src.data import BeachSegDataModule, torch_apply_mask_rgb

In [None]:
base_path = Path("/Users/kyledorman/data/BorderField/")
classification_dir = base_path / "Classifications"
shp_dir = base_path / "Lines"
    
crop_size = 224
buffer_factor = 0.125
buffer_px = int(crop_size * buffer_factor)
mask_dir = base_path / "Masks"
img_paths = sorted(list(base_path.glob("SatelliteImagery/*/*.tif")))

In [None]:
# Masks and dates
veg_masks = get_masks(mask_dir, "Mask_*.shp")
water_masks = get_masks(mask_dir, "WaterMask_*.shp")
mask_date = infer_date(veg_masks + water_masks)

len(veg_masks), len(water_masks), mask_date

In [None]:
# Group images
groups = group_images_by_date(img_paths)
ref_imgs = groups.pop(mask_date, [])

len(ref_imgs), len(groups)

In [None]:
# Compute extent & raster masks
out_transform, out_shape, CRS = compute_raster_extent(ref_imgs + sum(groups.values(), []))

out_shape, CRS

In [None]:
veg_gdf = load_and_merge_masks(veg_masks)
veg_mask = rasterize_gdf(veg_gdf, out_shape, out_transform) == 1
water_gdf = load_and_merge_masks(water_masks)
water_mask = rasterize_gdf(water_gdf, out_shape, out_transform) == 1
full_no_data = merged_no_data_mask(water_mask, veg_mask)
sand_mask = ~(full_no_data | water_mask | veg_mask)

merged_mask = np.zeros(veg_mask.shape, dtype=np.uint8)
merged_mask[water_mask] = 1
merged_mask[veg_mask] = 2
merged_mask[sand_mask] = 3

fig, axes = plt.subplots(1, 4, figsize=(5, 5))
for ax in axes:
    ax.axis('off')
axes[0].imshow(water_mask)
axes[1].imshow(veg_mask)
axes[2].imshow(full_no_data)
axes[3].imshow(sand_mask)

In [None]:
# Extract lines and generate crops
water_line = extract_linestring(water_mask, full_no_data)
assert water_line is not None
veg_line = extract_linestring(veg_mask, full_no_data)
assert veg_line is not None
prompt_crops = generate_square_crops_along_line(water_line, crop_size, 0)

In [None]:
merged_img, merged_img_nodata = merge_tifs(ref_imgs, out_shape, out_transform, CRS)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 16))

ax.imshow(merged_img)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 16))

ax.imshow(merged_img)
plot_mask(veg_mask, 'teal', 0.3, ax)

plot_mask(water_mask, 'hotpink', 0.3, ax)

plot_line(water_line, 'blue', ax)
plot_line(veg_line, 'green', ax)

plot_crops(prompt_crops, 'red', ax)
        
ax.axis('off')
fig.tight_layout()

In [None]:
p_imgs, p_masks, p_nodata = create_per_day_crops(prompt_crops, merged_img, merged_img_nodata, merged_mask, crop_size)
count = len(p_imgs)
keep = [i for i, nd in enumerate(p_nodata) if ~np.all(nd)]
p_imgs = [p_imgs[i] for i in keep]
p_masks = [p_masks[i] for i in keep]
p_nodata = [p_nodata[i] for i in keep]

len(keep), count

In [None]:
cols = 5
rows = len(p_imgs) // cols + int(len(p_imgs) % cols > 0)

fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols * 4, rows * 4))
    
for idx, ax in zip(range(len(p_imgs)), axes.flatten()):
    ax.axis('off')
    
    img = p_imgs[idx]    
    ax.imshow(img)
    
    plot_mask(p_masks[idx] == 1, 'hotpink', 0.6, ax)
    plot_mask(p_masks[idx] == 2, 'teal', 0.6, ax)

fig.tight_layout()
plt.show()

In [None]:
base_conf = OmegaConf.structured(BeachSegConfig)
base_conf.workers = 0

dm = BeachSegDataModule(base_conf)
dm.setup("train")
dl = dm.train_dataloader()
for batch in iter(dl):
    break
    

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(5, 5))
for ax in axes.flatten():
    ax.axis('off')

bv = dm.aug(batch)
bt = dm.train_aug(batch)

img = dm.denormalize(bv["image"])[0].detach().cpu().numpy().transpose((1, 2, 0)).clip(0, 1)
axes[0, 0].imshow(img)

mask = torch_randomize_mask_rgb(bv["mask"])[0].detach().cpu().numpy().transpose((1, 2, 0)).clip(0, 1)
axes[0, 1].imshow(mask)

img = dm.denormalize(bt["image"])[0].detach().cpu().numpy().transpose((1, 2, 0)).clip(0, 1)
axes[1, 0].imshow(img)

mask = torch_randomize_mask_rgb(bt["mask"])[0].detach().cpu().numpy().transpose((1, 2, 0)).clip(0, 1)
_ = axes[1, 1].imshow(mask)

fig.tight_layout()

In [None]:
def build_palette(num_labels: int):
    base = int(num_labels ** (1 / 3)) + 1
    margin = 256 // base

    # we assume that class_idx 0 is the background which is mapped to black
    color_list = [(0, 0, 0)]
    for location in range(num_labels):
        num_seq_r = location // base**2
        num_seq_g = (location % base**2) // base
        num_seq_b = location % base

        R = 255 - num_seq_r * margin
        G = 255 - num_seq_g * margin
        B = 255 - num_seq_b * margin

        color_list.append((R, G, B))

    return color_list

import torch
torch.tensor(build_palette(3)).shape