In [1]:
import sys
import warnings

sys.path.append("../../")
warnings.filterwarnings("ignore")

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from matplotlib.colors import ListedColormap

from finetune.segment.chesapeake_datamodule import ChesapeakeDataModule
from finetune.segment.chesapeake_model import ChesapeakeSegmentor

### Define paths and parameters

In [3]:
CHESAPEAKE_CHECKPOINT_PATH = (
    "../../checkpoints/segment/chesapeake-11class-segment_epoch-36_val-iou-0.6501.ckpt"
)
CLAY_CHECKPOINT_PATH = "../../checkpoints/clay-v1.5.ckpt"
METADATA_PATH = "../../configs/metadata.yaml"

TRAIN_CHIP_DIR = "/home/clebson/Documentos/dataset_goiasmuticlasse4claymodel/data/train/img/"
TRAIN_LABEL_DIR = "/home/clebson/Documentos/dataset_goiasmuticlasse4claymodel/data/train/gt/"
VAL_CHIP_DIR = "/home/clebson/Documentos/dataset_goiasmuticlasse4claymodel/data/val/img/"
VAL_LABEL_DIR = "/home/clebson/Documentos/dataset_goiasmuticlasse4claymodel/data/val/gt/"

BATCH_SIZE = 16
NUM_WORKERS = 1
PLATFORM = "naip"

### Model Loading

In [4]:
def get_model(chesapeake_checkpoint_path, clay_checkpoint_path, metadata_path):
    model = ChesapeakeSegmentor.load_from_checkpoint(
        checkpoint_path=chesapeake_checkpoint_path,
        metadata_path=metadata_path,
        ckpt_path=clay_checkpoint_path,
    )
    model.eval()
    return model

### Data Preparation

In [5]:
def get_data(
    train_chip_dir,
    train_label_dir,
    val_chip_dir,
    val_label_dir,
    metadata_path,
    batch_size,
    num_workers,
    platform,
):
    dm = ChesapeakeDataModule(
        train_chip_dir=train_chip_dir,
        train_label_dir=train_label_dir,
        val_chip_dir=val_chip_dir,
        val_label_dir=val_label_dir,
        metadata_path=metadata_path,
        batch_size=batch_size,
        num_workers=num_workers,
        platform=platform,
    )
    dm.setup(stage="fit")
    val_dl = iter(dm.val_dataloader())
    batch = next(val_dl)
    metadata = dm.metadata
    return batch, metadata

### Prediction

In [6]:
def run_prediction(model, batch):
    with torch.no_grad():
        outputs = model(batch)
    outputs = F.interpolate(
        outputs, size=(256, 256), mode="bilinear", align_corners=False
    )
    return outputs

### Post-Processing

In [7]:
def denormalize_images(normalized_images, means, stds):
    means = np.array(means).reshape(1, -1, 1, 1)
    stds = np.array(stds).reshape(1, -1, 1, 1)
    denormalized_images = normalized_images * stds + means
    return denormalized_images.astype(np.uint8)  # Do for NAIP/LINZ


def post_process(batch, outputs, metadata):
    preds = torch.argmax(outputs, dim=1).detach().cpu().numpy()
    labels = batch["label"].detach().cpu().numpy()
    pixels = batch["pixels"].detach().cpu().numpy()

    means = list(metadata["naip"].bands.mean.values())
    stds = list(metadata["naip"].bands.std.values())
    norm_pixels = denormalize_images(pixels, means, stds)

    images = rearrange(norm_pixels[:, :3, :, :], "b c h w -> b h w c")

    return images, labels, preds

### Plotting

In [9]:
def plot_predictions(images, labels, preds):
    colors = [
        "#ffffff",  # Branco
        "#1f8d49",  # Verde Escuro
        "#7dc975",  # Verde Claro
        "#7a6c00",  # Amarelo Mostarda
        "#519799",  # Azul Petróleo
        "#d6bc74",  # Bege
        "#C27BA0",  # Rosa Claro
        "#ffefc3",  # Creme
        "#db4d4f",  # Vermelho Suave
        "#ffaa5f",  # Laranja Claro
    ]

    cmap = ListedColormap(colors)

    fig, axes = plt.subplots(12, 8, figsize=(12, 18))

    # Plot the images
    plot_data(axes, images, row_offset=0, title="Image")

    # Plot the actual segmentation maps
    plot_data(axes, labels, row_offset=1, title="Actual", cmap=cmap, vmin=0, vmax=len(colors)-1)

    # Plot the predicted segmentation maps
    plot_data(axes, preds, row_offset=2, title="Pred", cmap=cmap, vmin=0, vmax=len(colors)-1)

    plt.tight_layout()
    plt.show()

def plot_data(ax, data, row_offset, title=None, cmap=None, vmin=None, vmax=None):
    for i, item in enumerate(data):
        row = row_offset + (i // 8) * 3
        col = i % 8
        ax[row, col].imshow(item, cmap=cmap, vmin=vmin, vmax=vmax)
        ax[row, col].axis("off")
        if title and col == 0:
            ax[row, col].set_title(title, rotation=0, fontsize=12)




In [10]:
# Load model
model = get_model(CHESAPEAKE_CHECKPOINT_PATH, CLAY_CHECKPOINT_PATH, METADATA_PATH)

OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB. GPU 0 has a total capacity of 5.80 GiB of which 6.38 MiB is free. Including non-PyTorch memory, this process has 5.74 GiB memory in use. Of the allocated memory 5.65 GiB is allocated by PyTorch, and 15.05 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Get data
batch, metadata = get_data(
    TRAIN_CHIP_DIR,
    TRAIN_LABEL_DIR,
    VAL_CHIP_DIR,
    VAL_LABEL_DIR,
    METADATA_PATH,
    BATCH_SIZE,
    NUM_WORKERS,
    PLATFORM,
)
# Move batch to GPU
batch = {k: v.to("cuda") for k, v in batch.items()}

In [None]:
# Run prediction
outputs = run_prediction(model, batch)

In [None]:
# Post-process the results
images, labels, preds = post_process(batch, outputs, metadata)

In [None]:
# Plot the predictions
plot_predictions(images, labels, preds)