In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable

from torchvision import transforms as T
import torchvision



from PIL import Image
import cv2
import albumentations as A

import time
import os
from tqdm.notebook import tqdm

from pathlib import Path
import tifffile as tiff

import cv2 as cv


In [None]:
class DroneMaskCreator(Dataset):
    def __init__(self, src_path, dst_dirname='label_tiff', scale=0.3):
        
        self.root = src_path
        self.dst_dirname = dst_dirname
        self.scale = scale
        
        self._load_files()
        self._create_dataframe()
        self._create_classes_dataframe()
    
    def _load_files(self):
        ROOT_PATH = self.root
        IMAGE_PATH = Path(ROOT_PATH).joinpath('images')
        MASK_PATH = Path(ROOT_PATH).joinpath('gt/semantic/label_images')
        TIFF_PATH = Path(ROOT_PATH).joinpath('gt/semantic/').joinpath(self.dst_dirname)
        CLASS_DICT = Path(ROOT_PATH).joinpath('gt/semantic/class_dict.csv')
        
        self.tiff_path = TIFF_PATH
        os.makedirs(self.tiff_path, exist_ok=True)
        
        
        self.image_files = sorted(list(IMAGE_PATH.glob('*.jpg')))
        self.mask_files = sorted(list(MASK_PATH.glob('*.png')))
        self.class_df = pd.read_csv(str(CLASS_DICT))
        
    def _create_dataframe(self):
        names, imgs_path, msks_path = [],[],[]
        for idx in range(len(self.image_files)):
            name = self.image_files[idx].stem
            img_path = str(self.image_files[idx])
            msk_path = str(self.mask_files[idx])

            names.append(name)
            imgs_path.append(img_path)
            msks_path.append(msk_path)

        data_dict = {
            'id': names,
            'image_path': imgs_path,
            'mask_path': msks_path
        }
        self.files = pd.DataFrame(data_dict)
        
    def _create_classes_dataframe(self):
        names, colors = [], []
        for idx in range(len(self.class_df)):
            name = self.class_df['name'].iloc[idx]
            r = self.class_df['red'].iloc[idx] 
            g = self.class_df['green'].iloc[idx] 
            b = self.class_df['blue'].iloc[idx] 
            color = (r,g,b)
            names.append(name)
            colors.append(color)

        classes_dict = {
            'name': names,
            'color': colors,
        }

        self.classes_dataframe = pd.DataFrame(classes_dict)

                
    def _load_image(self, path, to_np=False):
        img = Image.open(path)
        if to_np:
            img = np.array(img)
        return img
    
    def _scale_image(self, img):
        h,w,d = img.shape
        nw = int(w * self.scale)
        nh = int(h * self.scale)
        dim = (nw,nh)
        
        resized = cv.resize(img, dim, interpolation = cv.INTER_AREA)
        return resized
    
    def _convert_to_classes_masks(self, mask):
        masks = []
        colors = self.classes_dataframe['color']
        h,w,d = mask.shape
        for color in colors:
            canvas = np.zeros((h,w))
            label_img = mask == color
            label_img = label_img.astype(np.int64)
            label_img = ((label_img[:,:,0] * 255) + (label_img[:,:,1] * 255) + (label_img[:,:,2] * 255))/3
            label_img = label_img.astype(np.int64)
            masks.append(label_img)
        masks = np.dstack(masks)
        return masks
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
#         impath = self.files['image_path'].iloc[idx]
        mspath = self.files['mask_path'].iloc[idx]
        mask_name = Path(mspath).stem
        mask_new_name = f'{str(mask_name)}.tiff'
        
#         orig_img = self._load_image(impath)
        mask_img = self._load_image(mspath, to_np=True)
        mask_img = self._scale_image(mask_img)
        mask_img = self._convert_to_classes_masks(mask_img)
        
        return mask_img, mask_new_name
    
    def run_convert_and_save(self, compression_val=1):
        lenght = self.__len__()
        for idx in tqdm(range(lenght)):
            mask_img, mask_name = self.__getitem__(idx)
            mask_fname = str(self.tiff_path.joinpath(mask_name))
            tiff.imsave(mask_fname, mask_img, compress=compression_val)
           

In [None]:
ROOT_PATH = '/data/semantic_drone/training_set'
dmc = DroneMaskCreator(src_path=ROOT_PATH, scale=0.2)
dmc.run_convert_and_save(compression_val=1) 