In [1]:
# using Pkg
# Pkg.add(url = "https://github.com/MolloiLab/imageToolBox.jl")

In [2]:
using imageToolBox, Images, ImageMorphology, DICOM, Statistics, ImageView, JLD2

In [3]:
label_dir= "/media/molloi-lab/2TB/BAC_processed_clean"
raw_data_dir= "/media/molloi-lab/2TB/Clean_Dataset_full"
pred_dir = "/media/molloi-lab/1TB/Output"
output_dir = "/media/molloi-lab/2TB/BAC_diff_patch_for_paper"

SIDs = readdir(pred_dir)
s = size(SIDs)[1]

5109

In [4]:
function keep_largest_component(img; factor=10)
    # resize image
    s = size(img)
    s_ = round.(Int, s ./ factor)
    img_ = round.(imresize(img, s_))

    # Label connected components
    labels = label_components(img_)

    # Count the number of pixels in each component
    component_sizes = [sum(labels .== i) for i in 1:maximum(labels)]

    # Find the label of the largest component
    largest_label = argmax(component_sizes)

    # Return a binary image with only the largest component
    img_ = labels .== largest_label

    # resize image
    img_ = round.(imresize(img_, s))

    return img_
end

function find_surface_indices(arr)
    rslt = dilate(arr) .- arr
    surface_indices = findall(isone, rslt)
    return surface_indices, rslt
end

function mark_surfaces(img, mask, color)
    surface, _ = find_surface_indices(mask)
    r, g, b = color
    # Dimensions of the image
    dims = size(img)
    
    # Convert grayscale to RGB
    rgb_img = Array{Float32, 3}(undef, 3, dims...)
    for y in 1:dims[2]
        for x in 1:dims[1]
            gray_val = img[x, y]
            rgb_img[:, x, y] .= gray_val
        end
    end
    
    for idx in surface
        rgb_img[1, idx[1], idx[2]] = r
        rgb_img[2, idx[1], idx[2]] = g
        rgb_img[3, idx[1], idx[2]] = b
    end

    return colorview(RGB, rgb_img)
end

mark_surfaces (generic function with 1 method)

In [5]:
function locate_diff_and_save_patches(img, mask, pred, lbl; patch_size = 256, cover_radius = 128, threshold = 3)
    img_diff = abs.(pred - lbl)
    h, w = size(img)
    half_patch = patch_size ÷ 2
    patches = []
    lock_obj = ReentrantLock()

    # Binary matrix to track covered regions
    covered = falses(h, w)

    # Find coordinates of non-zero (difference) pixels
    diff_coords = findall(x -> x ≠ 0, img_diff)

    # @Threads.threads for coord in diff_coords
    for coord in diff_coords
        y, x = Tuple(coord)  # Convert CartesianIndex to tuple

        # Skip if the central region of the patch is already covered
        
        covered[y, x] && continue

        top = max(1, y - half_patch)
        bottom = min(h, y + half_patch - 1)
        left = max(1, x - half_patch)
        right = min(w, x + half_patch - 1)

        # Check if the patch size is 256x256
        if bottom - top + 1 == patch_size && right - left + 1 == patch_size
            pred_patch = pred[top:bottom, left:right]
            lbl_patch = lbl[top:bottom, left:right]
            # mask_patch = mask[top:bottom, left:right]

            # Compute the areas of false positives and false negatives within the patch
            false_positive_area = sum((x > 0 && y == 0) for (x, y) in zip(pred_patch, lbl_patch))
            false_negative_area = sum((x == 0 && y > 0) for (x, y) in zip(pred_patch, lbl_patch))

            # Compute the total mask area within the patch
            pred_area = sum(pred_patch)
            lbl_area = sum(lbl_patch)

            # Check if the false positive/negative area is greater than one-third of the mask area
            if false_positive_area > pred_area / threshold || false_negative_area > pred_area / threshold || false_positive_area > lbl_area / threshold || false_negative_area > lbl_area / threshold
                try
                    # increase contrast
                    # mask_ = mask[top:bottom, left:right]
                    img_ = img[top:bottom, left:right]
                    # draw contour
                    img_with_pred = mark_surfaces(img_, pred_patch, [0,0,1])
                    img_with_lbl = mark_surfaces(img_, lbl_patch, [1,0,0])

                    # lock(lock_obj) do
                        push!(patches, cat(img_with_pred, img_with_lbl, dims = 2))
                    # end

                    # Mark a larger region around the saved patch as covered
                    cover_top = max(1, y - cover_radius)
                    cover_bottom = min(h, y + cover_radius)
                    cover_left = max(1, x - cover_radius)
                    cover_right = min(w, x + cover_radius)
                    covered[cover_top:cover_bottom, cover_left:cover_right] .= true
                catch e
                    @info e
                end 
            end
        end
    end

    return patches
end

locate_diff_and_save_patches (generic function with 1 method)

In [6]:
# # fix wrong masks 
# @Threads.threads for i = 1 : s
# # for i = 1 : 3
#     sid = SIDs[i]
#     curr_pred_dir = joinpath(pred_dir, sid)
#     curr_img_dir = joinpath(raw_data_dir, sid)
#     for f in readdir(curr_pred_dir)
#         f_name = rsplit(f, '.'; limit = 3)[1] 
#         # setup path
#         curr_pred_path = joinpath(curr_pred_dir, f)
#         curr_img_path = joinpath(curr_img_dir, f_name * ".dcm")
#         new_pred_path = joinpath(curr_pred_dir, f_name * ".fixedpred.png")
#         # read 
#         curr_pred = Float32.(Images.load(curr_pred_path))
#         curr_img = Float32.(dcm_parse(curr_img_path)[(0x7fe0, 0x0010)])
#         # fix pred
#         mask = keep_largest_component(1 .- round.(zoom_pixel_values(curr_img)))
#         new_pred = curr_pred .* mask
#         # save
#         save(new_pred_path, new_pred)
#     end
# end

In [7]:
GC.gc(true)
# all_patches = []
@Threads.threads for i = 1: s
# @Threads.threads for i = 1 : 5
    if i % 100 == 0
        @info i
    end
    sid = SIDs[i]
    dir1 = joinpath(pred_dir, sid)
    for f in readdir(dir1)
        splited = rsplit(f, '.'; limit = 3)
        # if splited[2] == "fixedpred"
        if splited[2] == "pred"
            f_name = splited[1] 
            # setup path
            lbl_path = joinpath(label_dir, sid, f_name*".png")
            if isfile(lbl_path)
                # setup paths
                pred_path = joinpath(dir1, f)
                img_path = joinpath(label_dir, sid, f_name*".dcm")
                # read PNGs
                img = zoom_pixel_values(dcm_parse(img_path)[(0x7fe0, 0x0010)])
                # mask = 1 .- round.(zoom_pixel_values(img))
                pred = read_png(pred_path)
                lbl = read_png(lbl_path)
                # get result
                patches = locate_diff_and_save_patches(img, nothing, pred, lbl)
                # append!(all_patches, patches)
                if size(patches)[1] > 0
                    curr_output_dir = joinpath(output_dir, sid)
                    isdir(curr_output_dir) || mkdir(curr_output_dir)
                    for (i, p) in enumerate(patches)
                        save(joinpath(curr_output_dir, string(i)*".png"), p)
                    end
                end
                GC.gc(true)
            end
        end
    end
end
# size(all_patches)

In [None]:
pixel_data = zeros(Int, s, 4)

@Threads.threads for i = 1 : s
    if i % 100 == 0
        @info i
    end
    sid = SIDs[i]
    dir1 = joinpath(pred_dir, sid)
    for f in readdir(dir1)
        splited = rsplit(f, '.'; limit = 3)
        if splited[2] == "pred"
            f_name = splited[1] 
            # setup path
            lbl_path = joinpath(raw_data_dir, sid, f_name*".png")
            # setup paths
            pred_path = joinpath(dir1, f)
            # read PNGs
            pred = read_png(pred_path)
            lbl = read_png(lbl_path)
            
            a = sum(pred)
            b = sum((pred .== 1) .& (lbl .== 0))
            c = sum((pred .== 0) .& (lbl .== 1))
            
            # Calculate TN, TP, FN, FP for this image
            TP = a - b
            FP = b
            FN = c
            TN = length(pred) - TP - FP - FN
            
            pixel_data[i, 1] += TN
            pixel_data[i, 2] += TP
            pixel_data[i, 3] += FP
            pixel_data[i, 4] += FN
            
            GC.gc(true)
        end
    end
end

# Compute overall precision, recall, and F1-score using the pixel_data matrix
total_TN = sum(pixel_data[:, 1])
total_TP = sum(pixel_data[:, 2])
total_FP = sum(pixel_data[:, 3])
total_FN = sum(pixel_data[:, 4])

precision = total_TP / (total_TP + total_FP)
recall = total_TP / (total_TP + total_FN)
f1_score = 2 * (precision * recall) / (precision + recall)

println("Precision: ", precision)
println("Recall: ", recall)
println("F1-score: ", f1_score)

In [None]:
total_TN

In [None]:
total_TP

In [None]:
total_FP

In [None]:
total_FN

- #64:
    - total_TN = 201470811526
    - total_TP = 46735536
    - total_FP = 55212524
    - total_FN = 34972708
    - Precision: 0.4584249665957351  
    - Recall: 0.5719806681930406  
    - F1-score: 0.5089456226887806  

In [None]:
@save "pixel_data2.jld2" pixel_data

In [None]:
# all_patches[tmp][1]

In [None]:
# all_patches[tmp][2]

In [None]:
# all_patches[tmp][3]

In [None]:
# @save "saved_patches_for_paper_600.jld2" all_patches