In [5]:
import numpy as np
import cv2
import glob
import itertools
import os
from tqdm import tqdm
import random
random.seed(0)
class_colors = [(random.randint(0,255), random.randint(0,255), random.randint(0,255)) for _ in range(5000)]

In [6]:
def get_pairs_from_paths(images_path , segs_path):
    images = glob.glob(os.path.join(images_path,"*.jpg"))
    segmentations = glob.glob(os.path.join(segs_path,"*_NEW.png")) 

    segmentations_d = dict(zip(segmentations,segmentations))

    ret = []

    for im in images:
        seg_bnme = os.path.basename(im).replace(".jpg", "_NEW.png")
        seg = os.path.join(segs_path , seg_bnme )
        assert (seg in segmentations_d), (im + " is present in "+images_path +" but "+seg_bnme+" is not found in "+segs_path + " . Make sure annotation image are in .png")
        ret.append((im , seg))

    return ret

In [7]:
def get_segmentation_arr( path , nClasses ,  width , height , no_reshape=False ):

    seg_labels = np.zeros((  height , width  , nClasses ))

    if type( path ) is np.ndarray:
        img = path
    else:
        img = cv2.imread(path, 1)

    img = cv2.resize(img, ( width , height ) , interpolation=cv2.INTER_NEAREST )
    img = img[:, : , 0]

    for c in range(nClasses):
        seg_labels[: , : , c ] = (img == c ).astype(int)

    if no_reshape:
        return seg_labels

    seg_labels = np.reshape(seg_labels, ( width*height , nClasses ))
    return seg_labels

In [8]:
def image_segmentation_generator( images_path , segs_path ,  batch_size,  n_classes , input_height , input_width , output_height , output_width  , do_augment=False ):

    img_seg_pairs = get_pairs_from_paths( images_path , segs_path )
    random.shuffle( img_seg_pairs )
    zipped = itertools.cycle( img_seg_pairs  )

    while True:
        X = []
        Y = []
        for _ in range( batch_size) :
            im , seg = next(zipped) 

            im = cv2.imread(im , 1 )
            seg = cv2.imread(seg , 1 )

            if do_augment:
                im , seg[:,:,0] = augment_seg( im , seg[:,:,0] )

            X.append( get_image_arr(im , input_width , input_height ,odering=IMAGE_ORDERING )  )
            Y.append( get_segmentation_arr( seg , n_classes , output_width , output_height )  )

        yield np.array(X) , np.array(Y)