In [1]:
import os
import time
import numpy as np
from grain_track import utility
from grain_track.inference_track_net import *
from skimage.measure import label

def grain_track_for_gt():
    """
    We evaluate the performance of different tracking methods on gt slices
    """
    cwd = os.getcwd()
    parameter_address = os.path.join(cwd, "grain_track", "parameter")
    cnn_device = "cuda:0"

    # Performance of tracking on real data set with different algorithms.
    print("For real data")
    data_address = os.path.join(cwd, "datasets", "grain_track", "net_test", "real")
    input_address_pred = os.path.join(data_address, "real_boundary")
    input_address_gt = os.path.join(data_address, "real_gt_label_stack.npy")
    label_stack_gt = np.load(input_address_gt)
    print("The number of grain in GT is {}".format(len(np.unique(label_stack_gt) - 1)))
    grain_track = GrainTrack(input_address_pred, reverse_label=False)

    # method = 1 min centroid dis
    print("Analyzing by min centroid dis")
    start_time = time.time()
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=1)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of min centroid dis is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "real_gt_min_centroid_dis_label_stack.npy"), label_stack_pred)

    # method = 2 max overlap area
    print("Analyzing by max overlap area")
    start_time = time.time()
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=2)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of max overlap area is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "real_gt_max_overlap_area_label_stack.npy"), label_stack_pred)

    # method = 3 cnn vgg13_bn
    print("Analyzing by vgg13_bn")
    start_time = time.time()
    grain_track.set_cnn_tracker(model=0, pretrain_address=os.path.join(parameter_address, "real_vgg13_bn.pkl"), device=cnn_device, need_augment=False, max_num_tensor=30)
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=3)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of vgg13_bn is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "real_gt_vgg13_bn_label_stack.npy"), label_stack_pred)

    # method = 3 cnn densenet161
    print("Analyzing by densenet161")
    start_time = time.time()
    grain_track.set_cnn_tracker(model=1, pretrain_address=os.path.join(parameter_address, "real_densenet161.pkl"), device=cnn_device, need_augment=False, max_num_tensor=30)
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=3)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of densenet161 is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "real_gt_densenet161_label_stack.npy"), label_stack_pred)

    # Performance of tracking on real data set with different algorithms.
    print("For simulated data")
    data_address = os.path.join(cwd, "datasets", "grain_track", "net_test", "simulated")
    input_address_pred = os.path.join(data_address, "simulated_boundary")
    input_address_gt = os.path.join(data_address, "simulated_gt_label_stack.npy")
    label_stack_gt = np.load(input_address_gt)
    print("The number of grain in GT is {}".format(len(np.unique(label_stack_gt) - 1)))
    grain_track = GrainTrack(input_address_pred, reverse_label=False)

    # method = 1 min centroid dis
    print("Analyzing by min centroid dis")
    start_time = time.time()
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=1)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of min centroid dis is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "simulated_gt_min_centroid_dis_label_stack.npy"), label_stack_pred)

    # method = 2 max overlap area
    print("Analyzing by max overlap area")
    start_time = time.time()
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=2)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of max overlap area is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "simulated_gt_max_overlap_area_label_stack.npy"), label_stack_pred)

    # method = 3 cnn vgg13_bn
    print("Analyzing by vgg13_bn")
    start_time = time.time()
    grain_track.set_cnn_tracker(model=0, pretrain_address=os.path.join(parameter_address, "simulated_vgg13_bn.pkl"), device=cnn_device, need_augment=False, max_num_tensor=30)
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=3)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of vgg13_bn is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "simulated_gt_vgg13_bn_label_stack.npy"), label_stack_pred)

    # method = 3 cnn densenet161
    print("Analyzing by densenet161")
    start_time = time.time()
    grain_track.set_cnn_tracker(model=1, pretrain_address=os.path.join(parameter_address, "simulated_densenet161.pkl"), device=cnn_device, need_augment=False, max_num_tensor=30)
    label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=3)
    end_time = time.time()
    print("The number of grain is {}".format(label_num_pred))
    r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
    print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
        .format(r_index, adjust_r_index, v_index, merger_error, split_error))
    print("The duriation of densenet161 is {:.2f}'s".format(end_time - start_time))
    np.save(os.path.join(data_address, "simulated_gt_densenet161_label_stack.npy"), label_stack_pred)

def grain_track_for_real_pred():
    """
    We evaluate the performance of WPU-net using densenet161 on real test set(149 - 296). Because Unet-Bdelstm needs 
    3 slices as input and output 1 slice segmentation result. Thus, for fair comparison, we only track and analyse 150 - 295 slices.
    """
    cwd = os.getcwd()
    
    # Prepapre gt label stack
    data_address = os.path.join(cwd, "datasets", "grain_track", "net_test", "real")
    input_address_gt = os.path.join(data_address, "real_gt_label_stack.npy")
    label_stack_gt = np.load(input_address_gt)[:, :, 1: -1] # 150 - 295
    label_stack_gt, label_num_gt = label(label_stack_gt, return_num=True)
    print("The number of grain in GT is {} and the shape is {}".format(label_num_gt, label_stack_gt.shape))
    
    pretrain_address = os.path.join(cwd, "grain_track", "parameter", "real_densenet161.pkl") # we only test this tracking method
    cnn_device = "cuda:0"
    
    result_dir = os.path.join(cwd, "datasets", "segmentation", "result")
    methods_list = os.listdir(result_dir)
    for method_item in methods_list:
        print("For " + method_item)
        if method_item == "unet_bdclstm" or method_item == "att_unet"  or method_item == "unet" or method_item == ".ipynb_checkpoints":
            continue
        method_address = os.path.join(result_dir, method_item)
        grain_track = GrainTrack(method_address, reverse_label=False)
        grain_track.set_cnn_tracker(model=1, pretrain_address=pretrain_address, device=cnn_device, need_augment=False, max_num_tensor=30)
        label_stack_pred, label_num_pred = grain_track.get_tracked_label_stack(method=3)
        np.save(os.path.join(data_address, "real_" + method_item + "_densenet161_label_stack.npy"), label_stack_pred)
        print("The number of grain is {} and the shape is {}".format(label_num_pred, label_stack_gt.shape))
        r_index, adjust_r_index, v_index, merger_error, split_error = utility.validate_label_stack_by_rvi(label_stack_pred, label_stack_gt)
        print("The ri is {:.8f}, ari is {:.8f}, vi is {:.8f}, merger_error is {:.8f}, split_error is {:.8f}"
            .format(r_index, adjust_r_index, v_index, merger_error, split_error))     

if __name__ == "__main__":
#     grain_track_for_gt()
    grain_track_for_real_pred()

The number of grain in GT is 2224 and the shape is (1024, 1024, 146)
For unet_bdclstm
For att_unet
For unet
For .ipynb_checkpoints
For ffc


  nn.init.kaiming_normal(m.weight.data)


Start tracking
Tracking done
The number of grain is 2912 and the shape is (1024, 1024, 146)
The ri is 0.99928520, ari is 0.70498288, vi is 1.75165387, merger_error is 0.76841989, split_error is 0.98323397
For unet_3d
Start tracking
Tracking done
The number of grain is 3378 and the shape is (1024, 1024, 146)
The ri is 0.99936887, ari is 0.74054115, vi is 1.48643765, merger_error is 0.62845898, split_error is 0.85797867
For wpu_net_abw_mask_l1_5
Start tracking
Tracking done
The number of grain is 3165 and the shape is (1024, 1024, 146)
The ri is 0.99937266, ari is 0.75767839, vi is 1.47858783, merger_error is 0.70387330, split_error is 0.77471453
