In [1]:
import os
import sys
import numpy as np
import torch.utils.data as data
import torch

current = os.path.dirname(os.path.realpath('__file__'))
parent = os.path.dirname(current)
sys.path.append(parent)

from utils import get_img_arr, get_mask_arr

In [2]:
images = np.asarray(os.listdir('dataset/manual_annotations/landsat_patches'))
images = np.char.rstrip(images, '.tif')

In [3]:
class SegmentationDataset(data.Dataset):
        def __init__(self, image_folder, mask_folder, n_channels = 3):
            self.image_folder = image_folder
            self.mask_folder = mask_folder

            self.n_channels = n_channels

            self.images = os.listdir(image_folder)
            self.masks = os.listdir(mask_folder)

        def __len__(self):
            return len(self.images)

        def __getitem__(self, idx):
            image_path = os.path.join(self.image_folder, self.images[idx])

            image = get_img_arr(image_path, self.n_channels)

            mask_name = self.images[idx][:-10] + 'v1_' + self.images[idx][-10:]
            if mask_name in self.masks:
                mask_path = os.path.join(self.mask_folder, mask_name)
                mask = get_mask_arr(mask_path)
            else:
                mask = torch.zeros((1, 256, 256))

            return image, mask

In [4]:
dataset = SegmentationDataset('dataset/manual_annotations/landsat_patches', 'dataset/manual_annotations/manual_annotations_patches')
dataloader = data.DataLoader(dataset)
fire_pixels = []
for image, mask in dataloader:
    fire_pixels.append(int(mask.sum()))

In [5]:
images = np.array((images, fire_pixels)).T
images

array([['LC08_L1GT_226074_20200921_20200921_01_RT_p00021', '0'],
       ['LC08_L1GT_226074_20200921_20200921_01_RT_p00022', '0'],
       ['LC08_L1GT_226074_20200921_20200921_01_RT_p00023', '0'],
       ...,
       ['LC08_L1TP_193029_20200914_20200914_01_RT_p00910', '0'],
       ['LC08_L1TP_193029_20200914_20200914_01_RT_p00936', '0'],
       ['LC08_L1TP_193029_20200914_20200914_01_RT_p00937', '0']],
      dtype='<U51')

In [6]:
np.savetxt('dataset/manual_annotations/fire_pixels.csv', images, delimiter = ',', fmt = '%s')

In [7]:
generator = torch.Generator().manual_seed(42)
dataset = SegmentationDataset('dataset/manual_annotations/landsat_patches', 'dataset/manual_annotations/manual_annotations_patches')
train_dataset, val_dataset, test_dataset = data.random_split(dataset, [0.7, 0.2, 0.1])

In [8]:
train_images = images[train_dataset.indices]
val_images = images[val_dataset.indices]
test_images = images[test_dataset.indices]

In [9]:
print(train_images.shape[0])
print((train_images[:, 1].astype('int') > 0).sum())
print(train_images[:, 1].astype('int').sum())

6324
66
37378


In [10]:
print(val_images.shape[0])
print((val_images[:, 1].astype('int') > 0).sum())
print(val_images[:, 1].astype('int').sum())

1806
21
17383


In [11]:
print(test_images.shape[0])
print((test_images[:, 1].astype('int') > 0).sum())
print(test_images[:, 1].astype('int').sum())

903
13
17070


In [12]:
np.savetxt('dataset/manual_annotations/train.csv', train_images, delimiter = ',', fmt = '%s')
np.savetxt('dataset/manual_annotations/val.csv', val_images, delimiter = ',', fmt = '%s')
np.savetxt('dataset/manual_annotations/test.csv', test_images, delimiter = ',', fmt = '%s')