Skip to content

Commit

Permalink
[R] Image segmentation example and test. close apache#5003 (apache#7096)
Browse files Browse the repository at this point in the history
  • Loading branch information
thirdwing committed Jul 20, 2017
1 parent e3fd434 commit c40442a
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 84 deletions.
18 changes: 18 additions & 0 deletions R-package/tests/testthat/get_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ GetMNIST_ubyte <- function() {
!file.exists('data/t10k-labels-idx1-ubyte')) {
download.file('http://data.mxnet.io/mxnet/data/mnist.zip', destfile = 'data/mnist.zip')
unzip('data/mnist.zip', exdir = 'data/')
file.remove('data/mnist.zip')
}
}

Expand All @@ -21,6 +22,7 @@ GetMNIST_csv <- function() {
download.file('https://s3-us-west-2.amazonaws.com/apache-mxnet/R/data/mnist_csv.zip',
destfile = 'data/mnist_csv.zip')
unzip('data/mnist_csv.zip', exdir = 'data/')
file.remove('data/mnist_csv.zip')
}
}

Expand All @@ -35,6 +37,7 @@ GetCifar10 <- function() {
download.file('http://data.mxnet.io/mxnet/data/cifar10.zip',
destfile = 'data/cifar10.zip')
unzip('data/cifar10.zip', exdir = 'data/')
file.remove('data/cifar10.zip')
}
}

Expand All @@ -61,6 +64,7 @@ GetCatDog <- function() {
download.file('https://s3-us-west-2.amazonaws.com/apache-mxnet/R/data/cats_dogs.zip',
destfile = 'data/cats_dogs.zip')
unzip('data/cats_dogs.zip', exdir = 'data/')
file.remove('data/cats_dogs.zip')
}
}

Expand All @@ -72,5 +76,19 @@ GetMovieLens <- function() {
download.file('http://files.grouplens.org/datasets/movielens/ml-100k.zip',
destfile = 'data/ml-100k.zip')
unzip('data/ml-100k.zip', exdir = 'data/')
file.remove('data/ml-100k.zip')
}
}

GetISBI_data <- function() {
if (!dir.exists("data")) {
dir.create("data/")
}
if (!file.exists('data/ISBI/train-volume.tif') |
!file.exists('data/ISBI/train-labels.tif')) {
download.file('https://s3-us-west-2.amazonaws.com/apache-mxnet/R/data/ISBI.zip',
destfile = 'data/ISBI.zip')
unzip('data/ISBI.zip', exdir = 'data/')
file.remove('data/ISBI.zip')
}
}
130 changes: 130 additions & 0 deletions R-package/tests/testthat/test_img_seg.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
require(mxnet)

source("get_data.R")

print_inferred_shape <- function(net) {
slist <- mx.symbol.infer.shape(symbol = net, data = c(168, 168, 1, 2))
print(slist$out.shapes)
}

convolution_module <- function(net, kernel_size, pad_size, filter_count,
stride = c(1, 1), work_space = 2048, batch_norm = TRUE,
down_pool = FALSE, up_pool = FALSE, act_type = "relu",
convolution = TRUE) {
if (up_pool) {
net = mx.symbol.Deconvolution(net, kernel = c(2, 2), pad = c(0, 0),
stride = c(2, 2), num_filter = filter_count,
workspace = work_space)
net = mx.symbol.BatchNorm(net)
if (act_type != "") {
net = mx.symbol.Activation(net, act_type = act_type)
}
}
if (convolution) {
conv = mx.symbol.Convolution(data = net, kernel = kernel_size, stride = stride,
pad = pad_size, num_filter = filter_count,
workspace = work_space)
net = conv
}
if (batch_norm) {
net = mx.symbol.BatchNorm(net)
}

if (act_type != "") {
net = mx.symbol.Activation(net, act_type = act_type)
}

if (down_pool) {
pool = mx.symbol.Pooling(net, pool_type = "max", kernel = c(2, 2), stride = c(2, 2))
net = pool
}
print_inferred_shape(net)
return(net)
}

get_unet <- function() {
data = mx.symbol.Variable('data')
kernel_size = c(3, 3)
pad_size = c(1, 1)
filter_count = 32
pool1 = convolution_module(data, kernel_size, pad_size, filter_count = filter_count, down_pool = TRUE)
net = pool1
pool2 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 2, down_pool = TRUE)
net = pool2
pool3 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, down_pool = TRUE)
net = pool3
pool4 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, down_pool = TRUE)
net = pool4
net = mx.symbol.Dropout(net)
pool5 = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 8, down_pool = TRUE)
net = pool5
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, up_pool = TRUE)
net = convolution_module(net, kernel_size, pad_size = c(2, 2), filter_count = filter_count * 4, up_pool = TRUE)
net = mx.symbol.Crop(net, pool3, num.args = 2)
net = mx.symbol.concat(c(pool3, net), num.args = 2)
net = mx.symbol.Dropout(net)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, up_pool = TRUE)

net = mx.symbol.Concat(c(pool2, net), num.args = 2)
net = mx.symbol.Dropout(net)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4, up_pool = TRUE)
convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 4)
net = mx.symbol.Concat(c(pool1, net), num.args = 2)
net = mx.symbol.Dropout(net)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 2)
net = convolution_module(net, kernel_size, pad_size, filter_count = filter_count * 2, up_pool = TRUE)
net = convolution_module(net, kernel_size, pad_size, filter_count = 1, batch_norm = FALSE, act_type = "")
net = mx.symbol.SoftmaxOutput(data = net, name = 'sm')
return(net)
}

context("Image segmentation")

test_that("UNET", {
list.of.packages <- c("imager")
new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
if(length(new.packages)) install.packages(new.packages)
GetISBI_data()
library(imager)
IMG_SIZE <- 168
files <- list.files(path = "data/ISBI/train-volume/")
a = 'data/ISBI/train-volume/'
filess = paste(a, files, sep = '')
list_of_images = lapply(filess, function(x) {
x <- load.image(x)
y <- resize(x, size_x = IMG_SIZE, size_y = IMG_SIZE)
})

train.x = do.call('cbind', lapply(list_of_images, as.vector))
train.array <- train.x
dim(train.array) <- c(IMG_SIZE, IMG_SIZE, 1, 30)

files <- list.files(path = "data/ISBI/train-labels")
b = 'data/ISBI/train-labels/'
filess = paste(b, files, sep = '')
list_of_images = lapply(filess, function(x) {
x <- load.image(x)
y <- resize(x, size_x = IMG_SIZE, size_y = IMG_SIZE)
})

train.y = do.call('cbind', lapply(list_of_images, as.vector))

train.y[which(train.y < 0.5)] = 0
train.y[which(train.y > 0.5)] = 1
train.y.array = train.y
dim(train.y.array) = c(IMG_SIZE, IMG_SIZE, 1, 30)

devices <- mx.cpu()
mx.set.seed(0)

net <- get_unet()

model <- mx.model.FeedForward.create(net, X = train.array, y = train.y.array,
ctx = devices, num.round = 2,
initializer = mx.init.normal(sqrt(2 / 576)),
learning.rate = 0.05,
momentum = 0.99,
array.batch.size = 2)
})
6 changes: 3 additions & 3 deletions R-package/tests/testthat/test_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ test_that("Regression", {
})
mx.set.seed(0)
model <- mx.model.FeedForward.create(lro, X = train.x, y = train.y,
ctx = mx.cpu(), num.round = 50,
ctx = mx.cpu(), num.round = 5,
array.batch.size = 20,
learning.rate = 2e-6,
momentum = 0.9,
Expand All @@ -103,7 +103,7 @@ test_that("Classification", {
mx.set.seed(0)
model <- mx.mlp(train.x, train.y, hidden_node = 10,
out_node = 2, out_activation = "softmax",
num.round = 20, array.batch.size = 15,
num.round = 5, array.batch.size = 15,
learning.rate = 0.07,
momentum = 0.9,
eval.metric = mx.metric.accuracy)
Expand Down Expand Up @@ -218,7 +218,7 @@ test_that("Matrix Factorization", {
train_iter <- CustomIter$new(user_iter, item_iter)

model <- mx.model.FeedForward.create(pred3, X = train_iter, ctx = devices,
num.round = 10, initializer = mx.init.uniform(0.07),
num.round = 5, initializer = mx.init.uniform(0.07),
learning.rate = 0.07,
eval.metric = mx.metric.rmse,
momentum = 0.9,
Expand Down
81 changes: 0 additions & 81 deletions example/image-classification/symbol_unet.R

This file was deleted.

0 comments on commit c40442a

Please sign in to comment.