In [1]:
import os
import sys
import rasterio
import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms.functional import crop
from rasterio import windows
from tqdm import tqdm
from safetensors.torch import load_model
import numpy as np

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.inference_dataset import DeadwoodInferenceDataset
from unet.unet_model import UNet
from unet.train_dataset import get_windows

In [2]:
inference_file = "/net/data_ssd/tree_mortality_orthophotos/orthophotos/spain_13_09_2023_east_granada_1_ortho.tif"

In [3]:
run_id = "cgv3wg89"
api = wandb.Api()
experiment = api.run(f"jmoehring/standing-deadwood-unet-pro/{run_id}")

In [4]:
fold: int = 1
epoch: int = 64

experiment_name = experiment.name.replace(f"_fold{fold}", "")
model_path = f"/net/home/jmoehring/experiments/{experiment_name}/fold_{fold}_epoch_{epoch}/model.safetensors"
output_dir = f"/net/scratch/jmoehring/inference/{experiment.name}"
filename = os.path.basename(inference_file)

In [5]:
dataset = DeadwoodInferenceDataset(inference_file, padding=256, tile_size=1024)

In [6]:
dataset.cropped_windows

[Window(col_off=256, row_off=256, width=512, height=512),
 Window(col_off=256, row_off=768, width=512, height=512),
 Window(col_off=256, row_off=1280, width=512, height=512),
 Window(col_off=256, row_off=1792, width=512, height=512),
 Window(col_off=256, row_off=2304, width=512, height=512),
 Window(col_off=256, row_off=2816, width=512, height=512),
 Window(col_off=256, row_off=3328, width=512, height=512),
 Window(col_off=256, row_off=3840, width=512, height=512),
 Window(col_off=256, row_off=4352, width=512, height=512),
 Window(col_off=256, row_off=4864, width=512, height=512),
 Window(col_off=256, row_off=5376, width=512, height=512),
 Window(col_off=256, row_off=5888, width=512, height=512),
 Window(col_off=256, row_off=6400, width=512, height=512),
 Window(col_off=256, row_off=6912, width=512, height=512),
 Window(col_off=256, row_off=7424, width=512, height=512),
 Window(col_off=256, row_off=7936, width=512, height=512),
 Window(col_off=256, row_off=8448, width=512, height=512),

In [7]:
loader_args = {
    "batch_size": 1,
    "num_workers": 4,
    "pin_memory": True,
    "shuffle": False,
}
inference_loader = DataLoader(dataset, **loader_args)

In [8]:
# preferably use GPU
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
device_ids = [0]
# model with three input channels (RGB)
model = UNet(
    n_channels=3,
    n_classes=1,
).to(memory_format=torch.channels_last)
load_model(model, model_path)
model = model.to(memory_format=torch.channels_last, device=device)

model.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [9]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_image_path = os.path.join(output_dir, filename)
output_image = rasterio.open(
    output_image_path,
    "w+",
    driver="GTiff",
    width=dataset.width,
    height=dataset.height,
    count=1,
    dtype=rasterio.uint8,
    crs=dataset.image_src.crs,
    transform=dataset.image_src.transform,
    compress="DEFLATE",
)


for images, cropped_windows in tqdm(inference_loader):
    images = images.to(device=device, memory_format=torch.channels_last)
    with torch.no_grad():
        output = model(images)
        output = torch.sigmoid(output)
        output = (output > 0.3).float()

        # crop tensor by dataset padding
        output = crop(
            output,
            top=dataset.padding,
            left=dataset.padding,
            height=dataset.tile_size - (2 * dataset.padding),
            width=dataset.tile_size - (2 * dataset.padding),
        )

        # add white edge to image of two pixels
        # output = torch.nn.functional.pad(output, (2, 2, 2, 2), value=2)

        output_image.write(
            output[0].cpu().numpy(),
            window=windows.Window(
                cropped_windows["col_off"].item(),
                cropped_windows["row_off"].item(),
                cropped_windows["width"].item(),
                cropped_windows["height"].item(),
            ),
        )
output_image.close()

100%|██████████| 3200/3200 [03:04<00:00, 17.39it/s]
