In [None]:
from glob import glob
from PIL import Image
import os
import numpy as np
import re
from skimage.segmentation import slic, mark_boundaries

In [None]:
scenes = {}
for filename in glob("scenes_sentinel_pca/*.npy"):
    pca_region = np.load(filename)
    region = re.search(r"x(\d+)", filename).group(0)
    scenes[region] = pca_region
    print(f'Loaded slic for region {region} with shape {pca_region.shape}')

In [None]:
n_segments = {
    "x01": 21000,
    "x02": 27000,
    "x03": 21000,
    "x04": 15000,
    "x06": 18000,
    "x07": 33000,
    "x08": 18000,
    "x09": 21000,
    "x10": 24000
}

In [None]:
slics = {}
for region, pca_region in scenes.items():
    n_segments[region] *= 3
    print(f'Running SLIC for region {region} with {n_segments[region]} segments...')
    slic_region = slic(pca_region, n_segments=n_segments[region])
    slics[region] = slic_region

In [None]:
os.makedirs('slics_sentinel', exist_ok=True)
for region, slic_region in slics.items():
    print(f'Saving SLIC for region {region}...')
    print(f'Total of {len(np.unique(slic_region))} segments')
    np.save(f'slics_manysegs/slic_{region}', slic_region)

In [None]:
os.makedirs('slics_sentinel/images', exist_ok=True)
for region, slic_region in slics.items():
    print(f'Saving SLIC images for region {region}...')
    img = np.uint8(mark_boundaries(scenes[region], slic_region) * 255)
    # Get average segment size
    segment_sizes = np.bincount(slic_region.flatten())
    # segment_sizes = segment_sizes[segment_sizes > 0]
    print(f'Average segment size: {np.mean(segment_sizes)}')
    print(f'Median segment size: {np.median(segment_sizes)}')
    print(f'Number of segments with size > 630: {np.sum(segment_sizes > 630)} ({np.sum(segment_sizes > 630) / len(segment_sizes) * 100:.2f}%)')
    # Blackout segments with size < 630
    for segment in np.unique(slic_region):
        if segment_sizes[segment] < 630:
            img[slic_region == segment] = (1, 1, 1)

    img = Image.fromarray(img)
    img.save(f'slics_manysegs/images/slic_{region}.png')
    print(f'{30 * "-"}')