In [11]:
import os, time
from operator import add
import numpy as np
from glob import glob
import cv2
from tqdm import tqdm
import imageio
import torch
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score
from PIL import Image
from model import build_unet
from utils import create_dir, seeding

In [12]:
def calculate_metrics(y_true, y_pred):
    """ Ground truth """
    y_true = y_true.cpu().numpy()
    y_true = y_true > 0.5
    y_true = y_true.astype(np.uint8)
    y_true = y_true.reshape(-1)

    """ Prediction """
    y_pred = y_pred.cpu().numpy()
    y_pred = y_pred > 0.5
    y_pred = y_pred.astype(np.uint8)
    y_pred = y_pred.reshape(-1)

    score_jaccard = jaccard_score(y_true, y_pred)
    score_f1 = f1_score(y_true, y_pred)
    score_recall = recall_score(y_true, y_pred)
    score_precision = precision_score(y_true, y_pred)
    score_acc = accuracy_score(y_true, y_pred)

    return [score_jaccard, score_f1, score_recall, score_precision, score_acc]

In [13]:
def mask_parse(mask):
    mask = np.expand_dims(mask, axis=-1)    ## (512, 512, 1)
    mask = np.concatenate([mask, mask, mask], axis=-1)    ## (512, 512, 3)
    return mask

In [21]:
if __name__ == "__main__":
    "seeding"
    seeding(42)

    "folders"
    
    results_folder = "../results/"
    if not os.path.exists(results_folder):
        create_dir(results_folder)

    "load dataset"
    test_x = sorted(glob("../new_data/test/images/*"))
    print(len(test_x))
    test_y = sorted(glob("../new_data/test/masks/*"))
    print(len(test_y))

    "hyperparameters"
    H=512
    W=512
    size=(W,H)
    checkpoint_path = "C:/Users/WhiteLight/Documents/CIE_ml/Projects/WIP/image_segmentation/UNET/files/checkpoint.pth"

    "load the checkpoint"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    
    model = build_unet()
    model =model.to(device)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.eval()


    metrics_score = [0.0,0.0,0.0,0.0,0.0]
    time_taken = []
    
    for i, (x, y) in tqdm(enumerate(zip(test_x, test_y)), total=len(test_x)):
        """ Extract the name """
        name = x.split("/")[-1].split("\\")[-1].split(".")[0]
        print(name)
        """ Reading image """
        image = cv2.imread(x, cv2.IMREAD_COLOR)
        ##image = cv2.resize(image, size)
        x = np.transpose(image, (2, 0, 1))
        x = x/255.0
        x = np.expand_dims(x, axis=0)
        x = x.astype(np.float32)
        x = torch.from_numpy(x)
        x = x.to(device)

        """ Reading mask """
        mask = cv2.imread(y, cv2.IMREAD_GRAYSCALE)
        y = np.expand_dims(mask, axis=0)
        y=y/255.0
        y = np.expand_dims(y,axis=0)
        y = y.astype(np.float32)
        y = torch.from_numpy(y)
        y = y.to(device)


        """ Prediction and Calculating the metrics """
        with torch.no_grad():
            start_time = time.time()
            pred_y = model(x)
            pred_y = torch.sigmoid(pred_y)
            total_time = time.time() - start_time
            time_taken.append(total_time)


            score = calculate_metrics(y,pred_y)
            metrics_score = list(map(add, metrics_score, score))
            pred_y = pred_y[0].cpu().numpy()        ## (1, 512, 512)
            pred_y = np.squeeze(pred_y, axis=0)     ## (512, 512)
            pred_y = pred_y > 0.5
            pred_y = np.array(pred_y, dtype=np.uint8)

        """Saving Masks"""
        ori_mask = mask_parse(mask)
        pred_y = mask_parse(pred_y)
        line = np.ones((size[1], 10, 3)) * 128

        cat_images = np.concatenate(
            [image, line, ori_mask, line, pred_y * 255], axis=1
        )
        result_path = os.path.join(results_folder,f"{name}.png")
        cv2.imwrite(result_path, cat_images)
        


    jaccard = metrics_score[0]/len(test_x)
    f1 = metrics_score[1]/len(test_x)
    recall = metrics_score[2]/len(test_x)
    precision = metrics_score[3]/len(test_x)
    acc = metrics_score[4]/len(test_x)
    print(f"Jaccard: {jaccard:1.4f} - F1: {f1:1.4f} - Recall: {recall:1.4f} - Precision: {precision:1.4f} - Acc: {acc:1.4f}")

    fps = 1/np.mean(time_taken)
    print("FPS:",fps)

20
20


  0%|          | 0/20 [00:00<?, ?it/s]

01_test_0


  5%|▌         | 1/20 [00:00<00:17,  1.09it/s]

02_test_0


 10%|█         | 2/20 [00:01<00:14,  1.21it/s]

03_test_0


 15%|█▌        | 3/20 [00:02<00:13,  1.22it/s]

04_test_0


 20%|██        | 4/20 [00:03<00:12,  1.28it/s]

05_test_0


 25%|██▌       | 5/20 [00:03<00:11,  1.33it/s]

06_test_0


 30%|███       | 6/20 [00:04<00:10,  1.36it/s]

07_test_0


 35%|███▌      | 7/20 [00:05<00:09,  1.34it/s]

08_test_0


 40%|████      | 8/20 [00:06<00:08,  1.36it/s]

09_test_0


 45%|████▌     | 9/20 [00:06<00:08,  1.34it/s]

10_test_0


 50%|█████     | 10/20 [00:07<00:07,  1.36it/s]

11_test_0


 55%|█████▌    | 11/20 [00:08<00:06,  1.37it/s]

12_test_0


 60%|██████    | 12/20 [00:09<00:05,  1.38it/s]

13_test_0


 65%|██████▌   | 13/20 [00:09<00:05,  1.38it/s]

14_test_0


 70%|███████   | 14/20 [00:10<00:04,  1.40it/s]

15_test_0


 75%|███████▌  | 15/20 [00:11<00:03,  1.42it/s]

16_test_0


 80%|████████  | 16/20 [00:11<00:02,  1.42it/s]

17_test_0


 85%|████████▌ | 17/20 [00:12<00:02,  1.42it/s]

18_test_0


 90%|█████████ | 18/20 [00:13<00:01,  1.41it/s]

19_test_0


 95%|█████████▌| 19/20 [00:13<00:00,  1.41it/s]

20_test_0


100%|██████████| 20/20 [00:14<00:00,  1.37it/s]

Jaccard: 0.1811 - F1: 0.3016 - Recall: 0.3941 - Precision: 0.2691 - Acc: 0.8509
FPS: 25.96652517360063



