In [8]:
from glob import glob
from skimage.segmentation import slic
from PIL import Image
import numpy as np
import re

In [9]:
slics = {}
MAX_ITER = 10
COMPOSITION = '467'

In [10]:
def get_region(filepath):
    # Define a regular expression pattern to match "xnn" where nn is a 2-digit number
    pattern = r'(x\d{2})'
    
    # Use re.search to find the first match in the filepath
    match = re.search(pattern, filepath)
    
    if match:
        # Extract the matched "xnn" part from the regex match
        xnn_part = match.group(1)
        return xnn_part
    else:
        # Return None if no match is found
        return None

In [11]:
num_segments_dict = {"x01": 7000,
                     "x02": 9000,
                     "x03": 7000,
                     "x04": 5000,
                     "x06": 5000,
                     "x07": 11000,
                     "x08": 6000,
                     "x09": 7000,
                     "x10": 8000
                     }

In [12]:
img_paths = sorted(glob(f'scenes_{COMPOSITION}/*'))
truth_paths = sorted(glob('truth_masks/*'))

for idx, path in enumerate(img_paths):
    img = Image.open(path)
    img = np.array(img)
    truth = np.load(truth_paths[idx])

    region = get_region(path)
    N_SEGMENTS = num_segments_dict[region]
    
    mask = np.where(((truth != 0) & (truth != 255)), 1, 0) * 255

    print(f'Generating slic for region {region}...')
    slic_results = slic(img, n_segments=N_SEGMENTS, max_num_iter=MAX_ITER, mask=mask.squeeze(), convert2lab=True)

    slics[region] = slic_results

Generating slic for region x01...
Generating slic for region x02...
Generating slic for region x03...
Generating slic for region x04...
Generating slic for region x06...
Generating slic for region x07...
Generating slic for region x08...
Generating slic for region x09...
Generating slic for region x10...


In [13]:
# iterate through dict and save each region's slic
for region, slic in slics.items():
    savepath = f'slics/slic_{region}-{COMPOSITION}.npy'
    np.save(savepath, slic)
    print(f'saved: {savepath}')

saved: slics/slic_x01-467.npy
saved: slics/slic_x02-467.npy
saved: slics/slic_x03-467.npy
saved: slics/slic_x04-467.npy
saved: slics/slic_x06-467.npy
saved: slics/slic_x07-467.npy
saved: slics/slic_x08-467.npy
saved: slics/slic_x09-467.npy
saved: slics/slic_x10-467.npy
