In [1]:
include("./data.jl")
using Flux, CuArrays
using Flux: @treelike

In [4]:
UNetConvBlock(in_chs, out_chs, kernel = (3, 3)) =
    Chain(Conv(kernel, in_chs=>out_chs, relu, pad = (1, 1)),
          Conv(kernel, out_chs=>out_chs, relu, pad = (1, 1)))

struct UNetUpBlock
    upsample
    conv_layer
end

@treelike UNetUpBlock

UNetUpBlock(in_chs::Int, out_chs::Int, kernel = (3, 3)) =
    UNetUpBlock(ConvTranspose((2, 2), in_chs=>out_chs, stride=(2, 2)),
                Chain(Conv(kernel, in_chs=>out_chs, relu, pad=(1, 1)),
                      Conv(kernel, out_chs=>out_chs, relu, pad=(1, 1))))

function (u::UNetUpBlock)(x, bridge)
    x = u.upsample(x)
    # Since we know the image dimensions from beforehand we might as well not use the center_crop
    # u.conv_layer(cat(x, center_crop(bridge, size(x, 1), size(x, 2)), dims = 3))
    u.conv_layer(cat(x, bridge, dims = 3))
end

struct UNet
    pool_layer
    conv_blocks
    up_blocks
end

@treelike UNet

# This is to be used for Background and Foreground segmentation
function UNet()
    pool_layer = MaxPool((2, 2))
    conv_blocks = (UNetConvBlock(1, 64), UNetConvBlock(64, 128), UNetConvBlock(128, 256),
                   UNetConvBlock(256, 512), UNetConvBlock(512, 1024))
    up_blocks = (UNetUpBlock(1024, 512), UNetUpBlock(512, 256), UNetUpBlock(256, 128),
                 UNetUpBlock(128, 64), Conv((1, 1), 64=>1))
    UNet(pool_layer, conv_blocks, up_blocks)
end

function (u::UNet)(x)
    outputs = Vector(undef, 5)
    outputs[1] = u.conv_blocks[1](x)
    for i in 2:5
        pool_x = u.pool_layer(outputs[i - 1])
        outputs[i] = u.conv_blocks[i](pool_x)
    end
    up_x = outputs[end]
    for i in 1:4
        up_x = u.up_blocks[i](up_x, outputs[end - i])
    end
    u.up_blocks[end](up_x)
end

In [32]:
test = UNet()|>gpu

UNet(MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2)), (Chain(Conv((3, 3), 1=>64, NNlib.relu), Conv((3, 3), 64=>64, NNlib.relu)), Chain(Conv((3, 3), 64=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu)), Chain(Conv((3, 3), 128=>256, NNlib.relu), Conv((3, 3), 256=>256, NNlib.relu)), Chain(Conv((3, 3), 256=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu)), Chain(Conv((3, 3), 512=>1024, NNlib.relu), Conv((3, 3), 1024=>1024, NNlib.relu))), (UNetUpBlock(ConvTranspose((2, 2), 1024=>512), Chain(Conv((3, 3), 1024=>512, NNlib.relu), Conv((3, 3), 512=>512, NNlib.relu))), UNetUpBlock(ConvTranspose((2, 2), 512=>256), Chain(Conv((3, 3), 512=>256, NNlib.relu), Conv((3, 3), 256=>256, NNlib.relu))), UNetUpBlock(ConvTranspose((2, 2), 256=>128), Chain(Conv((3, 3), 256=>128, NNlib.relu), Conv((3, 3), 128=>128, NNlib.relu))), UNetUpBlock(ConvTranspose((2, 2), 128=>64), Chain(Conv((3, 3), 128=>64, NNlib.relu), Conv((3, 3), 64=>64, NNlib.relu))), Conv((1, 1), 64=>1)))

In [33]:
imgpaths = loadimgpaths("D:/code/kaggle/data/train-test/dicom-images-train")
rle = CSV.read("D:/code/kaggle/data/train-test/train-rle.csv")
loader = dataloader(imgpaths, rle, 8; imsize = (128,128))

Base.Generator{Base.Iterators.PartitionIterator{Array{String,1}},getfield(Main, Symbol("##16#17")){Tuple{Int64,Int64},DataFrames.DataFrame}}(getfield(Main, Symbol("##16#17")){Tuple{Int64,Int64},DataFrames.DataFrame}((128, 128), 11582×2 DataFrames.DataFrame. Omitted printing of 1 columns
│ Row   │ ImageId                                                 │
│       │ [90mString[39m                                                  │
├───────┼─────────────────────────────────────────────────────────┤
│ 1     │ 1.2.276.0.7230010.3.1.4.8323329.5597.1517875188.959090  │
│ 2     │ 1.2.276.0.7230010.3.1.4.8323329.12515.1517875239.501137 │
│ 3     │ 1.2.276.0.7230010.3.1.4.8323329.4904.1517875185.355709  │
│ 4     │ 1.2.276.0.7230010.3.1.4.8323329.32579.1517875161.299312 │
│ 5     │ 1.2.276.0.7230010.3.1.4.8323329.32579.1517875161.299312 │
│ 6     │ 1.2.276.0.7230010.3.1.4.8323329.32579.1517875161.299312 │
│ 7     │ 1.2.276.0.7230010.3.1.4.8323329.32579.1517875161.299312 │
│ 8     │ 1.2.276.0.72

In [34]:
test(gpu(first(loader)[1]))

OutOfMemoryError: OutOfMemoryError()

In [13]:
first(loader)[1]

MethodError: MethodError: no method matching getindex(::getfield(Main, Symbol("##19#21")), ::Int64)
Closest candidates are:
  getindex(::Any, !Matched::AbstractTrees.ImplicitRootState) at C:\Users\jules\.julia\packages\AbstractTrees\z1wBY\src\AbstractTrees.jl:344