In [1]:
import os

import numpy as np
import pandas as pd
import rasterio
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms

class TestSet(Dataset):
    def __init__(self, depth_layer, gt_dir, normalize):
        self.RGB_dir = os.path.join(gt_dir, 'rgb')
        self.hs_dir = os.path.join(gt_dir, 'hs')
        self.dem_dir = os.path.join(gt_dir, 'dem')
        self.groundtruth_dir = os.path.join(gt_dir, 'groundtruth_mask')
        self.depth_dir = self.hs_dir if depth_layer == "hs" else self.dem_dir
        self.labels_path = os.path.join(gt_dir, 'new_palsa_labels.csv')
        self.im_size = 200
        self.normalize = normalize

        # configure labels file.
        # only use samples where MS-Backe difference is <10%
        unfiltered_labels_df = pd.read_csv(self.labels_path, index_col=0)
        self.labels_df = unfiltered_labels_df.loc[
            (unfiltered_labels_df['difference']<10)]


    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        img_name = self.labels_df.index[idx]
        RGB_img_path = os.path.join(self.RGB_dir, f"{img_name}.tif")
        hs_img_path = os.path.join(self.depth_dir, f"{img_name}.tif")

        with rasterio.open(RGB_img_path) as RGB_src:
            # Read the image data
            RGB_img = RGB_src.read()

        with rasterio.open(hs_img_path) as hs_src:
            # Read the image data
            hs_img = hs_src.read()

        # convert and upsample hs image
        hs_image_array = np.array(hs_img)
        hs_image_tensor = torch.from_numpy(hs_image_array)
        hs_image_tensor = hs_image_tensor.float()
        bilinear = nn.Upsample(size=self.im_size*2, mode='bilinear')
        hs_upsampled_tensor = bilinear(hs_image_tensor.unsqueeze(0)).squeeze(0)

        # converting RGB to tensor
        RGB_image_array = np.array(RGB_img)
        RGB_image_tensor = torch.from_numpy(RGB_image_array)
        RGB_image_tensor = RGB_image_tensor.float()

        combined_tensor = torch.concatenate((RGB_image_tensor, hs_upsampled_tensor))

        if self.normalize:
            # use dataset wide calculated means and standard deviations
            if str(self.depth_dir).endswith('hs'):
                transforms.Normalize(mean=[74.90, 85.26, 80.06, 179.18],
                                     std=[15.05, 13.88, 12.01, 10.65])
                pass
            if str(self.depth_dir).endswith('dem'):
                transforms.Normalize(mean=[74.90, 85.26, 80.06,608.95],
                                     std=[15.05, 13.88, 12.01, 2.30])
                pass

        label = self.labels_df.iloc[idx, 0]
        binary_label = 1 if label > 0 else 0
        perc_label = label/100

        # grab ground truth mask
        gt_img_path = os.path.join(self.groundtruth_dir, f"{img_name}.tif")
        with rasterio.open(gt_img_path) as gt_src:
            gt_mask = gt_src.read()

        gt_image_array = np.array(gt_mask)
        gt_image_tensor = torch.from_numpy(gt_image_array)
        gt_image_tensor = gt_image_tensor.float()
        gt_upsampled_tensor = bilinear(gt_image_tensor.unsqueeze(0)).squeeze(0)  # OUTPUT is np array (1,200,200)

        return combined_tensor, binary_label, perc_label, gt_upsampled_tensor, img_name

: 

In [None]:
############
# Imports #
############

import json
import os

import torch
import wandb
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

testset_dir = '/Users/nadja/Documents/UU/Thesis/Data/FINALFINAL_200m_groundtruths'
depth_layer = 'hs'
normalize = True

test_set = TestSet(depth_layer, testset_dir, normalize)
test_loader = DataLoader(test_set, batch_size=1, shuffle=True, num_workers=1)



In [None]:
im, lab, perc_label, gt_mask, img_name = next(iter(test_loader))
if lab != 0: 

    print(img_name)

    fig, (ax1, ax2) = plt.subplots(1,2, figsize = (10,6))

    ax1.imshow(cpu_img[:,:,:3])
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_title('original image')

    cpu_img = im.squeeze().cpu().detach().permute(1,2,0).long().numpy()

    ax2.imshow(gt.squeeze(0).permute(1,2,0).long().numpy(), cmap=cmap, norm=norm)
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_title('Ground Truth')

    plt.tight_layout()
    plt.show()