In [2]:
import cv2
import torch
import torch.nn as nn
import numpy as np
import segmentation_models_pytorch as smp
import os
import glob
import matplotlib.pyplot as plt
import random
import time
from torch.utils.data import DataLoader
from sklearn import metrics
from sklearn.metrics import confusion_matrix
import datetime

os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [3]:
class WoundData(torch.utils.data.Dataset):
    
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Read image and mask
        im = cv2.imread(self.data[idx][0],-1)
        mask = cv2.imread(self.data[idx][1],0)

        # From np.array (HxWxC) to torch.tensor (CxHxW). From [0,255] to [0,1]
        im = torch.from_numpy(np.float32(im/255).transpose(2,0,1))
        mask = torch.from_numpy(np.float32(mask/255)).unsqueeze(0)
        
        return im, mask

In [4]:
test_folder_imgs = np.array(glob.glob(os.path.join("test128", "images", "*")))
test_folder_labs = np.array(glob.glob(os.path.join("test128", "labels", "*")))

print(len(test_folder_imgs))
test_paths = []
for i in range(len(test_folder_imgs)):
    test_paths.append([test_folder_imgs[i], test_folder_labs[i]])

261


In [18]:
models = glob.glob(os.path.join("manytrainruns_p2psmb128_04*", "*.pt"))
print(len(models))

results = []

for mi, model_path in enumerate(models):
    
    with open(os.path.join(os.path.split(model_path)[0], "architecture.txt"), "r") as f:
        txt = f.read().split('\n')
        ENCODER = txt[0]
        ENCODER_WEIGHTS = txt[1]
        
    CLASSES = ['vein']
    DEVICE = 'cuda'

    model = smp.Unet(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        in_channels=3,
    )

    model.load_state_dict(torch.load(model_path))
    model.to(DEVICE);
    
    test_ds = WoundData(test_paths)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)
    
    collated_pred, collated_mask = [], []
    collated_img = []
    for i, (sampi, sampm) in enumerate(test_dl):
        sampo = torch.sigmoid(model(sampi.to(DEVICE)))
        collated_pred.append(np.array(sampo.cpu().detach()).squeeze())
        collated_mask.append(np.array(sampm.detach()).squeeze())
        collated_img.append(np.array(sampi.cpu().detach().squeeze()).transpose(1,2,0))
        
    collated_pred = np.array(collated_pred)
    collated_mask = np.array(collated_mask)
    collated_img = np.array(collated_img)
    collated_pred_bin = (collated_pred > 0.5).astype(int)
    
    prec = []
    rec = []
    dice = []
    sizes = []
    ious = []

    for pred, mask, img in zip(collated_pred_bin, collated_mask, collated_img):
        contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)    
        tp = np.sum((pred + mask) == 2)
        tn = np.sum((pred + mask) == 0)
        fp = np.sum((pred - mask) > 0)
        fn = np.sum((pred - mask) < 0)

        prec.append(tp/ (tp + fp + 0.00001))
        rec.append(tp/ (tp + fn + 0.00001))
        dice.append(2*tp/ (2*tp + fn + fp))
        ious.append(tp/(tp+fp+fn))
        sizes.append(np.sqrt(np.sum(mask)))
 
    print(model_path)
    print(np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious))
    results.append([np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious)])

35
manytrainruns_p2psmb128_04122022-003232\ganwoundmodel.pt
0.791410297424617 0.8572735407240317 0.7756329450890073 0.6678306155934379
manytrainruns_p2psmb128_04122022-003528\ganwoundmodel.pt
0.869752901132766 0.8215912124658615 0.8095689422405379 0.7104273787248317
manytrainruns_p2psmb128_04122022-004143\ganwoundmodel.pt
0.8068472331998916 0.8768362960616659 0.7998987436118085 0.7009984317507709
manytrainruns_p2psmb128_04122022-004612\ganwoundmodel.pt
0.8109040252492988 0.8519759405974515 0.7873040074896307 0.6836068000379965
manytrainruns_p2psmb128_04122022-004903\ganwoundmodel.pt
0.80568725864945 0.8767315835838536 0.7993600536250811 0.6976926113246986
manytrainruns_p2psmb128_04122022-005322\ganwoundmodel.pt
0.8485426000167067 0.8405931367492464 0.8098725862427693 0.7134937713911261
manytrainruns_p2psmb128_04122022-010014\ganwoundmodel.pt
0.8523678541427198 0.8528296612185314 0.8180458265204578 0.7264470533343101
manytrainruns_p2psmb128_04122022-010723\ganwoundmodel.pt
0.77897392136

In [7]:
models2 = glob.glob(os.path.join("manytrainruns_train128_04*", "*.pt"))
print(len(models2))

results2 = []

for mi, model_path in enumerate(models2):
    
    with open(os.path.join(os.path.split(model_path)[0], "architecture.txt"), "r") as f:
        txt = f.read().split('\n')
        ENCODER = txt[0]
        ENCODER_WEIGHTS = txt[1]
        
    CLASSES = ['vein']
    DEVICE = 'cuda'

    model = smp.Unet(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        in_channels=3,
    )

    model.load_state_dict(torch.load(model_path))
    model.to(DEVICE);
    
    test_ds = WoundData(test_paths)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)
    
    collated_pred, collated_mask = [], []
    collated_img = []
    for i, (sampi, sampm) in enumerate(test_dl):
        sampo = torch.sigmoid(model(sampi.to(DEVICE)))
        collated_pred.append(np.array(sampo.cpu().detach()).squeeze())
        collated_mask.append(np.array(sampm.detach()).squeeze())
        collated_img.append(np.array(sampi.cpu().detach().squeeze()).transpose(1,2,0))
        
    collated_pred = np.array(collated_pred)
    collated_mask = np.array(collated_mask)
    collated_img = np.array(collated_img)
    collated_pred_bin = (collated_pred > 0.5).astype(int)
    
    prec = []
    rec = []
    dice = []
    sizes = []
    ious = []

    for pred, mask, img in zip(collated_pred_bin, collated_mask, collated_img):
        contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)    
        tp = np.sum((pred + mask) == 2)
        tn = np.sum((pred + mask) == 0)
        fp = np.sum((pred - mask) > 0)
        fn = np.sum((pred - mask) < 0)

        prec.append(tp/ (tp + fp + 0.00001))
        rec.append(tp/ (tp + fn + 0.00001))
        dice.append(2*tp/ (2*tp + fn + fp))
        ious.append(tp/(tp+fp+fn))
        sizes.append(np.sqrt(np.sum(mask)))
 
    print(np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious))
    results2.append([np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious)])

35
0.803058632172808 0.8691586050281312 0.7891033815743561 0.6871601246553296
0.8037010082464845 0.8792541748676711 0.797145534242938 0.6965248315636553
0.7556190811741358 0.8960106587987426 0.7737100131461148 0.6668245292194606
0.7955719997945979 0.8607344487226245 0.7801414246245107 0.675082054953318
0.7779693780142364 0.8792155272725375 0.7779064388808891 0.6715309487512774
0.8135970575710284 0.8621475066986389 0.7950767283629437 0.6938575183540335
0.8143727918804015 0.8677254487537023 0.7968452605862072 0.697426483241024
0.8184265391524692 0.8745018668876844 0.8054869419003916 0.7089657502788287
0.8441757195884152 0.8529949825938222 0.8100382387130065 0.7136231402703227
0.8318487854690507 0.8600796222378136 0.8057103539731129 0.707634004164182
0.8240547080108546 0.8703924689228796 0.807629743458211 0.7090998576645257
0.7946955726093131 0.8802416777243494 0.7916581580902818 0.6901948673629796
0.8088424381213724 0.8565405520583524 0.7874491182056488 0.6865006324175019
0.8084530766876

In [8]:
models3 = glob.glob(os.path.join("manytrainruns_train128_FPN_*", "*.pt"))
print(len(models3))

results3 = []

for mi, model_path in enumerate(models3):
    
    with open(os.path.join(os.path.split(model_path)[0], "architecture.txt"), "r") as f:
        txt = f.read().split('\n')
        ENCODER = txt[0]
        ENCODER_WEIGHTS = txt[1]
        
    CLASSES = ['vein']
    DEVICE = 'cuda'

    model = smp.FPN(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        in_channels=3,
    )

    model.load_state_dict(torch.load(model_path))
    model.to(DEVICE);
    
    test_ds = WoundData(test_paths)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)
    
    collated_pred, collated_mask = [], []
    collated_img = []
    for i, (sampi, sampm) in enumerate(test_dl):
        sampo = torch.sigmoid(model(sampi.to(DEVICE)))
        collated_pred.append(np.array(sampo.cpu().detach()).squeeze())
        collated_mask.append(np.array(sampm.detach()).squeeze())
        collated_img.append(np.array(sampi.cpu().detach().squeeze()).transpose(1,2,0))
        
    collated_pred = np.array(collated_pred)
    collated_mask = np.array(collated_mask)
    collated_img = np.array(collated_img)
    collated_pred_bin = (collated_pred > 0.5).astype(int)
    
    prec = []
    rec = []
    dice = []
    sizes = []
    ious = []

    for pred, mask, img in zip(collated_pred_bin, collated_mask, collated_img):
        contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)    
        tp = np.sum((pred + mask) == 2)
        tn = np.sum((pred + mask) == 0)
        fp = np.sum((pred - mask) > 0)
        fn = np.sum((pred - mask) < 0)

        prec.append(tp/ (tp + fp + 0.00001))
        rec.append(tp/ (tp + fn + 0.00001))
        dice.append(2*tp/ (2*tp + fn + fp))
        ious.append(tp/(tp+fp+fn))
        sizes.append(np.sqrt(np.sum(mask)))
 
    print(np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious))
    results3.append([np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious)])

10
0.8124564222696982 0.8421388099768795 0.7785468385872898 0.6746105914610402
0.8061234947600389 0.8617913013421475 0.7879804979167286 0.6842950355325107
0.8359540004784152 0.8384369227834834 0.7928516343398464 0.6949071210818776
0.8067497859154358 0.8521900994077689 0.7810040401133524 0.6767506672455154
0.8253651987639741 0.844597070762433 0.7891127516817881 0.6874759049133562
0.7902150881161933 0.8595524996604994 0.7747178491928848 0.6682300245414021
0.785904188629299 0.8540967414142092 0.7687626638552983 0.6579221220505371
0.7889502024131986 0.8401506397032008 0.7607361592222173 0.6521076614793715
0.8038573949377261 0.8344131009078667 0.7662622294834055 0.6566004664897152
0.7847260118261944 0.857737550066366 0.7704619928203205 0.6620911908391585


In [16]:
models4 = glob.glob(os.path.join("manytrainruns_p2psmb128_FPN_0*", "*.pt"))
print(len(models4))

results4 = []

for mi, model_path in enumerate(models4):
    
    with open(os.path.join(os.path.split(model_path)[0], "architecture.txt"), "r") as f:
        txt = f.read().split('\n')
        ENCODER = txt[0]
        ENCODER_WEIGHTS = txt[1]
        
    CLASSES = ['vein']
    DEVICE = 'cuda'

    model = smp.FPN(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=len(CLASSES), 
        in_channels=3,
    )

    model.load_state_dict(torch.load(model_path))
    model.to(DEVICE);
    
    test_ds = WoundData(test_paths)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)
    
    collated_pred, collated_mask = [], []
    collated_img = []
    for i, (sampi, sampm) in enumerate(test_dl):
        sampo = torch.sigmoid(model(sampi.to(DEVICE)))
        collated_pred.append(np.array(sampo.cpu().detach()).squeeze())
        collated_mask.append(np.array(sampm.detach()).squeeze())
        collated_img.append(np.array(sampi.cpu().detach().squeeze()).transpose(1,2,0))
        
    collated_pred = np.array(collated_pred)
    collated_mask = np.array(collated_mask)
    collated_img = np.array(collated_img)
    collated_pred_bin = (collated_pred > 0.5).astype(int)
    
    prec = []
    rec = []
    dice = []
    sizes = []
    ious = []

    for pred, mask, img in zip(collated_pred_bin, collated_mask, collated_img):
        contours, hierarchy = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)    
        tp = np.sum((pred + mask) == 2)
        tn = np.sum((pred + mask) == 0)
        fp = np.sum((pred - mask) > 0)
        fn = np.sum((pred - mask) < 0)

        prec.append(tp/ (tp + fp + 0.00001))
        rec.append(tp/ (tp + fn + 0.00001))
        dice.append(2*tp/ (2*tp + fn + fp))
        ious.append(tp/(tp+fp+fn))
        sizes.append(np.sqrt(np.sum(mask)))
 
    print(np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious))
    results4.append([np.mean(prec), np.mean(rec), np.mean(dice), np.mean(ious)])

10
0.8652708531058912 0.7891038957778721 0.7808216206807231 0.6782985772341128
0.868105586194172 0.7746490581355997 0.7714066047348364 0.6669115492139572
0.8878036071185255 0.7686402781311327 0.78017262827587 0.6790334146411907
0.8908916052717274 0.7737357451206193 0.7881035287379492 0.6917414859157733
0.9166859757756819 0.7101707035975763 0.7565330913923681 0.6477108549886402
0.9001761106199357 0.7617140007193388 0.7828533070703685 0.6819446164036482
0.919782181250404 0.7320035302510569 0.7758156037056833 0.673814570211171
0.8864835142474804 0.7851589682704124 0.7938502596662873 0.6954952753053439
0.8989942557546244 0.7542242046193052 0.7784722594971932 0.6754683544391115
0.8942463351582627 0.7714302397345506 0.7899137732037571 0.6910987419150264


In [23]:
stats = np.mean(results, axis=0)
print(2*stats[0]*stats[1]/(stats[0]+stats[1]))
np.mean(results, axis=0) # unet + gan

0.842034871054116


array([0.83424734, 0.84996917, 0.80361316, 0.70464741])

In [24]:
stats = np.mean(results2, axis=0)
print(2*stats[0]*stats[1]/(stats[0]+stats[1]))
np.mean(results2, axis=0) # unet no gan

0.8350701937399583


array([0.80364028, 0.86905858, 0.79073449, 0.68859446])

In [25]:
stats = np.mean(results3, axis=0)
print(2*stats[0]*stats[1]/(stats[0]+stats[1]))
np.mean(results3, axis=0) # fpn no gan

0.8256717035278338


array([0.80403018, 0.84851047, 0.77704367, 0.67149908])

In [26]:
stats = np.mean(results4, axis=0)
print(2*stats[0]*stats[1]/(stats[0]+stats[1]))
np.mean(results4, axis=0) # fpn + gan

0.8222976179452897


array([0.892844  , 0.76208306, 0.77979427, 0.67815174])

In [27]:
# prec rec dice iou

In [21]:
len(results), len(results2), len(results3), len(results4)

(35, 25, 10, 10)