In [131]:
import mmcv
import matplotlib.pyplot as plt

import numpy as np 
import os
import sys
from PIL import Image
from tqdm import tqdm

In [132]:
label_root_dir = '/var/datasets/rain_filtering/particle_labels/train'

save_root_dir = '/var/datasets/rain_filtering/ann_dir/train'
os.makedirs(save_root_dir, exist_ok=True)

In [133]:
label_list = ['first', 'last']

In [134]:
image_list= np.sort(os.listdir(os.path.join(label_root_dir, label_list[0])))

In [135]:
def create_mmseg_labels(label_root_dir, label_list,image_list,save_root_dir):
    """
    Created and save labels for mmsegmentation.
    Loaded label images (ranged 0-255) then regenerate labels accorging to their values (ranged 0 -num_classes -1).
    For rain_filtering dataset,
    0... nothing
    1... particle (based on first images)
    2... object (based on last images)
    3... particle and object (intersect between first and last)
    The label images are saved as png file as 'P' mode.
    
    Arguments:
        label_root_dir: path to directory that contains the label image direcotry
        label_list: list that contatins which label directory to use for labels
        image_list; list that contains which image file to create labels
        save_root_dir: path to directory that save directory
    """
    for img_file in tqdm(image_list):
        #load each images 
        label_image = create_label(label_root_dir, label_list, img_file)
        #convert to PIL.Image
        label_pil = Image.fromarray(label_image.astype(np.uint8))
        #save label image
        label_pil.save(os.path.join(save_root_dir, img_file), mode='P')

    print('done')
    

def create_label(label_root_dir, label_list, img_file):
    """
    Created and  labels for mmsegmentation.
    Loaded label images (ranged 0-255) then regenerate labels accorging to their values (ranged 0 -num_classes -1).
    Arguments:
        label_root_dir: path to directory that contains the label image direcotry
        label_list: list that contatins which label directory to use for labels
        img_file: str for load image file name
    Return: 
        img_array: np.array: image shape, whose values are ranged  (ranged 0 -num_classes -1).
        For rain_filtering datasets,
        0... nothing
        1... particle (based on first images)
        2... object (based on last images)
        3... particle and object (intersect between first and last)
    """
    lbl_img_list = []
    for label in label_list:
        #load image as float32
        lbl_img = mmcv.imread(os.path.join(label_root_dir, label, img_file)).astype(np.float32)
        #convert 255-> 1 (Note that pixel values are eigher 0 or 255)
        lbl_img /= 255
        # sum up the channel e.g. (height, width, 3) -> (height, width)
        lbl_img = lbl_img.sum(2)
        #convert3 -> 1 
        lbl_img /=3
        #append to list 
        lbl_img_list.append(lbl_img)
    #regenerate lbl_img
    img_array = np.zeros_like(lbl_img)
    for i, img in enumerate(lbl_img_list):
        img_array += img * (i+1)
    return img_array






In [136]:
create_mmseg_labels(label_root_dir, label_list, image_list, save_root_dir)

100%|██████████| 37448/37448 [03:24<00:00, 182.78it/s]done



In [89]:
first_image = mmcv.imread(os.path.join(label_root_dir, label_list[0], image_list[0])).astype(np.float32) / 255

In [90]:
last_image = mmcv.imread(os.path.join(label_root_dir, label_list[1], image_list[0])).astype(np.float32) / 255

In [91]:
first_image = first_image.sum(2) /3
last_image = last_image.sum(2)/3

In [92]:
label_image =np.zeros_like(first_image)

In [93]:
label_image += last_image *2

In [94]:
label_image += first_image

In [95]:
np.unique(label_image)


array([0., 1., 2., 3.], dtype=float32)

In [98]:
label_image = Image.fromarray(label_image.astype(np.uint8))

In [99]:
label_image.save('test_label.png', mode='P')

In [101]:
qq = Image.open('test_label.png')

In [114]:
qw = np.asarray(qq)

In [122]:
func_img = create_label(label_root_dir, label_list, image_list[0])

In [126]:
np.sum(qw == func_img)  == qw.size

True

In [124]:
qw

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=uint8)

In [125]:
np.unique(func_img)

array([0., 1., 2., 3.], dtype=float32)