In [1]:
import torch
from torch import nn
torch.cuda.is_available()
from torchgeo.samplers import RandomGeoSampler
from torch.utils.data import DataLoader
from torchgeo.datasets import stack_samples
import os
from torchgeo.datasets import RasterDataset
import pylab as plt
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class BengaluruDataset(RasterDataset):
    """
    Load xView3 polarization data that ends in *_dB.tif
    """

    filename_glob = "*.tif"
    

class BengaluruDatasetLabels(RasterDataset):
    """
    Load xView3 polarization data that ends in *_dB.tif
    """
    is_image = False
    filename_glob = "*.tif"

In [3]:
# base path of the dataset
TRAIN_PATH = os.path.join("Data", "Train")
TEST_PATH = os.path.join("Data", "Test")

tr_labels = lambda x: x[:,0,:,:].long()

class TransBengaluruImages(nn.Module):
    """."""

    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        inputs["image"] -= inputs["image"].min()
        inputs["image"] /= inputs["image"].max()
        
        return inputs
    
class TransBengaluruLabels(nn.Module):
    """."""

    def __init__(self):
        super().__init__()

    def forward(self, inputs):
        # Batch
        if inputs["mask"].ndim == 4:
            inputs["mask"] = (inputs["mask"][:,0,:,:]).long()
        # Sample
        else:
            inputs["mask"] = (inputs["mask"][0,:,:]).long()
        
        return inputs

ds = BengaluruDataset(os.path.join(TRAIN_PATH, "Images"), transforms=TransBengaluruImages())
la = BengaluruDatasetLabels(os.path.join(TRAIN_PATH, "Labels"), transforms=TransBengaluruLabels())

train_ds = ds & la

ds2 = BengaluruDataset(os.path.join(TEST_PATH, "Images"), transforms=TransBengaluruImages())
la2 = BengaluruDatasetLabels(os.path.join(TEST_PATH, "Labels"), transforms=TransBengaluruLabels())

test_ds = ds2 & la2

Converting BengaluruDatasetLabels resolution from 1.200444240953174 to 1.2004442409532352


In [4]:
DATASET_SIZE = 2048
BATCH_SIZE = 128
IMG_SIZE = 128


train_sampler = RandomGeoSampler(train_ds, size=IMG_SIZE, length=DATASET_SIZE)
test_sampler = RandomGeoSampler(test_ds, size=IMG_SIZE, length=DATASET_SIZE)

train_dl = DataLoader(train_ds, BATCH_SIZE, sampler=train_sampler, collate_fn=stack_samples)
test_dl  = DataLoader(test_ds, BATCH_SIZE, sampler=test_sampler, collate_fn=stack_samples)

In [5]:
def visualize(image, mask):
    """PLot images in one row."""
    fig = plt.figure(figsize=(16, 8))
    #plt.imshow(image[0].transpose(0,2), vmin=0, vmax=1)
    image /= image.numpy().max()
    plt.imshow(image, cmap='gray')
    #plt.imshow(image[0])
    if mask.max()>0:
        plt.imshow(mask/9, alpha=0.25, vmin=0, vmax=1, cmap='Set1', interpolation='nearest')
        
    cb = plt.colorbar(cmap='Dark2', ticks=np.linspace(0.5/9, 3.5/9, 4), boundaries=np.linspace(0,4/9,5))
    cb.set_ticklabels(['Background', 'Soil', 'Herbaceous', 'Woody'])
    plt.show()

In [6]:
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = self.contract_block(in_channels, 32, 7, 3)
        self.conv2 = self.contract_block(32, 64, 3, 1)
        self.conv3 = self.contract_block(64, 128, 3, 1)
        self.conv4 = self.contract_block(128, 256, 3, 1)

        self.upconv4 = self.expand_block(256, 128, 3, 1)
        self.upconv3 = self.expand_block(128*2, 64, 3, 1)
        self.upconv2 = self.expand_block(64*2, 32, 3, 1)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, 1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)

        upconv4 = self.upconv4(conv4)
        upconv3 = self.upconv3(torch.cat([upconv4, conv3], 1))
        upconv2 = self.upconv2(torch.cat([upconv3, conv2], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1], 1))

        return upconv1

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand

In [7]:
unet = torch.load('unet_128px-4ly.pt')
unet.eval();

In [8]:
for sample in train_dl:
    idx = 0
    image = sample["image"].moveaxis(1,3)[idx, :, :, 1]
    target = sample["mask"][idx, :, :]
    pred = unet(sample["image"].cuda()).cpu()[0]

    visualize(image, target)
    visualize(image, torch.argmax(pred, 0))
    break

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "rasterio/_base.pyx", line 302, in rasterio._base.DatasetBase.__init__
  File "rasterio/_base.pyx", line 213, in rasterio._base.open_dataset
  File "rasterio/_err.pyx", line 217, in rasterio._err.exc_wrap_pointer
rasterio._err.CPLE_OpenFailedError: Data/Train/Labels/Labelled_P04.tif: Remote I/O error

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/work/python/dkottke/pytorch/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3398, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3968/3723953950.py", line 1, in <cell line: 1>
    for sample in train_dl:
  File "/mnt/work/python/dkottke/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
  File "/mnt/work/python/dkottke/pytorch/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
  File "/mnt/work/pyth

In [None]:
for sample in test_dl:
    idx = 0
    image = sample["image"].moveaxis(1,3)[idx, :, :, 1]
    target = sample["mask"][idx, :, :]
    pred = unet(sample["image"].cuda()).cpu()[0]

    visualize(image, target)
    visualize(image, torch.argmax(pred, 0))
    break