# Generalization test on KVASIR-Seg dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os,sys,inspect
#sys.path.insert(0,"..")
os.chdir('..')

In [3]:
import numpy as np
import cv2
from PIL import Image
import torch
from torchvision import datasets, transforms, utils
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import glob
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
import torch
import random
import cv2
import torchvision.transforms as transforms

from sklearn.model_selection import train_test_split
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

torch.cuda.is_available = lambda : False
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

cpu


### Load data

In [4]:
class polyp_test_dataloader(Dataset):
    """
    KVASIR-Seg data loader
    """
    def __init__(self, data_folder, is_train=True):
        self.is_train = is_train
        self._data_folder = data_folder
        self.build_dataset()

    def build_dataset(self):
        self._input_folder = os.path.join(self._data_folder, 'images')
        self._label_folder = os.path.join(self._data_folder, 'masks')
        self._images = glob.glob(self._input_folder + "/*.png")
        self._labels = glob.glob(self._label_folder + "/*.png")
        
    def __len__(self):
        return len(self._images)

    def __getitem__(self, idx):
        img_path = self._images[idx]
        mask_path = self._labels[idx]
        
        # Read image, mask and scribble
        image = Image.open(img_path).convert('RGB')
        mask = cv2.imread(mask_path, 0)
        mask[mask<=127] = 0
        mask[mask>127] = 1
        mask = cv2.resize(mask, (224, 224), interpolation = cv2.INTER_AREA)
        mask = np.expand_dims(mask, axis=0)

        transforms_image = transforms.Compose([transforms.Resize((224, 224)), 
                                               transforms.CenterCrop((224,224)),
                                               transforms.ToTensor(),
                                               transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))])
        transforms_mask = transforms.Compose([transforms.Resize((224, 224)),
                                              transforms.CenterCrop((224,224)),
                                              transforms.ToTensor()])
        
        # Conver to torch tensors
        image = transforms_image(image)
        mask = torch.from_numpy(mask)
        
        sample = {'image': image, 
                  'mask': mask
                 }
        return sample

In [5]:
# def to_img(ten):
#     ten =(ten[0].permute(1,2,0).detach().cpu().numpy()+1)/2
#     ten=(ten*255).astype(np.uint8)
#     return ten

# a = to_img(x)
# print(a.shape)
# plt.imshow(a)
# #plt.imshow(a, cmap='gray')

In [6]:
# a = to_img(y)
# print(a.shape)
# plt.imshow(a, cmap='gray')

### Load model

In [9]:
from models.kiunet import unet
# from models.LeViTUNet128s import Build_LeViT_UNet_128s
# from models.LeViTUNet192 import Build_LeViT_UNet_192
# from models.LeViTUNet384 import Build_LeViT_UNet_384

#cvc_model_cb_ts_e/h

EXPERIMENT_NAME = "polys_unet_cb_h"
ROOT_DIR = os.path.abspath(".")
LOG_PATH = os.path.join(ROOT_DIR, "logs", EXPERIMENT_NAME)
model_path = 'logs/{}/{}.pth'.format(EXPERIMENT_NAME, EXPERIMENT_NAME)
model_path

'logs/polys_unet/polys_unet.pth'

In [10]:
model = unet()
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
model = model.to(DEVICE)
model.eval()

unet(
  (encoder1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoder5): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [11]:
from metrics import iou_score, dice_coef, calculate_metric_percase

def test(model):
    model.eval()
    
    with torch.no_grad():
        jaccard = 0
        dice = 0
        for data_name in ['CVC-ClinicDB', 'Kvasir', 'CVC-300', 'CVC-ColonDB', 'ETIS-LaribPolypDB']:
            test_dataset = polyp_test_dataloader("datasets/POLYPS/TestDataset/"+data_name)
            test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)
            for data in test_dataloader:
                data, target = data["image"].to(DEVICE), data["mask"].to(DEVICE)
                output = model(data.float())
                dc, jc, _ = calculate_metric_percase(output, target)
                jaccard += jc
                dice += dc
            jaccard /= len(test_dataloader)
            dice /= len(test_dataloader)
            print(f"Scores for {data_name}")
            print(f"Jaccard Index / IoU : {jaccard*100:.3f}")
            print(f"Dice Coeff / F1 : {dice*100}")
            #print('Jaccard Index / IoU : {:.3f}'.format(jaccard * 100))
            #print('Dice Coefficient / F1 : {:.3f}'.format(dice * 100))
            print('==========================================')
            print('==========================================')
        return jaccard

In [12]:
jac_score = test(model)



Scores for CVC-ClinicDB
Jaccard Index / IoU : 69.737
Dice Coeff / F1 : 77.64562318083964




Scores for Kvasir
Jaccard Index / IoU : 57.931
Dice Coeff / F1 : 69.92930437863208




Scores for CVC-300
Jaccard Index / IoU : 33.854
Dice Coeff / F1 : 45.19899484239413




Scores for CVC-ColonDB
Jaccard Index / IoU : 31.971
Dice Coeff / F1 : 41.71943736339824




Scores for ETIS-LaribPolypDB
Jaccard Index / IoU : 18.138
Dice Coeff / F1 : 25.197870553722908




In [16]:
for data_name in ['CVC-ClinicDB', 'Kvasir', 'CVC-300', 'CVC-ColonDB', 'ETIS-LaribPolypDB']:
    print(data_name)

CVC-ClinicDB
Kvasir
CVC-300
CVC-ColonDB
ETIS-LaribPolypDB


In [22]:
# Save predictions
if not os.path.exists(os.path.join(LOG_PATH, "vis_test")):
    os.mkdir(os.path.join(LOG_PATH, "vis_test"))
    for data_name in ['CVC-ClinicDB', 'Kvasir', 'CVC-300', 'CVC-ColonDB', 'ETIS-LaribPolypDB']:
        os.mkdir(os.path.join(LOG_PATH, "vis_test", data_name))
        os.mkdir(os.path.join(LOG_PATH, "vis_test", data_name, "imgs"))
        os.mkdir(os.path.join(LOG_PATH, "vis_test", data_name, "gts"))
        os.mkdir(os.path.join(LOG_PATH, "vis_test", data_name, "preds"))

for data_name in ['CVC-ClinicDB', 'Kvasir', 'CVC-300', 'CVC-ColonDB', 'ETIS-LaribPolypDB']:
    test_dataset = polyp_test_dataloader("datasets/POLYPS/TestDataset/"+data_name)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=8)
    for batch_idx, data in enumerate(test_dataloader):
        img, target = data["image"].to(DEVICE), data["mask"].to(DEVICE)
        output = torch.sigmoid(model(img.float()))

        img = (img[0].permute(1,2,0).detach().cpu().numpy()+1)/2
        img = (img*255).astype(np.uint8)
        img=cv2.cvtColor(img,cv2.COLOR_RGB2BGR)

        gt = target.permute(0, 2, 3, 1).squeeze().detach().cpu().numpy()
        gt=(gt*255).astype(np.uint8)
        gt=cv2.cvtColor(gt,cv2.COLOR_RGB2BGR)

        pred = output.permute(0, 2, 3, 1).squeeze().detach().cpu().numpy() > 0.5
        pred=(pred*255).astype(np.uint8)
        pred=cv2.cvtColor(pred,cv2.COLOR_RGB2BGR)

        cv2.imwrite(os.path.join(LOG_PATH, "vis_test", data_name, "imgs/")+str(batch_idx)+'.png', img)
        cv2.imwrite(os.path.join(LOG_PATH, "vis_test", data_name, "gts/")+str(batch_idx)+'.png', gt)
        cv2.imwrite(os.path.join(LOG_PATH, "vis_test", data_name, "preds/")+str(batch_idx)+'.png', pred)

