In [1]:
from tensorflow import keras
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.image import load_img
import random,os,sys
import rasterio as rs
from rasterio.windows import Window,from_bounds
import geopandas as gpd

In [2]:
class fetch_image(keras.utils.Sequence):
    
    def __init__(self,input_dir_path,target_dir_path):
        
        self.input_dir_path = input_dir_path
        self.target_dir_path = target_dir_path
        
        self.input_img_paths = []
        for dirname, dirnames, filenames in os.walk(self.input_dir_path):
            self.input_img_paths += [os.path.join(dirname, filename) for filename in filenames]
            
        self.target_img_paths = []
        for dirname, dirnames, filenames in os.walk(self.target_dir_path):
            self.target_img_paths += [os.path.join(dirname, filename) for filename in filenames]
        
        
        self.input_img_paths = sorted(self.input_img_paths)
        self.target_img_paths = sorted(self.target_img_paths)
        
        random.Random(1337).shuffle(self.input_img_paths)
        random.Random(1337).shuffle(self.target_img_paths)
        
    def __len__(self):
        return len(self.target_img_paths)
    
    def __getitem__(self,idx):
        
        image = np.moveaxis(rs.open(self.input_img_paths[idx]).read(),0,2)
        label = np.moveaxis(rs.open(self.target_img_paths[idx]).read(),0,2)
        
        # image = self.input_img_paths[idx]
        # label = self.target_img_paths[idx]
        
        return np.expand_dims(image,axis=0), np.expand_dims(label,axis=0)

In [3]:
class fetch_image_batch(keras.utils.Sequence):
    
    def __init__(self,input_dir_path:str,target_dir_path:str,batch_size:int,img_size:tuple):
        
        self.input_dir_path = input_dir_path
        self.target_dir_path = target_dir_path
        self.batch_size = batch_size
        self.img_size = img_size
        
        self.input_img_paths = []
        for dirname, dirnames, filenames in os.walk(self.input_dir_path):
            self.input_img_paths += [os.path.join(dirname, filename) for filename in filenames]
            
        self.target_img_paths = []
        for dirname, dirnames, filenames in os.walk(self.target_dir_path):
            self.target_img_paths += [os.path.join(dirname, filename) for filename in filenames]
        
        
        self.input_img_paths = sorted(self.input_img_paths)
        self.target_img_paths = sorted(self.target_img_paths)
        
        random.Random(1337).shuffle(self.input_img_paths)
        random.Random(1337).shuffle(self.target_img_paths)
        
    def __len__(self):
        return len(self.target_img_paths) // self.batch_size
    
    def __getitem__(self,idx):
        
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        
        image = np.zeros((self.batch_size,) + self.img_size + (9,), dtype="float32") # number of channel should be changed
        for j, path in enumerate(batch_input_img_paths):
            img = np.moveaxis(rs.open(path).read(),0,2)
            image[j] = img
            
        label = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32") # number of channel should be changed
        for j, path in enumerate(batch_target_img_paths):
            img = np.moveaxis(rs.open(path).read(),0,2)
            label[j] = img

#         image = []
#         for j, path in enumerate(batch_input_img_paths):
#             image.append(path)
            
#         label = []
#         for j, path in enumerate(batch_target_img_paths):
#             label.append(path)
        
        return image, label

In [4]:
class fetch_geom_batch(keras.utils.Sequence):
    
    def __init__(self,patches_path:str,feature_image_path:str, label_image_path:str, label:str, img_size:tuple, batch_size:int):
        
        self.patches_path = patches_path
        self.feature_image_path = feature_image_path
        self.label_image_path = label_image_path
        self.label = label
        self.batch_size = batch_size
        self.img_size = img_size
        
        self.patches = gpd.read_file(self.patches_path)
        self.patches_label = self.patches[self.patches['label'] == self.label]
        self.patches_label_shuffle = self.patches_label.sample(frac=1)
        
        
    def __len__(self):
        return len(self.patches_label) // self.batch_size
    
    def __getitem__(self,idx):
        
        i = idx * self.batch_size
        geom_patch_batch = self.patches_label_shuffle.iloc[i : i + self.batch_size]
        
        
        feat_img = rs.open(self.feature_image_path)
        label_img = rs.open(self.label_image_path)

        image = np.zeros((self.batch_size,) + self.img_size + (9,), dtype="float32") # number of channel should be changed
        label = np.zeros((self.batch_size,) + self.img_size + (1,), dtype="float32")
        
        for j, i in enumerate(range(len(geom_patch_batch))):
            minx,miny,maxx,maxy = geom_patch_batch.iloc[i, 3].bounds # Geometry column should be selected
            window_f = from_bounds(minx, miny, maxx, maxy, transform=feat_img.transform)
            window_l = from_bounds(minx, miny, maxx, maxy, transform=label_img.transform)
            transform = rs.transform.from_bounds(minx,miny,maxx,maxy, self.img_size[0], self.img_size[1])
            img_feat = np.moveaxis(feat_img.read(window=window_f,out_shape=self.img_size,resampling=0),0,2)
            img_label = np.moveaxis(label_img.read(window=window_l,out_shape=self.img_size,resampling=0),0,2)
            image[j] = img_feat
            label[j] = img_label
        
        return image, label[:,3:-3,3:-3,:]