In [2]:
import json, copy, cv2
import networkx as nx
import sys,glob
import matplotlib.image as mpimg
import numpy as np
import gudhi as gd
from gudhi.representations import vector_methods
import ripser
import persim
from TDA_filtrations import level_set_flooding, save_BD, image_to_pointcloud
from custom_functions import *
import multiprocessing as mp
import skimage.measure

In [3]:
###### To select a dataset to analyze, uncomment all code under the dataset's name

### STARE Expert #1
dataset = "STARE1"
nefi_output_folder = "../Data/Dataset_1/NEFI_graphs/*/"
image_folder = "../Data/Dataset_1/Provided_masks/"
write_folder = "../Results/Dataset_1/"

###STARE Expert #2
'''dataset = "STARE2"
nefi_output_folder = "../Data/Dataset_1/NEFI_graphs_VK/*/"
image_folder = "../Data/Dataset_1/Provided_masks_VK/"
write_folder = "../Results/Dataset_1_VK/"'''

### HRF
'''dataset = "HRF"
nefi_output_folder = "../Data/HRF_Dataset_1/NEFI_graphs/*/"
image_folder = "../Data/HRF_Dataset_1/Provided_masks/"
write_folder = "../Results/HRF_Dataset_1/"'''

### All
'''dataset = "all"
nefi_output_folder = "../Data/all/NEFI_graphs/*/"
image_folder = "../Data/all/Provided_masks/"
write_folder = "../Results/all/"'''


#To help with indexing and loading disease classifications
if dataset == "HRF":
    nums = np.arange(1,46)
    #mat = np.load("../Data/Diagnoses/image_diagnoses_HRF.npy",allow_pickle=True).item()
elif "STARE" in dataset:
    nums = np.array([1,2,3,4,5,44,77,81,82,139,162, 163, 235, 236, 239, 240, 255, 291, 319, 324])
    #mat = np.load("../Data/Diagnoses/image_diagnoses.npy",allow_pickle=True).item()    
elif "all" in dataset:    
    mat = np.load("../Data/Diagnoses/image_diagnoses_all.npy",allow_pickle=True).item()    
    nums = list(mat['image_diagnoses'].keys())
    
data_name = "DS1_"    
file_name = "im"    
nefi_outputs = glob.glob(f"{nefi_output_folder}*.txt")    

## Compute the Radial inward and Radial outward filtrations

In [None]:
def radial_filtrations(num):
    
    if "all" in dataset:
        num_str = num
    else:
        num_str = f"{str(num).zfill(4)}"
        
    if dataset == "HRF":
        max_rad = 3000
    else:
        max_rad = 700
    
    #find nefi output file
    nefi_output = [s for s in nefi_outputs if num_str in s]
    #ensure there is only one location in this list
    assert len(nefi_output)==1
    #read in graph G
    graph_in = nx.read_multiline_adjlist(nefi_output[0],delimiter='|')

    #compute both radial inward and radial outward filtrations
    for direction in ['inward','outward']:
    
        filename_header = write_folder+data_name+file_name+num_str+"_"+direction
    
        diag = radius_filtration(graph_in,max_rad=max_rad,filename_save = filename_header+"_persistence",direction=direction)
    
        b0,b1,r = betti_curve(diag,r0=0,r1=40,filename_save = filename_header+"_Betti")
        
        PI_o, PI_r = Persist_im(diag=diag, inf_val = 40,sigma = 1.0, filename_save = [filename_header+"_PIO",
                                                                                      filename_header+"_PIR"])

print(f"Computing radial filtrations for {dataset}")          
pool = mp.Pool(mp.cpu_count())
results = pool.map(radial_filtrations, nums)
pool.close()

## Compute the Flooding filtration

In [None]:
def flood_filtration(num):

    if dataset == "all":
        num_str = num
    else:
        num_str = f"{str(num).zfill(4)}"
    
    #load in image
    image_loc = f"{image_folder}{file_name}{num_str}.png"
    image = mpimg.imread(image_loc)
    
    if dataset == "HRF":
        #downsample to ease computation
        image = skimage.measure.block_reduce(image,(3,3),np.max)
    
    filename_header = write_folder+data_name+file_name+num_str+"_flooding"
    
    try:
        diag = level_set_flooding(image[:,:,0],iter_num=35,steps=2,filename = filename_header+"_persistence")
    except:
        diag = level_set_flooding(image,       iter_num=35,steps=2,filename = filename_header+"_persistence")

    b0,b1,r = betti_curve(diag,r0=0,r1=35,filename_save = filename_header+"_Betti")
    
    PI_o, PI_r = Persist_im(diag=diag, sigma = 1.0, inf_val = 35,filename_save = [filename_header+"_PIO",filename_header+"_PIR"])

print(f"Computing flooding filtration for {dataset}")      
pool = mp.Pool(mp.cpu_count())
results = pool.map(flood_filtration, nums)
pool.close()


## VR filtration

In [None]:
#Define weighting for persistence images for ripser
def weight_ramp(x):
    
    if np.any(np.isinf(x)):
        weight = 1.0
    else:
        weight = x[1]/185
    
    return weight

def VR_filtration(num):
    
    if "all" in dataset:
        num_str = num
    else:
        num_str = f"{str(num).zfill(4)}"
    
    #load in image
    image_loc = f"{image_folder}{file_name}{num_str}.png"
    image = mpimg.imread(image_loc)
    
    if dataset == "HRF":
        #downsample for HRF to ease computation
        image = skimage.measure.block_reduce(image,(3,3),np.max)
    
    #saving
    filename_header = write_folder+data_name+file_name+num_str+"_VR"
    
    #convert image to pointcloud
    try:
        pointcloud = image_to_pointcloud(image[:,:,0])
    except:
        pointcloud = image_to_pointcloud(image)    
    pointcloud = np.array(pointcloud)

    #initialize averaged PIs for each descriptor vector
    im0_ripser_ramp = np.zeros((2500,))
    im1_ripser_ramp = np.zeros((2500,))
    im0_ripser_ones = np.zeros((2500,))
    im1_ripser_ones = np.zeros((2500,))
    
    np.random.seed(10)
    
    #average over 50 subsamplings
    for i in np.arange(50):
    
        #shuffle pointcloud
        np.random.shuffle(pointcloud)
    
        #Run VR on subsampled pointcloud
        dgms = ripser.ripser(pointcloud, n_perm = 2000)['dgms']
        
        #Save the persistence diagram
        save_BD(dgms, filename = f"{filename_header}_persistence_{i}")

        #Compute the Persistence images with ramped weighting
        persistence_image = vector_methods.PersistenceImage(resolution = [50,50],
                                                            im_range = [0,185,0,185],
                                                            weight = weight_ramp)    
        im0_ripser, im1_ripser = persistence_image.transform(dgms)
        
        im0_ripser_ramp += im0_ripser
        im1_ripser_ramp += im1_ripser
        
        #with one weighting    
        persistence_image = vector_methods.PersistenceImage(resolution = [50,50],
                                                            im_range = [0,185,0,185])    
        im0_ripser, im1_ripser = persistence_image.transform(dgms)
        
        im0_ripser_ones += im0_ripser
        im1_ripser_ones += im1_ripser
    
    im0_ripser_ramp /= 50.0
    im1_ripser_ramp /= 50.0
    im0_ripser_ones /= 50.0
    im1_ripser_ones /= 50.0
    
    #Save Persistence image results
    data = {}
    data['Ip'] = []
    data['Ip'].append(im0_ripser_ramp)
    data['Ip'].append(im1_ripser_ramp)
    np.save(f"{filename_header}_PIR",data)
    
    data = {}
    data['Ip'] = []
    data['Ip'].append(im0_ripser_ones)
    data['Ip'].append(im1_ripser_ones)
    np.save(f"{filename_header}_PIO",data)

    
print(f"Computing VR filtration for {dataset}")    
pool = mp.Pool(mp.cpu_count())
results = pool.map(VR_filtration, nums)
pool.close()
    