In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as Fn
import torchvision.transforms.functional as vF
from time import time
import os
import matplotlib.pyplot as plt
from torchvision.io.image import read_image
from torchvision.transforms import Compose, Resize, CenterCrop, RandomCrop,ToTensor, Normalize, Grayscale, RandomRotation
from torchvision.transforms.v2 import  RandomResize
from torch.utils.data import TensorDataset,DataLoader
from torchvision.io.image import read_image,write_png
from scipy.io import loadmat,savemat
from skimage.draw import line
from scipy import ndimage

from PIL import Image
from MSLNet import log2file
from scipy import ndimage

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC # resize the input image using bicubic interpolation, producing a smoother result compared to other interpolation methods like device = torch.device("cuda")

def load_image(name, method='PIL'):
    t=ToTensor()
    if method=='PIL':
        x = t(Image.open(name))
    else:
        x = read_image(name)
    #print(x.shape)
    return x

def iou1(a,b):
    """
    Function to get the IOU between two 1D-tensors
    """
    io=torch.sum(a*b)/(torch.sum(torch.max(a,b))+0.00001)
    return io

def Dice(a,b):
    return 2*torch.sum(a*b)/(torch.sum(a)+torch.sum(b))

def load_mask1(path,name):
    yi=np.load(path+'masks/'+name+'.npy')
    return torch.tensor(yi>0).float()

def load_mask(path,name):
    t=ToTensor()
    x = t(Image.open(path+'masks/1-7/'+name+'.png'))
    return x.squeeze(0)

def keeponly(names,st):
    out=[]
    for name in names:
        if st in name:
            out.append(name)
    return out

device = "cuda" if torch.cuda.is_available() else "cpu"
#device = 'cpu'
print(device)

In [None]:
path='d:/datasets/cathaction/'
with open(path+'train.txt', 'r') as f:
    train = f.read().splitlines()
with open(path+'test.txt', 'r') as f:
    test = f.read().splitlines()
#test=keeponly(test,'pt')
print(len(train),len(test))

In [None]:
from skimage.morphology import thin,medial_axis,skeletonize
def precision(a,dy):
    ndet=np.sum(a) # detections
    ncdet=np.sum(a*dy) 
    prec=ncdet/(ndet+0.00001)
    return prec.item()

def recall(dr,y):
    npos=np.sum(y>0).item() # positives
    npdet=np.sum((y*dr)>0).item() # detected positives, with detection at dist <=2
    rec=npdet/(npos+0.00001)
    #print(npos,npdet,rec)
    return rec

def eval1(da,dy,dist):
    # y= GT
    # dGt = dt of GT
    # xr = result
    dr=da<=dist
    prec=precision(da==0,dy<=dist)
    rec=recall(dr,dy==0)    
    return prec,rec

def AHD(da,dy):
    day=da[dy==0]
    dya=dy[da==0]
    return (np.mean(day)+np.mean(dya)).item()/2

# evaluate the runs of a method
bases=['MslLorNN_']
for base in bases:
    names=[]
    for i in range(4):
        name='%s%d'%(base,i)
        names.append(name)
    print(names)
    n=len(test)
    nclf=len(names)
    allp=np.zeros((n,nclf))
    allr=np.zeros((n,nclf))
    ious=np.zeros((n,nclf))
    dices=np.zeros((n,nclf))
    ahds=np.zeros((n,nclf))
    valid=np.ones((n,nclf),dtype=np.int32)
    thr=0
    for i in range(n):
        yi=load_mask1(path,test[i])
        yi=torch.tensor(thin(yi>0)).float()
        dy=ndimage.distance_transform_edt(1-yi)
        nr,nc=yi.shape
        for j in range(nclf):
            paths=path+'Resultsloc/'+names[j]+'/'+test[i]+'.png'
            if not os.path.exists(paths):
                print(paths,'not found')
                valid[i,j]=0
                continue
            xr=load_image(paths).squeeze()*255
            if len(xr.shape)>2:
                xr=xr[0,:,:]
            #print(torch.max(xr))
            xr=(xr[:nr,:nc]>thr).float()
            dx=ndimage.distance_transform_edt(1-xr)
            p,r=eval1(dx,dy,0)
            dices[i,j]=Dice(yi,xr)
            ious[i,j]=iou1(yi,xr)
            ahds[i,j]=AHD(dx,dy)
            allp[i,j]=p
            allr[i,j]=r
            #print(i,j,p,r)
        if i%200==0:
            p,r=np.mean(allp[:i+1,0]),np.mean(allr[:i+1,0])
            d,io=np.mean(dices[:i+1,0]),np.mean(ious[:i+1,0])
            ahd=np.mean(ahds[:i+1,0])
            f=2*p*r/(p+r)
            print('%d %1.2f %1.2f %1.2f %1.2f %1.2f %.1f'%(i,p*100,r*100,f*100,d*100,io*100,ahd))
    p,r,f,d,io,ahd=[],[],[],[],[],[]
    for j in range(nclf):
        p.append(np.mean(allp[valid[:,j]==1,j])*100)
        r.append(np.mean(allr[valid[:,j]==1,j])*100)
        d.append(np.mean(dices[valid[:,j]==1,j])*100)
        io.append(np.mean(ious[valid[:,j]==1,j])*100) 
        ahd.append(np.mean(ahds[(valid[:,j]==1)&(np.isnan(ahds[:,j])==False),j])) 
        f.append(2*p[j]*r[j]/(p[j]+r[j]))
        log2file('evalloc.txt','%s %.2f %1.2f %1.2f %1.2f %1.2f %1.2f %.1f'%(names[j],thr,p[j],r[j],f[j],d[j],io[j],ahd[j]))
    if len(p)>1:
        log2file('evalloc.txt','%.2f &%1.2f (%.2f) &%1.2f (%.2f) &%1.2f (%1.2f) &%1.2f (%1.2f) &%1.2f (%1.2f) &%1.1f (%1.1f)'%(thr,np.mean(p),np.std(p),np.mean(r),np.std(r),np.mean(f),np.std(f),np.mean(d),np.std(d),np.mean(io),np.std(io),np.mean(ahd),np.std(ahd)))
        

In [None]:
# make all results have value 255
import glob
#pathseg='D:\\Datasets\\CathAction\\Results\\SwinCA_2\\'
#pathseg='D:\\Datasets\\CathAction\\masks\\'
t=ToTensor()
os.chdir(pathseg) 
files = glob.glob("*.png")
#files = [f for f in os.listdir(pathseg) if os.path.isfile(os.path.join(pathseg, f))]
for i in range(len(files)):
    name=pathseg+files[i]
    im = t(Image.open(name))
    im1 = (im*255**2).long()
    #name=pathseg+files[i][:-4]+'.png'
    write_png(im1.to(torch.uint8),name)

In [None]:
# combine results with masks
from utils_MSL import merge
import glob
path='D:\\Datasets\\CathAction\\'
pathano=path+'masks\\'
todo='SwinCA_0'
pathseg=path+'Results\\'+todo
pathout=path+'ResultsAno\\'+todo
os.makedirs(pathout, exist_ok=True)
t=ToTensor()
os.chdir(pathseg) 
files = glob.glob("*.png")
#files = [f for f in os.listdir(pathseg) if os.path.isfile(os.path.join(pathseg, f))]
for i in range(len(files)):
    name=pathseg+'\\'+files[i]
    im = t(Image.open(name)).squeeze()
    if len(im.shape)>2:
        im=im[0,:,:]
    name=pathano+files[i]
    ano = t(Image.open(name)).squeeze()
    im=(im>0/255).float()
    out=merge([ano,im])
    im1 = (out*255).long().permute(2,0,1)
    name=pathout+'\\'+files[i]
    write_png(im1.to(torch.uint8),name)

In [None]:
# run Ambrosini's TACE centerline extraction and grouping program
import subprocess
path='D:\\Datasets\\CathAction\\'
pathexe='../../CNN-2D-X-Ray-Catheter-Detection\\cpp\\Release\\tace.exe'
pathseg=path+'Results\\SwinCA_0\\'
pathloc=path+'Results_loc\\SwinCA_0\\'
os.makedirs(pathloc, exist_ok=True)
for i in range(len(test)):
    subprocess.run([pathexe, pathseg+test[i]+'.png',pathloc+test[i]+'.txt'])
    if i%1000==0:
        print(i,end=' ')

In [None]:
# load curves from text files and save them as images
def draw(m,p):
    nc,nr=m.shape
    p=np.round(p).astype(np.int32)
    for i in range(1,p.shape[0]):
        r,c = line(p[i-1,0],p[i-1,1],p[i,0],p[i,1])
        r1=r[(r<nr)&(c<nc)]
        c1=c[(r<nr)&(c<nc)]
        m[c1,r1]=1
    return m
path='D:\\Datasets\\CathAction\\'
pathseg=path+'Results\\SwinCA_3\\'
pathloc=path+'Results_loc\\SwinCA_3\\'
for i in range(len(test)):
    name=pathseg+test[i]+'.png'
    im = Image.open(name)
    nc,nr=im.size
    #print(nr,nc)
    m=torch.zeros(nr,nc)
    name=pathloc+test[i]+'.txt'
    if os.path.isfile(name):
        cv=np.loadtxt(name)
        if len(cv.shape)>=2:
            m=draw(m,cv)
            m=m*255
    #print(m.shape,l.shape)
    #plt.imshow(m,cmap='gray')
    name=pathloc+test[i]+'.png'
    write_png(m.unsqueeze(0).to(torch.uint8),name)
    if i%1000==0:
        print(i,end=' ')

In [None]:
# evaluate on a sequence level
from utils_MSL import log2file
from scipy import ndimage
import os
   
def load_masks(path,seq,ext,dothin=0):
    t=ToTensor()
    ys=[]
    dts=[]
    for n in seq:
        yi=load_mask(path,n)
        if dothin:
            yi=torch.tensor(thin(yi>0)).float()
        #print(x.shape)
        dy=ndimage.distance_transform_edt(1-yi)
        ys.append(yi)
        dts.append(dy)
    return ys,dts

def load_images(path,seq,ext,channel=0,method='PIL'):
    t=ToTensor()
    im=[]
    for n in seq:
        name=path+'/'+n+ext
        if os.path.exists(name)==False:
            print('not found',name)
            continue
        if method=='PIL':
            x = t(Image.open(name))*255
        else:
            x = read_image(name)*255
        #print(x.shape)
        if len(x.shape)>2:
            im.append(x[channel,:,:])
        else:
            im.append(x) 
    return im

# evaluate the runs of a method
bases=['MslLorNN_','mfold','nnunet','2phase','SwinCa_','scnn25_']
bases=['MslDCNN_','MslCENN_']
bases=['SwinGw_','nnUNetGw','MslNetNN_Gw','MslLorNN_Gw']
bases=['SwinCa_','mfold_','MslLorNN_']
bases=['SCNN25_gw','2phase','SwinGw_','MslNetNN_','MslLorNN_']
bases=['SCNN25_gw']
bases=['SwinCA_','nnUNetCA','MslNetNNCA_','MslLorNNCA_']
bases=['SwinGw_','MslNetNN_','MslLorNN_']
n=len(tek)
nb=len(bases)
nclf=4
allp=np.zeros((n,nb,nclf))
allr=np.zeros((n,nb,nclf))
ious=np.zeros((n,nb,nclf))
dices=np.zeros((n,nb,nclf))
ahds=np.zeros((n,nb,nclf))
valid=np.ones((n,nb,nclf),dtype=np.int32)
thr=0
for s in range(n):
    y,dy=load_masks(path,te[tek[s]],'.png')
    nr,nc=y[0].shape
    for b in range(len(bases)):
        base=bases[b]
        names=[]
        for j in range(nclf):
            name='%s%d'%(base,j)
            names.append(name)
        #names=['nnunet/pred0','nnunet/pred1','nnunet/predf0','nnunet/predf1']
        #names=['SCNN25_gw0','SCNN25_gw1','SCNN25_gw3']
        #names=['ScrDeliveryDlls_070509']
        nclf=len(names)            
        for j in range(nclf):
            paths=path+'Resultsloc/'+names[j]+'/'
            xrs=load_images(paths,te[tek[s]],'.png')
            #print(len(xrs))
            if len(xrs)==0:
                valid[s,b,j]=0
                continue
            ps=[]
            rs=[]
            ds=[]
            ios=[]
            ahd=[]
            for i in range(len(xrs)):
                xr=xrs[i]
                #print(torch.max(xr))
                xr=(xr[:nr,:nc]>thr).float()
                dx=ndimage.distance_transform_edt(1-xr)
                p,r=eval1(dx,dy[i],3)
                ds.append(Dice(y[i],xr))
                ios.append(iou1(y[i],xr))
                ahd.append(AHD(dx,dy[i]))
                ps.append(p)
                rs.append(r)
            allp[s,b,j]=np.mean(ps)
            allr[s,b,j]=np.mean(rs)
            ious[s,b,j]=np.mean(ios)
            dices[s,b,j]=np.mean(ds)
            ahds[s,b,j]=np.mean(ahd)
            #print(i,j,p,r)
    if s%10==0:
        p,r=np.mean(allp[:s+1,0,0]),np.mean(allr[:s+1,0,0])
        d,io=np.mean(dices[:s+1,0,0]),np.mean(ious[:s+1,0,0])
        ahd=np.mean(ahds[:s+1,0,0])
        f=2*p*r/(p+r)
        print('%d %1.2f %1.2f %1.2f %1.2f %1.2f %.1f'%(s,p*100,r*100,f*100,d*100,io*100,ahd))
for b in range(len(bases)):
    p,r,f,d,io,ahd=[],[],[],[],[],[]
    for j in range(nclf):
        name='%s%d'%(bases[b],j)
        #name=names[j]
        p.append(np.mean(allp[valid[:,b,j]==1,b,j])*100)
        r.append(np.mean(allr[valid[:,b,j]==1,b,j])*100)
        d.append(np.mean(dices[valid[:,b,j]==1,b,j])*100)
        io.append(np.mean(ious[valid[:,b,j]==1,b,j])*100) 
        ahd.append(np.mean(ahds[(valid[:,b,j]==1)&(np.isnan(ahds[:,b,j])==False),b,j])) 
        f.append(2*p[j]*r[j]/(p[j]+r[j]))
        log2file('evalseq.txt','%s %.2f %1.2f %1.2f %1.2f %1.2f %1.2f %.1f'%(name,thr,p[j],r[j],f[j],d[j],io[j],ahd[j]))
    if len(p)>1:
        log2file('evalseq.txt','%.2f &%1.2f (%.2f) &%1.2f (%.2f) &%1.2f (%1.2f) &%1.2f (%1.2f) &%1.2f (%1.2f) &%1.1f (%1.1f)'%(thr,np.mean(p),np.std(p),np.mean(r),np.std(r),np.mean(f),np.std(f),np.mean(d),np.std(d),np.mean(io),np.std(io),np.mean(ahd),np.std(ahd)))