In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from load_data import CSAWS
from utils import *
from config import *
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
CUSTOM_COLORMAP = [
    (0, 0, 0),  # Background
    (255, 0, 0),  # Nipple
    (0, 0, 255)  # Pectoral muscle
]
unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

In [None]:
def argument_parser():
    parser = argparse.ArgumentParser(description='Testing a segmentation model')
    parser.add_argument('--init_model_file', default=BEST_MODEL_DIR, help='Path to the trained model file', dest='init_model_file')
    parser.add_argument('--test_image_dir', default=TEST_IMAGE, help='Path to the test data file', dest='test_data_dir')
    parser.add_argument('--test_mask_dir', default=TEST_MASK, help='Path to the test mask file', dest='mask_dir')
    parser.add_argument('--transform', type=A.Compose, default=TEST_TRANSFORM, help='Data augmentation')
    
    return parser.parse_args()

In [3]:
def load_test_dataset(args):
    test_dataset = CSAWS(args.test_image_dir, args.test_mask_dir, args.transform)
    return test_dataset

In [None]:
def load_model(model_file, device):
    model = torch.load(model_file)
    return model

In [6]:
def plot_figure(id, image, color_mask, color_mask_predict):
        plt.figure(figsize=(15, 5))
        plt.subplot(1, 3, 1)
        plt.imshow(unorm(image).permute(1, 2, 0).cpu(), cmap='gray') 
        plt.title(f'Original Image {id}')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(color_mask)
        plt.title(f'True Mask {id}')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(color_mask_predict)
        plt.title(f'Predicted Mask {id}')
        plt.axis('off')
        
        plt.show()

In [7]:
def visualize_predictions(model, test_dataset, device):
    with torch.no_grad():
        model.eval()

        for id in range(len(test_dataset)):
            x, y = test_dataset[id]

            y_predict = F.interpolate(model(x.unsqueeze(0).to(device)), size=y.shape, mode="bilinear").argmax(dim=1).squeeze().cpu().numpy()

            color_mask_predict = np.zeros((*y_predict.shape, 3), dtype=np.uint8)
            color_mask = np.zeros((*y_predict.shape, 3), dtype=np.uint8)

            for i, color in enumerate(CUSTOM_COLORMAP):
                color_mask_predict[y_predict == i] = np.array(color)
                color_mask[y == i] = np.array(color)

    plot_figure(id, x, color_mask, color_mask_predict)

In [None]:
args = argument_parser()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(args.init_model_file, device)
test_dataset = load_test_dataset(args)
visualize_predictions(model, test_dataset, device)