In [1]:
import os
import sys
import rasterio
import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms.functional import crop
from rasterio import windows
from tqdm import tqdm
import numpy as np

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.inference import DeadwoodInferenceDataset
from unet.unet_model import UNet
from unet.dataset import get_windows

In [2]:
inference_file = "/net/data_ssd/tree_mortality_orthophotos/temp/uavforsat_DDH002_ortho.reprojected.0.06.4326.tif"

In [3]:
run_id = "garg20i1"
api = wandb.Api()
experiment = api.run(f"jmoehring/standing-deadwood-unet-pro/{run_id}")

In [4]:
fold: int = 0
epoch: int = 19

checkpoint_dir = f"/net/scratch/jmoehring/checkpoints/{experiment.name}"
model_checkpoint = f"fold_{fold}_epoch_{epoch}.pth"
output_dir = f"/net/scratch/jmoehring/inference/{experiment.name}"
filename = os.path.basename(inference_file)

In [5]:
dataset = DeadwoodInferenceDataset(inference_file, padding=128)

In [6]:
dataset.cropped_windows

[Window(col_off=128, row_off=128, width=256, height=256),
 Window(col_off=128, row_off=384, width=256, height=256),
 Window(col_off=128, row_off=640, width=256, height=256),
 Window(col_off=128, row_off=896, width=256, height=256),
 Window(col_off=128, row_off=1152, width=256, height=256),
 Window(col_off=128, row_off=1408, width=256, height=256),
 Window(col_off=128, row_off=1664, width=256, height=256),
 Window(col_off=128, row_off=1920, width=256, height=256),
 Window(col_off=128, row_off=2176, width=256, height=256),
 Window(col_off=128, row_off=2432, width=256, height=256),
 Window(col_off=128, row_off=2493, width=256, height=256),
 Window(col_off=384, row_off=128, width=256, height=256),
 Window(col_off=384, row_off=384, width=256, height=256),
 Window(col_off=384, row_off=640, width=256, height=256),
 Window(col_off=384, row_off=896, width=256, height=256),
 Window(col_off=384, row_off=1152, width=256, height=256),
 Window(col_off=384, row_off=1408, width=256, height=256),
 Wind

In [7]:
loader_args = {
    "batch_size": 1,
    "num_workers": 4,
    "pin_memory": True,
    "shuffle": False,
}
inference_loader = DataLoader(dataset, **loader_args)

In [8]:
# preferably use GPU
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
device_ids = [0]
# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
if torch.cuda.device_count() > 1:
    device_ids = [2]
model = nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, model_checkpoint)))
model = model.to(memory_format=torch.channels_last, device=device)

model.eval()

DataParallel(
  (module): UNet(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (down1): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(128, 128, 

In [9]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_image_path = os.path.join(output_dir, filename)
output_image = rasterio.open(
    output_image_path,
    "w+",
    driver="GTiff",
    width=dataset.width,
    height=dataset.height,
    count=1,
    dtype=rasterio.uint8,
    crs=dataset.image_src.crs,
    transform=dataset.image_src.transform,
    compress="DEFLATE",
)


for images, cropped_windows in tqdm(inference_loader):
    images = images.to(device=device, memory_format=torch.channels_last)
    with torch.no_grad():
        output = model(images)
        output = torch.sigmoid(output)
        output = (output > 0.3).float()

        # crop tensor by dataset padding
        output = crop(
            output,
            top=dataset.padding,
            left=dataset.padding,
            height=dataset.tile_size - (2 * dataset.padding),
            width=dataset.tile_size - (2 * dataset.padding),
        )

        # add white edge to image of two pixels
        output = torch.nn.functional.pad(output, (1, 1, 1, 1), value=2)

        output_image.write(
            output[0].cpu().numpy(),
            window=windows.Window(
                cropped_windows["col_off"].item(),
                cropped_windows["row_off"].item(),
                cropped_windows["width"].item(),
                cropped_windows["height"].item(),
            ),
        )
output_image.close()

100%|██████████| 242/242 [00:03<00:00, 63.81it/s]
