In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fn
from time import time
import os
import matplotlib.pyplot as plt
from torchvision.io.image import read_image
from torchvision.transforms import Compose, Resize, CenterCrop, RandomCrop,ToTensor, Normalize, Grayscale, RandomRotation
from torchvision.transforms.v2 import  RandomResize
from torch.utils.data import TensorDataset,DataLoader
from torchvision.io.image import read_image,write_png
from scipy.io import loadmat,savemat
from skimage.draw import line
from skimage.morphology import thin,medial_axis,skeletonize
from scipy import ndimage
from skimage import measure
from MSLNet import log2file

from PIL import Image

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC # resize the input image using bicubic interpolation, producing a smoother result compared to other interpolation methods like device = torch.device("cuda")

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = 'cpu'
print(device)
def load_mask(path,name):
    t=ToTensor()
    x = t(Image.open(path+name+'.png'))
    return x.squeeze(0)
    
def load_mask1(path,name):
    yi=np.load(path+'masks/'+name+'.npy')
    return torch.tensor(yi>0).float()


In [None]:
path='d:/datasets/cathaction/'
with open(path+'train.txt', 'r') as f:
    train = f.read().splitlines()
with open(path+'test.txt', 'r') as f:
    test = f.read().splitlines()
print(len(train),len(test))

In [None]:
from scipy.optimize import linear_sum_assignment
from percgrouping import draw, extractCurves, merge_curves, removeshortcurves, draw_cvs

def load_image(name):
    t=ToTensor()
    x = t(Image.open(name))
    return x

def save_result(name,im):
    im[im<=0]=0
    im[im >= 255] = 255
    write_png(im.unsqueeze(0).to(torch.uint8).cpu(),name)     

def iou1(a,b):
    """
    Function to get the IOU between two 1D-tensors
    """
    io=torch.sum(a*b)/(torch.sum(torch.max(a,b))+0.00001)
    return io

def Dice(a,b):
    return 2*torch.sum(a*b)/(torch.sum(a)+torch.sum(b))
    
def precision(a,dy):
    ndet=np.sum(a) # detections
    ncdet=np.sum(a*dy) 
    prec=ncdet/(ndet+0.00001)
    return prec.item()

def recall(dr,y):
    npos=np.sum(y>0).item() # positives
    npdet=np.sum((y*dr)>0).item() # detected positives, with detection at dist <=2
    rec=npdet/(npos+0.00001)
    #print(npos,npdet,rec)
    return rec

def eval1(da,dy,dist):
    # dGt = dt of GT
    # da =  dt of result
    dr=da<=dist
    prec=precision(da==0,dy<=dist)
    rec=recall(dr,dy==0)    
    return prec,rec

def AHD(da,dy):
    day=da[dy==0]
    dya=dy[da==0]
    return (np.mean(day)+np.mean(dya)).item()/2

#m=load_images('d:/training/guidewire/resultMSL886/',test[:],'.png',0,'cpu')
#plt.imshow(m,cmap='gray')
todo=test
nit=3
npts=10
dmax=np.arange(1,nit+1)*40
#dmax=np.ones(nit)*40
rho=0.7
tau=10
lmin=0
poly=0
print(nit,npts,dmax,rho,tau,poly,lmin)
#log2file('eval_loc.txt','merge_curves nit=%d thr=%.1f'%(nit,thr))
#for name in names:
#path='d:/datasets/cathaction/'
names=[]
for i in range(4):
    name='MSLLorNN_%d'%i
    names.append(name)
    pathl=path+'resultsLoc/'+name+'/'
    os.makedirs(pathl, exist_ok=True)
    path1=path+'resultsCv/'+name+'/'
    os.makedirs(path1, exist_ok=True)
print(names)
n=len(test)
nclf=len(names)
allp=np.zeros((n,nclf))
allr=np.zeros((n,nclf))
alln=np.zeros((n,nclf))
allt=np.zeros((n,nclf))
ious=np.zeros((n,nclf))
dices=np.zeros((n,nclf))
ahds=np.zeros((n,nclf))
valid=np.ones((n,nclf),dtype=np.int32)
thr=0
for i in range(n):
    fname=test[i]
    yi=load_mask1(path,fname)
    yi=thin(yi>0)
    dy=ndimage.distance_transform_edt(1-yi)
    yi=torch.tensor(yi).float()
    nr,nc=yi.shape
    for j in range(nclf):
        pathr=path+'Results/'+names[j]+'/'+fname+'.png'
        #paths='d:/tmp/Results/'+test[i]+'.bmp'
        pathl=path+'resultsLoc/'+names[j]+'/'
        path1=path+'resultsCv/'+names[j]+'/'
        if not os.path.exists(pathr):
            print(pathr,'not found')
            valid[i,j]=0
            continue
        im=load_image(pathr).squeeze()*255
        if len(im.shape)>2:
            im=im[0,:,:]
        t=0
        #fname=alls[k][i]
        #cv=load_cvs(path+'/'+fname+'.mat')
        #print(len(cv),cv[0].shape)
        t0=time()
        cv=extractCurves(im>0./255,8)
        #print(len(cv),cv[0].shape)
        #cv1
        if 0:
            pathr=path+'Images/'+fname+'.png'
            im=load_image(pathr).squeeze()
            if len(im.shape)>2:
                im=im[0,:,:]
        #print(im.shape)
        cv0=cv
        cv=merge_curves(cv,nit,npts,dmax,rho,tau,poly)
        cv=removeshortcurves(cv,lmin)
        allt[i,j]=time()-t0
        xr=torch.zeros(dy.shape).t()
        xr=draw_cvs(xr,cv).t()
        xr=(xr>thr).float()
        dx=ndimage.distance_transform_edt(1-xr)
        pi,ri=eval1(dx,dy,3)
        dices[i,j]=Dice(yi,xr)
        ious[i,j]=iou1(yi,xr)
        ahds[i,j]=AHD(dx,dy)
        allp[i,j]=pi
        allr[i,j]=ri
        #break
        #fi=2*pi*ri/(pi+ri)
        alln[i,j]=len(cv)
        name1=pathl+fname+'.png'
        #save_result(name1,(xr>0).float()*255)
        name1=path1+fname+'.png'
        #save_plot(name1,cv,im,200)
        name1=path1+fname+'.pth'
        #torch.save(cv,name1)
    if i%100==0:
        p,r=np.mean(allp[:i+1,0]),np.mean(allr[:i+1,0])
        d,io=np.mean(dices[:i+1,0]),np.mean(ious[:i+1,0])
        f=2*p*r/(p+r)
        print('%d %s %1.2f %1.2f %1.2f %1.2f %1.2f %.1f'%(i,fname,p*100,r*100,f*100,d*100,io*100, np.mean(alln[:i+1,0])))
p,r,f,d,io,n,ah=[],[],[],[],[],[],[]
for j in range(nclf):
    p.append(np.mean(allp[valid[:,j]==1,j])*100)
    r.append(np.mean(allr[valid[:,j]==1,j])*100)
    d.append(np.mean(dices[valid[:,j]==1,j])*100)
    io.append(np.mean(ious[valid[:,j]==1,j])*100) 
    ah.append(np.mean(np.mean(ahds[(valid[:,j]==1)&(np.isnan(ahds[:,j])==False),j]))) 
    n.append(np.mean(alln[valid[:,j]==1,j])) 
    f.append(2*p[j]*r[j]/(p[j]+r[j]))
    log2file('eval_loc.txt','%s %.2f %1.2f %1.2f %1.2f %1.2f %1.2f %.1f %.1f'%(names[j],thr,p[j],r[j],f[j],d[j],io[j],ah[j],n[j]))
if len(p)>1:
    log2file('eval_loc.txt','nit=%d,pts=%d,dmax=%d,rho=%.1f,tau=%d,poly=%d,lmin=%d: \n &%1.2f (%.2f) &%1.2f (%.2f) &%1.2f (%1.2f) &%1.2f (%1.2f) &%1.2f (%1.2f) &%1.1f (%1.1f) &%1.1f (%1.1f)'%(nit,npts,dmax[0],rho,tau,poly,lmin,
        np.mean(p),np.std(p),np.mean(r),np.std(r),np.mean(f),np.std(f),np.mean(d),np.std(d),np.mean(io),np.std(io),np.mean(ah),np.std(ah),np.mean(n),np.std(n)))