In [1]:
using Pkg
Pkg.activate("libs/")
# Pkg.instantiate()
# Pkg.add("MLUtils")
using CSV
using JLD2
using CUDA
using Glob
using Dates
# using Zygote
using DICOM
using Images
using MLUtils
using Setfield
using ImageView
using ImageDraw
using Statistics
using DataFrames
using StaticArrays
using MLDataPattern
using ChainRulesCore
using Distributions: Normal
using FastAI, FastVision, Flux, Metalhead
import CairoMakie; CairoMakie.activate!(type="png")

[32m[1m  Activating[22m[39m project at `~/Desktop/Project BAC/BAC project/libs`


/snap/core20/current/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /lib/x86_64-linux-gnu/libproxy.so.1)
Failed to load module: /home/molloi-lab/snap/code/common/.cache/gio-modules/libgiolibproxy.so
/snap/core20/current/lib/x86_64-linux-gnu/libstdc++.so.6: version `GLIBCXX_3.4.29' not found (required by /lib/x86_64-linux-gnu/libproxy.so.1)
Failed to load module: /home/molloi-lab/snap/code/common/.cache/gio-modules/libgiolibproxy.so


Check how many threads we have

In [2]:
Threads.nthreads()

64

List all CUDA devices

In [3]:
CUDA.allowscalar(false)
CUDA.devices()

CUDA.DeviceIterator() for 4 devices:
0. NVIDIA GeForce RTX 4090
1. NVIDIA GeForce RTX 4090
2. NVIDIA GeForce RTX 4090
3. NVIDIA GeForce RTX 4090

# Notes

1. In this training, images' pixel values are zoomed into range = [0, 1]. This should not be applied to final plan since different images have different min and max.

2. In this training, only `pixel values` are being used. In future, we can feed more features to the NN to get a better information like `L_mass` or `R_mass`.

In [4]:
patch_size = 256
patch_size_half = round(Int, patch_size/2);

# Helper functions

In [5]:
"""
    This function zoom all pixel values into [0, 1].
"""
function zoom_pxiel_values(img)
    a, b = minimum(img), maximum(img)
    if b-a != 0
        img = (img .- a) / (b - a)
    end
    return img
end

zoom_pxiel_values

In [6]:
"""
    This function takes in a img of various size, 
    returns patches with size = patch_size * patch_size.
"""
function patch_image(img, lbl)
    s = size(img)
    x = ceil(Int, s[1]/patch_size) + floor(Int, (s[1]-patch_size_half)/patch_size)
    y = ceil(Int, s[2]/patch_size) + floor(Int, (s[2]-patch_size_half)/patch_size)
    num_patches = x*y
    img_patches = Array{Float32, 4}(undef, patch_size, patch_size, 1, num_patches)
    lbl_patches = Array{Float32, 4}(undef, patch_size, patch_size, 1, num_patches)
    ct = 0
    for i = 1 : x-1
        x_start = 1+(i-1)*patch_size_half
        x_end = x_start+patch_size-1
        for j = 1 : y-1
            y_start = 1+(j-1)*patch_size_half
            y_end = y_start+patch_size-1
            # save patch
            ct += 1
            img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
            lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
        end
        # right col
        y_start, y_end = s[2]-patch_size+1, s[2]
        # save patch
        ct += 1
        img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
        lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
    end
    # last row
    x_start, x_end = s[1]-patch_size+1, s[1]
    for j = 1 : y-1
        y_start = 1+(j-1)*patch_size_half
        y_end = y_start+patch_size-1
        # save patch
        ct += 1
        img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
        lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
    end
    # right col
    y_start, y_end = s[2]-patch_size+1, s[2]
    # save patch
    ct += 1
    img_patches[:, :, 1, ct] = zoom_pxiel_values(img[x_start:x_end, y_start:y_end])
    lbl_patches[:, :, 1, ct] = lbl[x_start:x_end, y_start:y_end]
    # return
    return num_patches, img_patches, lbl_patches
end

patch_image

In [7]:
"""
    This function fixs the path to the images and labels.
"""
function fix_path!(data_set)
    num_data = size(data_set)[1]
    Threads.@threads for i = 1 : num_data
        for j = 1 : 2
            for k = 1 : 4
                # modify img path
                splited = split(deepcopy(data_set[i][j][k]), "\\")
                if size(splited)[1] > 1
                    new_path = joinpath("../collected_dataset_for_ML", joinpath(splited[4:end]))
                    data_set[i][j][k] = new_path
                end
            end
        end
    end
end

fix_path!

In [8]:
"""
    This function check how many number of images and labels there will be after patching.
"""
function get_num_of_imgs(data_set)
    num_data = size(data_set)[1]
    cts = Array{Int}(undef, num_data*4)
    Threads.@threads for i = 1 : num_data
        @views t = train_set[i]
        for j = 1 : 4
            # read dicom images
            s = size(dcm_parse(t[1][j])[(0x7fe0, 0x0010)])
            x = ceil(Int, s[1]/patch_size) + floor(Int, (s[1]-patch_size_half)/patch_size)
            y = ceil(Int, s[2]/patch_size) + floor(Int, (s[2]-patch_size_half)/patch_size)
            # save 
            cts[(i-1)*4+j] = x*y
        end
    end
    return cts
end

get_num_of_imgs

# 1. Prepare

In [9]:
@load "clean_set_step2_for_ubuntu.jld2" train_set valid_set

2-element Vector{Symbol}:
 :train_set
 :valid_set

In [10]:
data_dir = "../collected_dataset_for_ML";

## 1.1 Load train set & valid set
container format: patch_size * patch_size * 1 * num_imgs

In [12]:
# get num of total patches(train)
ct_patches_train = get_num_of_imgs(train_set)
num_patches_train = sum(ct_patches_train)

703276

In [13]:
# runtime: 50s
num_train_data = size(train_set)[1]
train_container_images = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_train)
train_container_masks = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_train)
Threads.@threads for i = 1 : num_train_data
    start_idx = sum(ct_patches_train[1:i-1])+1
    for j = 1 : 4 # 4 images each patient
        # read dicom images
        img = Float16.(dcm_parse(train_set[i][1][j])[(0x7fe0, 0x0010)])
        # read png images
        lbl = Float16.(Images.load(train_set[i][2][j]))
        # process image
        num_patches, img_patches, lbl_patches = patch_image(img, lbl)
        # save 
        end_idx = start_idx+num_patches-1
        train_container_images[:, :, 1, start_idx : end_idx] = img_patches
        train_container_masks[:, :, 1, start_idx : end_idx] = lbl_patches
        start_idx = end_idx
    end
end

In [14]:
# get num of total patches(valid)
ct_patches_valid = get_num_of_imgs(valid_set)
num_patches_valid = sum(ct_patches_valid)

113564

In [15]:
# runtime: 7.5s
num_valid_data = size(valid_set)[1]
valid_container_images = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_valid)
valid_container_masks = Array{Float16, 4}(undef, patch_size, patch_size, 1, num_patches_valid)
Threads.@threads for i = 1 : num_valid_data
    start_idx = sum(ct_patches_valid[1:i-1])+1
    for j = 1 : 4 # 4 images each patient
        # read dicom images
        img = Float16.(dcm_parse(valid_set[i][1][j])[(0x7fe0, 0x0010)])
        # read png images
        lbl = Float16.(Images.load(valid_set[i][2][j]))
        # process image
        num_patches, img_patches, lbl_patches = patch_image(img, lbl)
        # save 
        end_idx = start_idx+num_patches-1
        valid_container_images[:, :, 1, start_idx : end_idx] = img_patches
        valid_container_masks[:, :, 1, start_idx : end_idx] = lbl_patches
        start_idx = end_idx
    end
end

In [16]:
GC.gc(true)

## 1.2 Create dataloaders

In [17]:
b_s = 12
train_loader = MLUtils.DataLoader((data=train_container_images, label=train_container_masks), batchsize=b_s)
test_loader = MLUtils.DataLoader((data=valid_container_images, label=valid_container_masks), batchsize=b_s);

## 1.3 Create Model

In [18]:
function _random_normal(shape...)
    return Float32.(rand(Normal(0.0,0.02),shape...))
end

_conv = (stride, in, out) -> Conv((3, 3), in=>out, stride=stride, pad=SamePad();init=_random_normal)
_tran = (stride, in, out) -> ConvTranspose((2, 2), in=>out, stride=stride, pad=SamePad();init=_random_normal)
# _conv = (stride, in, out) -> Conv((3, 3), in=>out, stride=stride, pad=SamePad())
# _tran = (stride, in, out) -> ConvTranspose((2, 2), in=>out, stride=stride, pad=SamePad())

conv1 = (in, out) -> Chain(_conv(1, in, out), BatchNorm(out, leakyrelu))
conv2 = (in, out) -> Chain(_conv(2, in, out), BatchNorm(out, leakyrelu))
conv3 = (in, out) -> Chain(_conv(1, in, out), x -> softmax(x; dims = 3))
# conv3 = (in, out) -> Chain(_conv(1, in, out), sigmoid)
tran2 = (in, out) -> Chain(_tran(2, in, out), BatchNorm(out, leakyrelu))



function unet2D(in_chs, lbl_chs)
    # Contracting layers
    l1 = Chain(conv1(in_chs, 64), conv1(64, 64))
    l2 = Chain(l1, MaxPool((2,2), stride=2), conv1(64, 128), conv1(128, 128))
    l3 = Chain(l2, MaxPool((2,2), stride=2), conv1(128, 256), conv1(256, 256))
    l4 = Chain(l3, MaxPool((2,2), stride=2), conv1(256, 512), conv1(512, 512))
    l5 = Chain(l4, MaxPool((2,2), stride=2), conv1(512, 1024), conv1(1024, 1024), tran2(1024, 512))

    # Expanding layers
    l6 = Chain(Parallel(FastVision.Models.catchannels,l5,l4), 
                conv1(512+512, 512),
                conv1(512, 512),
                tran2(512, 256))
    l7 = Chain(Parallel(FastVision.Models.catchannels,l6,l3), 
                conv1(256+256, 256),
                conv1(256, 256),
                tran2(256, 128))
    l8 = Chain(Parallel(FastVision.Models.catchannels,l7,l2), 
                conv1(128+128, 128),
                conv1(128, 128),
                tran2(128, 64))
    l9 = Chain(Parallel(FastVision.Models.catchannels,l8,l1), 
                conv1(64+64, 64),
                conv1(64, 64),
                conv3(64, lbl_chs))
end

unet2D (generic function with 1 method)

## 1.4 Create Loss

In [19]:
function dice_loss(ŷ, y; ϵ=1f-5)
    @inbounds loss_dice = 
        1f0 - (muladd(2f0, sum(ŷ[:,:,2,:] .* y[:,:,1,:]), ϵ) / (sum(ŷ[:,:,2,:] .^ 2) + sum(y[:,:,1,:] .^ 2) + ϵ))
    return loss_dice
end
lossfn = dice_loss

dice_loss (generic function with 1 method)

## 1.5 Loop

In [20]:
function train_1_epoch!(epoch_idx, model, model_ps, train_dl, optimizer)
	# Epoch start
	losses = Float32[]
	step_ct = 0
	for (x, y) in train_dl
		x_gpu, y_gpu = x |> gpu, Float32.(y) |> gpu
		ls, gs = Flux.withgradient(model) do m
			lossfn(m(x_gpu), y_gpu)
		end
		push!(losses, ls)
	  	Flux.update!(optimizer, model, gs[1])
	  	# Step finished
		step_ct += 1
		if step_ct % 1000 == 0 
			@info "step $step_ct\tloss = $(mean(losses))"
			losses = Float32[]
		end
	end
end

train_1_epoch! (generic function with 1 method)

# 2. Train

[:,:,1,:] --> background  
[:,:,2,:] --> foreground

In [21]:
# @load "a_good_model.jld2" model_0 optimizer

In [22]:
# lossfn = dice_loss
model = unet2D(1, 2);

## 2.1 Debug

In [37]:
model_0_ps = Flux.params(model_0)

Params([[0.013864943 0.0024395904 -0.009355192; 0.00614296 -0.0034171205 -0.022988057; -0.041933335 0.028599536 0.0076876143;;;; -0.024341874 0.0089192055 0.024133267; -0.020765215 0.027129073 0.001777542; -0.023662323 -0.002478159 -0.006025589;;;; -0.03300408 -0.017547643 -0.0020365363; -0.012783114 0.017621964 0.026324354; 0.0055204765 -0.04018023 -0.00021446738;;;; … ;;;; -0.02251454 -0.0189265 -0.037282888; -0.02168737 0.012797994 0.02470656; -0.007238393 0.0034731731 -0.0047700927;;;; -0.0018573465 -0.004995119 -0.031599756; -0.038740292 -0.0030662715 0.0036857587; 0.0034231427 -0.061049543 -0.03760255;;;; -0.005144453 -0.0068676565 0.02899018; -0.007976617 -0.0007312118 -0.00038405063; -0.032121714 0.04017705 -0.029176876], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0

In [34]:
for (x, y) in train_loader
    x_gpu, y_gpu = x |> gpu, Float32.(y) |> gpu
    global ls, gs2 = Flux.withgradient(model_0) do m
        lossfn(m(x_gpu), y_gpu)+
    end
    break
end

In [38]:
model_0_ps[1]

3×3×1×64 CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}:
[:, :, 1, 1] =
  0.0138649    0.00243959  -0.00935519
  0.00614296  -0.00341712  -0.0229881
 -0.0419333    0.0285995    0.00768761

[:, :, 1, 2] =
 -0.0243419   0.00891921   0.0241333
 -0.0207652   0.0271291    0.00177754
 -0.0236623  -0.00247816  -0.00602559

[:, :, 1, 3] =
 -0.0330041   -0.0175476  -0.00203654
 -0.0127831    0.017622    0.0263244
  0.00552048  -0.0401802  -0.000214467

;;;; … 

[:, :, 1, 62] =
 -0.0225145   -0.0189265   -0.0372829
 -0.0216874    0.012798     0.0247066
 -0.00723839   0.00347317  -0.00477009

[:, :, 1, 63] =
 -0.00185735  -0.00499512  -0.0315998
 -0.0387403   -0.00306627   0.00368576
  0.00342314  -0.0610495   -0.0376026

[:, :, 1, 64] =
 -0.00514445  -0.00686766    0.0289902
 -0.00797662  -0.000731212  -0.000384051
 -0.0321217    0.0401771    -0.0291769

In [41]:
gs[1]

(layers = ((connection = nothing, layers = ((layers = ((connection = nothing, layers = ((layers = ((connection = nothing, layers = ((layers = ((connection = nothing, layers = ((layers = ((layers = ((layers = ((layers = ((layers = ((layers = ((σ = nothing, weight = [-1.9798158f-13 -3.6355454f-14 -6.4770793f-14; -1.7622812f-13 8.466041f-14 -3.4269137f-13; -2.8535957f-13 -6.8088173f-13 -3.0826058f-13;;;; 9.769796f-13 1.4254525f-12 1.4909893f-12; 9.489614f-13 1.454653f-12 1.0834225f-12; 1.1564234f-12 1.8277652f-12 1.4990748f-12;;;; -2.2272116f-14 1.7733951f-13 4.9710985f-14; 1.7647294f-13 3.1415986f-13 2.9869445f-13; 1.9737804f-13 2.607959f-13 2.694553f-13;;;; … ;;;; 5.9622304f-14 -4.7250232f-14 -4.1249807f-14; 8.3771416f-14 1.2661972f-13 -1.4099948f-13; -1.9625404f-13 -2.4894147f-13 -3.5188971f-13;;;; -2.149491f-14 -3.0331406f-14 -1.6826437f-14; 1.3329559f-14 -2.2255105f-14 7.8165034f-14; 2.987367f-14 -8.254871f-15 2.9794455f-14;;;; 2.4781238f-13 -3.372717f-14 -4.5250552f-13; -9.455533f-1

In [42]:
typeof(gs[1])

NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:connection, :layers), Tuple{Nothing, Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:connection, :layers), Tuple{Nothing, Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:connection, :layers), Tuple{Nothing, Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:connection, :layers), Tuple{Nothing, Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:layers,), Tuple{Tuple{NamedTuple{(:σ, :weight, :bias, :stride, :pad, :dilation, :groups), Tuple{Nothing, CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vararg{Nothing, 4}}}, NamedTuple{(:λ, :β, :γ, :μ, :σ², :ϵ, :momentum, :affine, :track_stats, :active, :chs), Tuple{Nothing, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Vararg{Nothing, 8}}}}}}, NamedTuple{(:layer

## 2.2 Actual train

In [23]:
model_0 = model |> gpu;

In [24]:
# optimizer = AdaGrad(0.01)
opt_state = Flux.setup(AdaGrad(0.01), model_0);

In [25]:
# device!(0)
# model = model_0 |> cpu
# @save "a_good_model2.jld2" model optimizer

In [26]:
model_0_ps = Flux.params(model_0)
train_1_epoch!(1, model_0, model_0_ps, train_loader, opt_state)

┌ Info: step 1000	loss = 0.99959815
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 2000	loss = 0.99896157
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 3000	loss = 0.9888568
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 4000	loss = 0.96931994
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 5000	loss = 0.9639192
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 6000	loss = 0.9470065
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 7000	loss = 0.9322687
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


┌ Info: step 8000	loss = 0.930597
└ @ Main /home/molloi-lab/Desktop/Project BAC/BAC project/4_train_single_gpu.ipynb:15


In [None]:
train_1_epoch!(1, model_0, model_0_ps, train_loader, opt_state)

In [None]:
train_1_epoch!(1, model_0, model_0_ps, train_loader, opt_state)

In [None]:
train_1_epoch!(1, model_0, model_0_ps, train_loader, opt_state)