Import the necessary packages

In [1]:
# using Pkg
# Pkg.add("Optimisers")

In [2]:
using Pkg
Pkg.activate("libs/")
using Lux, Random, NNlib, Zygote, LuxCUDA, CUDA, FluxMPI, JLD2, DICOM
using Images
using MLUtils
using Optimisers

[32m[1m  Activating[22m[39m project at `~/Desktop/Project BAC/BAC project/libs`


└ @ FluxMPI /home/molloi-lab/.julia/packages/FluxMPI/OM5f6/src/FluxMPI.jl:28


In [3]:
patch_size = 256
patch_size_half = round(Int, patch_size/2);

# 1. Prepare Data

## 1.1 Helper functions

In [4]:
"""
    This function zoom all pixel values into [0, 1].
"""
function zoom_pxiel_values(img)
    a, b = minimum(img), maximum(img)
    if b-a != 0
        img = (img .- a) / (b - a)
    end
    return img
end

"""
    This function takes in a img of various size, 
    returns patches with size = patch_size * patch_size.
"""
function patch_image(img, lbl)
    s = size(img)
    x = ceil(Int, s[1]/patch_size) + floor(Int, (s[1]-patch_size_half)/patch_size)
    y = ceil(Int, s[2]/patch_size) + floor(Int, (s[2]-patch_size_half)/patch_size)
    num_patches = x*y
    img_patches = Array{Float32, 4}(undef, patch_size, patch_size, 1, num_patches)
    lbl_patches = Array{Float32, 4}(undef, patch_size, patch_size, 1, num_patches)
    ct = 0
    for i = 1 : x-1
        x_start = 1+(i-1)*patch_size_half
        x_end = x_start+patch_size-1
        for j = 1 : y-1
            y_start = 1+(j-1)*patch_size_half
            y_end = y_start+patch_size-1
            # save patch
            ct += 1
            img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
            lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
        end
        # right col
        y_start, y_end = s[2]-patch_size+1, s[2]
        # save patch
        ct += 1
        img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
        lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
    end
    # last row
    x_start, x_end = s[1]-patch_size+1, s[1]
    for j = 1 : y-1
        y_start = 1+(j-1)*patch_size_half
        y_end = y_start+patch_size-1
        # save patch
        ct += 1
        img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
        lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
    end
    # right col
    y_start, y_end = s[2]-patch_size+1, s[2]
    # save patch
    ct += 1
    img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
    lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
    # return
    return num_patches, img_patches, lbl_patches
end

"""
    This function fixs the path to the images and labels.
"""
function fix_path!(data_set)
    num_data = size(data_set)[1]
    Threads.@threads for i = 1 : num_data
        for j = 1 : 2
            for k = 1 : 4
                # modify img path
                splited = split(deepcopy(data_set[i][j][k]), "\\")
                if size(splited)[1] > 1
                    new_path = joinpath("../collected_dataset_for_ML", joinpath(splited[4:end]))
                    data_set[i][j][k] = new_path
                end
            end
        end
    end
end

"""
    This function check how many number of images and labels there will be after patching.
"""
function get_num_of_imgs(data_set)
    num_data = size(data_set)[1]
    cts = Array{Int}(undef, num_data*4)
    Threads.@threads for i = 1 : num_data
        @views t = train_set[i]
        for j = 1 : 4
            # read dicom images
            s = size(dcm_parse(t[1][j])[(0x7fe0, 0x0010)])
            x = ceil(Int, s[1]/patch_size) + floor(Int, (s[1]-patch_size_half)/patch_size)
            y = ceil(Int, s[2]/patch_size) + floor(Int, (s[2]-patch_size_half)/patch_size)
            # save 
            cts[(i-1)*4+j] = x*y
        end
    end
    return cts
end

get_num_of_imgs

## 1.2 Prepare Data

In [5]:
@load "clean_set_step2_for_ubuntu.jld2" train_set valid_set

data_dir = "../collected_dataset_for_ML";

### 1.2.1 Load train set & valid set
container format: patch_size * patch_size * 1 * num_imgs

In [6]:
# get num of total patches(train)
ct_patches_train = get_num_of_imgs(train_set)
num_patches_train = sum(ct_patches_train)

703276

In [7]:
# runtime: 50s
num_train_data = size(train_set)[1]
train_container_images = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_train)
train_container_masks = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_train)
Threads.@threads for i = 1 : num_train_data
    start_idx = sum(ct_patches_train[1:i-1])+1
    for j = 1 : 4 # 4 images each patient
        # read dicom images
        img = Float16.(dcm_parse(train_set[i][1][j])[(0x7fe0, 0x0010)])
        # read png images
        lbl = Float16.(Images.load(train_set[i][2][j]))
        # process image
        num_patches, img_patches, lbl_patches = patch_image(img, lbl)
        # save 
        end_idx = start_idx+num_patches-1
        train_container_images[:, :, 1, start_idx : end_idx] = img_patches
        train_container_masks[:, :, 1, start_idx : end_idx] = lbl_patches
        start_idx = end_idx
    end
end

In [8]:
# # get num of total patches(valid)
# ct_patches_valid = get_num_of_imgs(valid_set)
# num_patches_valid = sum(ct_patches_valid)

In [9]:
# # runtime: 7.5s
# num_valid_data = size(valid_set)[1]
# valid_container_images = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_valid)
# valid_container_masks = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_valid)
# Threads.@threads for i = 1 : num_valid_data
#     start_idx = sum(ct_patches_valid[1:i-1])+1
#     for j = 1 : 4 # 4 images each patient
#         # read dicom images
#         img = Float16.(dcm_parse(valid_set[i][1][j])[(0x7fe0, 0x0010)])
#         # read png images
#         lbl = Float16.(Images.load(valid_set[i][2][j]))
#         # process image
#         num_patches, img_patches, lbl_patches = patch_image(img, lbl)
#         # save 
#         end_idx = start_idx+num_patches-1
#         valid_container_images[:, :, 1, start_idx : end_idx] = img_patches
#         valid_container_masks[:, :, 1, start_idx : end_idx] = lbl_patches
#         start_idx = end_idx
#     end
# end

### 1.2.2 Create dataloaders

In [10]:
b_s = 20
# test_loader = MLUtils.DataLoader((data=valid_container_images, label=valid_container_masks), batchsize=b_s)
train_loader = MLUtils.DataLoader((data=train_container_images, label=train_container_masks), batchsize=b_s);

### 1.2.3 Create Model

In [11]:
_conv = (in, out) -> Conv((3, 3), in=>out, pad=SamePad())

conv1 = (in, out) -> Chain(_conv(in, out), BatchNorm(out, leakyrelu))
conv2 = (in, out) -> Chain(_conv(in, out), x -> softmax(x; dims = 3))

_tran = (in, out) -> ConvTranspose((2, 2), in => out, stride = 2)
tran = (in, out) -> Chain(_tran(in, out), BatchNorm(out, leakyrelu))

my_cat = (x, y) -> cat(x, y; dims=Val(3))

function unet2D(in_chs, lbl_chs)    
    # Contracting layers
    l1 = Chain(conv1(in_chs, 64), conv1(64, 64))
    l2 = Chain(l1, MaxPool((2,2), stride=2), conv1(64, 128), conv1(128, 128))
    l3 = Chain(l2, MaxPool((2,2), stride=2), conv1(128, 256), conv1(256, 256))
    l4 = Chain(l3, MaxPool((2,2), stride=2), conv1(256, 512), conv1(512, 512))
    l5 = Chain(l4, MaxPool((2,2), stride=2), conv1(512, 1024), conv1(1024, 1024), tran(1024, 512))
    
    # Expanding layers
    l6 = Chain(Parallel(my_cat,l5,l4), conv1(512+512, 512), conv1(512, 512), tran(512, 256))
    l7 = Chain(Parallel(my_cat,l6,l3), conv1(256+256, 256), conv1(256, 256), tran(256, 128))
    l8 = Chain(Parallel(my_cat,l7,l2), conv1(128+128, 128), conv1(128, 128), tran(128, 64))
    l9 = Chain(Parallel(my_cat,l8,l1), conv1(64+64, 64), conv1(64, 64), conv2(64, lbl_chs))
end

unet2D (generic function with 1 method)

### 1.2.4 Create Loss

In [12]:
function dice_loss(ŷ, y; ϵ=1f-5)
    @inbounds loss_dice = 
        1f0 - (muladd(2f0, sum(ŷ[:,:,2,:] .* y[:,:,1,:]), ϵ) / (sum(ŷ[:,:,2,:] .^ 2) + sum(y[:,:,1,:] .^ 2) + ϵ))
    return loss_dice
end
lossfn = dice_loss

dice_loss (generic function with 1 method)

# 2. Train

## 2.1 Setup

In [11]:
FluxMPI.Init()
CUDA.allowscalar(false)
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)

TaskLocalRNG()

In [14]:
# device = gpu_device()
model = unet2D(1, 2)

ps, st = Lux.setup(rng, model) .|> gpu
ps = FluxMPI.synchronize!(ps; root_rank = 0)
st = FluxMPI.synchronize!(st; root_rank = 0)

└ @ Lux /home/molloi-lab/.julia/packages/Lux/5YzHA/src/deprecated.jl:28


(layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = NamedTuple(), layer_2 = (running_mean = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], running_var = Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], training = Val{true}()), layer_3 = NamedTuple(), layer_4 = (running_mean = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], running_var = Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], training = Val{true}()), layer_5 = NamedTuple(), layer_6 = NamedTuple(), layer_7 = (running_mean = Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], running_var = Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0,

In [15]:
opt = DistributedOptimizer(AdaGrad(0.01))
st_opt = Optimisers.setup(opt, ps)

st_opt = FluxMPI.synchronize!(st_opt; root_rank = 0)

(layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (layer_1 = (weight = [32mLeaf(DistributedOptimizer{AdaGrad{Float64}}(AdaGrad{Float64}(0.01, 2.22045e-16)), [39m[2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16;;;; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16;;;; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16;;;; … ;;;; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16;;;; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16;;;; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16; 2.22045f-16 2.22045f-16 2.22045f-16][32m)[39m, bias = [32mLeaf(DistributedOptimizer{AdaGrad{Float64}}(AdaGrad{Float64}(0.01, 2.22045e-16)), [39

## 2.2 Train

In [16]:
for epoch in 1:1
	for (_x, _y) in train_loader
		global ps, st_opt
		x, y = Float32.(_x) |> gpu, Float32.(_y) |> gpu
		lossfn2(p) = dice_loss(model(x, p, st)[1], y)
		l, back = Zygote.pullback(lossfn2, ps)
		FluxMPI.fluxmpi_println("\tLoss $l")
		gs = back(one(l))[1]
		st_opt, ps = Optimisers.update(st_opt, ps, gs)
	end
end

OutOfGPUMemoryError: Out of GPU memory trying to allocate 160.000 MiB
Effective GPU memory usage: 99.78% (22.078 GiB/22.126 GiB)
Memory pool usage: 19.750 GiB (20.156 GiB reserved)
