In [57]:
import os
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from skimage.transform import resize
from skimage.io import imread
from torchvision.io import read_image
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
from cs_6804_project.src.torch_cloudnet.model import CloudNet
from cs_6804_project.src.keras_cloudnet.utils import get_input_image_names
from cs_6804_project.src.keras_cloudnet.augmentation import (flipping_img_and_msk, rotate_cclk_img_and_msk,
                                                             rotate_clk_img_and_msk, zoom_img_and_msk)

In [2]:
# Paths to data
GLOBAL_PATH = "D:/38-Cloud/"
TRAIN_FOLDER = os.path.join(GLOBAL_PATH, '38-Cloud_training')
TEST_FOLDER = os.path.join(GLOBAL_PATH, '38-Cloud_test')

# Set params
in_rows = 192
in_cols = 192
num_of_channels = 4
num_of_classes = 1
starting_learning_rate = 1e-4
end_learning_rate = 1e-8
max_num_epochs = 2000  # just a huge number. The actual training should not be limited by this value
val_ratio = 0.2
patience = 15
decay_factor = 0.7
batch_sz = 12
max_bit = 65535  # maximum gray level in landsat 8 images
experiment_name = "Cloud-Net"
weights_path = os.path.join(GLOBAL_PATH, experiment_name + '.h5')
train_resume = False

100%|██████████████████████████████████████████████████████████████████████████| 5155/5155 [00:00<00:00, 396305.53it/s]


In [None]:
# Get input/target image names
train_patches_csv_name = 'training_patches_38-cloud_nonempty.csv'
df_train_img = pd.read_csv(os.path.join(GLOBAL_PATH, train_patches_csv_name))
train_img, train_msk = get_input_image_names(df_train_img, TRAIN_FOLDER, if_train=True)

In [None]:
# Split data into training and validation
train_img_split, val_img_split, train_msk_split, val_mask_split = train_test_split(train_img, train_msk,
                                                                                   test_size=val_ratio,
                                                                                   random_state=42, shuffle=True)

In [None]:
class CloudDataset(Dataset):
    def __init__(self, train_files, target_files, img_rows, img_cols, max_bit, transform=False):
        self.train_files = train_files
        self.target_files = target_files
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.max_bit = max_bit
        self.transform = transform

    def __len__(self):
        return len(self.target_files)

    def __getitem__(self, idx):
        # Get input images
        image_red = imread(self.train_files[idx][0])
        image_green = imread(self.train_files[idx][1])
        image_blue = imread(self.train_files[idx][2])
        image_nir = imread(self.train_files[idx][3])
        images = np.stack((r, g, b, nir),axis=-1).astype('int32')
        images = resize(images, (self.img_rows, self.img_cols), preserve_range=True, mode='symmetric')
        # Get target image
        target = imread(self.target_files[idx])
        target = resize(target, (self.img_rows, self.img_cols), preserve_range=True, mode='symmetric')
        
        # Perform image augmentation
        if self.transform:
            images, target = self.transform(images)
        
        images = ToTensor()(images)
        target = ToTensor()(target)
        return images, target
    
    def transform(self, images, target):
        rnd_flip = np.random.randint(2, dtype=int)
        rnd_rotate_clk = np.random.randint(2, dtype=int)
        rnd_rotate_cclk = np.random.randint(2, dtype=int)
        rnd_zoom = np.random.randint(2, dtype=int)

        if rnd_flip == 1:
            images, mask = flipping_img_and_msk(images, mask)

        if rnd_rotate_clk == 1:
            images, mask = rotate_clk_img_and_msk(images, mask)

        if rnd_rotate_cclk == 1:
            images, mask = rotate_cclk_img_and_msk(images, mask)

        if rnd_zoom == 1:
            images, mask = zoom_img_and_msk(images, mask)

        mask = mask[..., np.newaxis]
        mask /= 255
        images /= max_possible_input_value
        return images, target