In [15]:
import numpy as np
import os
import time

import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
from sklearn.metrics import jaccard_score, accuracy_score
import matplotlib.pyplot as plt
import skimage.io

from config import msd_testing_root
from misc import check_mkdir, crf_refine
from mirrornet import MirrorNet
from misc import compute_iou, compute_acc_mirror, compute_acc_image, compute_mae, compute_ber

In [11]:
device_ids = [0]
torch.cuda.set_device(device_ids[0])
data_root = "/home/research/Datasets/NormalNet/MSD/test"
ckpt_path = "./final_results"
exp_name = "MirrorNet"
args = {"snapshot": "160", "scale": 384, "crf": True}

img_transform = transforms.Compose(
    [
        transforms.Resize((args["scale"], args["scale"])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

to_test = {"MSD": msd_testing_root}

to_pil = transforms.ToPILImage()

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [12]:
net = MirrorNet(
        backbone_path="/home/research/Datasets/NormalNet/ICCV2019_MirrorNet/backbone/resnext/resnext_101_32x4d.pth"
    ).to(device)

finetuned_path = os.path.join(ckpt_path, f"{exp_name}_base.pth")
    # net.load_state_dict(torch.load(finetuned_path))
net.load_state_dict(torch.load(finetuned_path, map_location=torch.device("cuda")))

trying to load weights from /home/research/Datasets/NormalNet/ICCV2019_MirrorNet/backbone/resnext/resnext_101_32x4d.pth
Load ResNeXt Weights Succeed!


<All keys matched successfully>

In [18]:
net.eval()

with torch.no_grad():
    img_list = [
                img_name
                for img_name in os.listdir(
                    "/home/research/Datasets/NormalNet/ICCV2019_MirrorNet/final_results/MirrorNet/MirrorNet_160_nocrf"
                )
            ]
    for idx, img_name in enumerate(img_list):
        gt_mask = Image.open(
                    os.path.join(
                        "/home/research/Datasets/NormalNet/MSD/test/",
                        "mask",
                        f"{img_name.split('.',1)[0]}.png",
                    )
                )
        img_path = os.path.join("/home/research/Datasets/NormalNet/ICCV2019_MirrorNet/final_results/MirrorNet/MirrorNet_160_nocrf", img_name)
        predict_mask = Image.open(img_path)
        
        # Visualization
        iou_score = compute_iou(predict_mask, gt_mask)
        acc_mirror = compute_acc_mirror(predict_mask, gt_mask)
        acc_image = compute_acc_image(predict_mask, gt_mask)
        mae = compute_mae(predict_mask, gt_mask)
        ber = compute_ber(predict_mask, gt_mask)
        
        # Visualization
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 3, 1)
        plt.imshow(np.array(Image.open(os.path.join("/path/to/images", img_name))))
        plt.title(f"Image: {img_name}")
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(gt_mask, cmap='gray')
        plt.title("Ground Truth Mask")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(predict_mask, cmap='gray')
        plt.title("Predicted Mask")
        plt.axis('off')
        
        plt.suptitle(f"IoU: {iou_score:.3f}, Acc (Mirror): {acc_mirror:.3f}, Acc (Image): {acc_image:.3f}, MAE: {mae:.3f}, BER: {ber:.3f}")
        
        plt.show()
        
        # For demonstration, let's break after the first image-mask pair to avoid flooding the output.
        break


AttributeError: 'PngImageFile' object has no attribute 'shape'