In [1]:
import sys
import os
import numpy as np

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

from eval import eval_net
from unet import UNet
from utils.dataset import UNetDataset

import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
val_im_folder_path = '/mnt/data/zli85/github/OECModelRetraining/pilot_study/data_drift_existence/cloud_detection/unet/train_tao_unet_cloud/tao_experiments/data/320/val/subscenes'
val_mask_folder_path = '/mnt/data/zli85/github/OECModelRetraining/pilot_study/data_drift_existence/cloud_detection/unet/train_tao_unet_cloud/tao_experiments/data/320/val/masks'

val_dataset = UNetDataset(im_folder_path=val_im_folder_path, mask_folder_path=val_mask_folder_path, format='image')
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

In [4]:
checkpoint_path = './checkpoints/best.pt'

In [5]:
# load model
net = UNet(n_channels=3, n_classes=1, f_channels='model_channels.txt')
net.load_state_dict(torch.load(checkpoint_path))
net.to(device=device)

UNet(
  (inc): inconv(
    (conv): double_conv(
      (conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (down1): down(
    (mpconv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2

In [6]:
# eval
net.eval()

UNet(
  (inc): inconv(
    (conv): double_conv(
      (conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (down1): down(
    (mpconv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2

In [7]:
output_folder = './results'
os.makedirs(output_folder, exist_ok=True)

In [8]:
for batch_idx, (data, target) in enumerate(val_loader):
    data = data.to(device=device)

    with torch.no_grad():
        pred = net(data)
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()

    # save images
    pred = pred.cpu().numpy()
    pred = np.squeeze(pred)
    pred = pred * 255
    pred = pred.astype(np.uint8)

    mask = target.cpu().numpy()
    mask = np.squeeze(mask)
    mask = mask * 255
    mask = mask.astype(np.uint8)

    im = data.cpu().numpy()
    im = np.squeeze(im)
    im = np.moveaxis(im, 0, -1)
    im = im * 255
    im = im.astype(np.uint8)

    fig = plt.figure()
    plt.subplot(1, 3, 1)
    plt.imshow(im)
    plt.title('Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(mask, cmap='gray')
    plt.title('Mask')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(pred, cmap='gray')
    plt.title('Prediction')
    plt.axis('off')

    fig.savefig(os.path.join(output_folder, f'{batch_idx}.png'))
    plt.close(fig)