In [1]:
import numpy as np
import os
from pathlib import Path
from dataset import CBERS4A_CloudDataset
from model import UNet
import config
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import torch

In [2]:
imgs_dir = '/media/reginaldo/pfc-dados/dataset-piloto/dados-validacao/images_LOCAL_CONFIG_6bands'
masks_dir = '/media/reginaldo/pfc-dados/dataset-piloto/dados-validacao/masks_LOCAL_CONFIG'

val_ds = CBERS4A_CloudDataset(imgs_dir=imgs_dir, masks_dir=masks_dir)

val_dataloader = DataLoader(
                            val_ds,
                            batch_size=config.BATCH_SIZE,
                            num_workers=config.BATCH_SIZE,
                            shuffle=False,
                            )

In [3]:
def generate_visualization(fig_title=None, fig_size=None, font_size=16, **images):
        n = len(images)
        fig_size = (16, 5) if fig_size is None else fig_size
        fig, axarr = plt.subplots(1, n, figsize=fig_size)
        if fig_title is not None:
            fig.suptitle(fig_title, fontsize=font_size)
        for i, (name, image) in enumerate(images.items()):
            plt.subplot(1, n, i + 1)
            plt.xticks([])
            plt.yticks([])
            plt.title(" ".join(name.split("_")).title())
            if image.shape == (4, 512, 512):
                image.transpose([1,2,0])
                image = image[:,:,:3]
            plt.imshow(image)
        fig.subplots_adjust(top=0.8)
        return axarr, fig

def save_inference_to_disk(plot, image_name, output_path):
    image_name = Path(image_name).name.split(".")[0]
    report_path = os.path.join(
        output_path,
        "report_image_{name}.jpg".format(
            name=image_name,
        ),
    )
    plot.savefig(report_path, format="jpg", bbox_inches="tight")
    return 

In [4]:
models = ['/media/reginaldo/pfc-dados/dataset-piloto/testes/treino1/epoch=25-step=1846.ckpt',
          '/media/reginaldo/pfc-dados/dataset-piloto/testes/treino2/epoch=60-step=4331.ckpt',
          '/media/reginaldo/pfc-dados/dataset-piloto/testes/treino3/epoch=44-step=3870.ckpt',
          '/media/reginaldo/pfc-dados/dataset-piloto/testes/treino4/epoch=53-step=4644.ckpt',
          '/media/reginaldo/pfc-dados/dataset-piloto/testes/treino5/epoch=67-step=5848.ckpt']
model_1, model_2, model_3, model_4, model_5 = models

In [8]:
def val_inference(dataloader=val_dataloader):

    val_dl = dataloader.dataset
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = UNet.load_from_checkpoint(checkpoint_path=model_5)
    model = model.float()
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(val_dl):
            if idx % 1000 == 0:
                images, masks = batch['image'], batch['mask']

                images = images.unsqueeze(0)
                images = images.to(device)

                # mask plot preparation
                predicted_mask = model(images)
                predicted_mask = predicted_mask.to("cpu")

                masks = masks.numpy().astype(np.uint8)
                masks = np.squeeze(masks)

                # image plot preparation
                images = images.to("cpu")
                images = images.numpy()
                images = np.squeeze(images)
                images = (images / images.max()) * 255.
                images = images.astype(np.uint8)
                images = images.transpose((1,2,0))
                images = images[:,:,:3]

                # predicted mask plot preparation
                predicted_mask = predicted_mask.numpy()
                predicted_mask = np.squeeze(predicted_mask)
                predicted_mask = np.argmax(predicted_mask, axis=0)
                predicted_mask = predicted_mask.astype(np.uint8)
                
                plot_title = f'batch_{idx}'
                plt_result, fig = generate_visualization(
                    image=images,
                    ground_truth_mask=masks,
                    predicted_mask=predicted_mask)

                save_inference_to_disk(fig, plot_title, output_path='/media/reginaldo/pfc-dados/dataset-piloto/dados-validacao/resultados-treino5')

            plt.close(fig)
        return

In [9]:
val_inference()

Lightning automatically upgraded your loaded checkpoint from v1.7.1 to v2.3.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../../../media/reginaldo/pfc-dados/dataset-piloto/testes/treino5/epoch=67-step=5848.ckpt`
