Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

custom loss symbol in R/Python #3368

Closed
uzhao opened this issue Sep 24, 2016 · 23 comments · Fixed by #4181
Closed

custom loss symbol in R/Python #3368

uzhao opened this issue Sep 24, 2016 · 23 comments · Fixed by #4181

Comments

@uzhao
Copy link

uzhao commented Sep 24, 2016

I tried to create a custom loss symbol in R or Python. I found an example using MakeLoss in Python at https://zhuanlan.zhihu.com/p/21725762?refer=xlvector. I tried to create a network to minimize MSE for linear regression but never work. Could anyone please provide an example? Thanks.

@winstywang winstywang added the R label Sep 27, 2016
@uzhao uzhao changed the title custom loss symbol in R custom loss symbol in R/Python Sep 28, 2016
@ijkguo
Copy link
Contributor

ijkguo commented Sep 29, 2016

If the goal is to do linear regression, maybe mxnet.symbol.LogisticRegressionOutput can be considered.
Another usage of composition loss can be found at https://github.com/dmlc/mxnet/blob/master/example/rcnn/rcnn/symbol.py#L111

@uzhao
Copy link
Author

uzhao commented Sep 29, 2016

Thank you for the suggestion. In my case the objective function could be complex and I just want an example for MakeLoss. I believe this rcnn is the only official example with MakeLoss, and it's still too complex to see how to create a mlp with custom loss function. In tflearn this is quite simple, you can provide a function directly. I also read the example to create a new operator, but it looks very complex and unnecessary.

@ijkguo
Copy link
Contributor

ijkguo commented Sep 30, 2016

Well simplicity comes with the price of slow speed and high memory cost.

@uzhao
Copy link
Author

uzhao commented Sep 30, 2016

I don't get this part. For a customized objective function, using MakeLoss is easier than write a new layer. Anyway, I figure out how to use it now. But the performance looks wired. I tried regression sample in http://mxnet.readthedocs.io/en/latest/packages/r/fiveMinutesNeuralNetwork.html and changed loss to
lro <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.square(mx.symbol.Reshape(fc1, shape = -1) - label)))
the result is unacceptable bad.

@ijkguo
Copy link
Contributor

ijkguo commented Oct 3, 2016

How about try something simpler like l2-loss as MakeLoss(square(fc - label)) and see if other components of this learning process is affecting performance?

@uzhao
Copy link
Author

uzhao commented Oct 3, 2016

I tried square loss and it doesn't work. This is another test and doesn't work as well.
with lro <- mx.symbol.LinearRegressionOutput(fc1)
sqrt(mean((preds-test.y)^2)) = 7.800502
with lro <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc1, shape = 0) - label))
sqrt(mean((preds-test.y)^2)) = 557.1939

@ijkguo
Copy link
Contributor

ijkguo commented Oct 3, 2016

Then perhaps try symbol.LinearRegressionOutput to make sure other part like io or symbol composition is correct? If that works, the problem could be the composited loss. It would be most helpful if you give an example where the loss is giving wrong gradient. Refer to symbol.simple_bind for this purpose.

@uzhao
Copy link
Author

uzhao commented Oct 3, 2016

Here is the full script.
https://gist.github.com/uzhao/8ceeedfc5b87d45987b9fcaa2847299b

two models are

{
"nodes": [
{
"op": "null",
"param": {},
"name": "data",
"inputs": [],
"backward_source_id": -1
},
{
"op": "null",
"param": {},
"name": "fullyconnected5_weight",
"inputs": [],
"backward_source_id": -1
},
{
"op": "null",
"param": {},
"name": "fullyconnected5_bias",
"inputs": [],
"backward_source_id": -1
},
{
"op": "FullyConnected",
"param": {
"no_bias": "False",
"num_hidden": "1"
},
"name": "fullyconnected5",
"inputs": [[0, 0], [1, 0], [2, 0]],
"backward_source_id": -1
},
{
"op": "null",
"param": {},
"name": "linearregressionoutput2_label",
"inputs": [],
"backward_source_id": -1
},
{
"op": "LinearRegressionOutput",
"param": {"grad_scale": "1"},
"name": "linearregressionoutput2",
"inputs": [[3, 0], [4, 0]],
"backward_source_id": -1
}
],
"arg_nodes": [0, 1, 2, 4],
"heads": [[5, 0]]

}

{
"nodes": [
{
"op": "null",
"param": {},
"name": "data",
"inputs": [],
"backward_source_id": -1
},
{
"op": "null",
"param": {},
"name": "fullyconnected6_weight",
"inputs": [],
"backward_source_id": -1
},
{
"op": "null",
"param": {},
"name": "fullyconnected6_bias",
"inputs": [],
"backward_source_id": -1
},
{
"op": "FullyConnected",
"param": {
"no_bias": "False",
"num_hidden": "1"
},
"name": "fullyconnected6",
"inputs": [[0, 0], [1, 0], [2, 0]],
"backward_source_id": -1
},
{
"op": "Reshape",
"param": {
"keep_highest": "False",
"reverse": "False",
"shape": "(0,)",
"target_shape": "(0,0)"
},
"name": "reshape2",
"inputs": [[3, 0]],
"backward_source_id": -1
},
{
"op": "null",
"param": {},
"name": "label",
"inputs": [],
"backward_source_id": -1
},
{
"op": "_Minus",
"param": {},
"name": "_minus5",
"inputs": [[4, 0], [5, 0]],
"backward_source_id": -1
},
{
"op": "square",
"param": {},
"name": "square5",
"inputs": [[6, 0]],
"backward_source_id": -1
},
{
"op": "MakeLoss",
"param": {"grad_scale": "1"},
"name": "makeloss5",
"inputs": [[7, 0]],
"backward_source_id": -1
}
],
"arg_nodes": [0, 1, 2, 5],
"heads": [[8, 0]]
}

@khalida
Copy link

khalida commented Dec 10, 2016

Hi, I'm trying to also implement mean-squared error as a custom loss functions as a path to understanding how to implement custom loss functions in general. I have run the example code you have posted above and got the same results.

As far as I can tell something which is missing is attaching the label/response values to the label symbol. In line 26 of your code you assign a label variable, but as far as I can tell this is never attached to the labels of the training data.

Unless by giving it the name label it is automatically treated as such (in the same way the data symbol is attached to the features because it has been given the name data). I'm afraid I'm not very familiar with the mx.symbol API so the above may not make that much sense.

@thirdwing
Copy link
Contributor

@uzhao

First, your usage of make mx.symbol.MakeLoss is correct.

Second, you can try different initializer and learning.rate after changing the loss function.

I just randomly pick one and the result 23.88418 seems reasonable. You can try more.

data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=1)
lro <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc1, shape = 0) - label))

mx.set.seed(0)
model <- mx.model.FeedForward.create(lro, X=train.x, y=train.y,
                                     initializer=mx.init.uniform(0.002),
                                     ctx=mx.cpu(), num.round=50, array.batch.size=20,
                                     learning.rate=2e-12, momentum=0.9,
                                     eval.metric=mx.metric.rmse)

preds = predict(model, test.x)
sqrt(mean((preds-test.y)^2))
## [1] 23.88418

@thirdwing
Copy link
Contributor

I think we might need to write a small vignette on how to use mx.symbol.MakeLoss.

@JianyangZhao
Copy link

@thirdwing
I think your initializer works by coincidence, since if you change to 1 round you will get 23.98614. And increase number of rounds doesn't improve the result at all.

@khalida
Copy link

khalida commented Dec 20, 2016

Apologies for cross-posing, but I think the issue I have raised here is relevant to this discussion also.

I have given full details in this SO question, but basically even with the standard LinearRegressionOutput layer the performance on simple regression problems seems to be poor.

Note that I'm only looking at training performance, so it should be possible to get very low errors, and this is the case with other neural network tools, but I have been unable to get good regression performance from mxnet.

@khalida
Copy link

khalida commented Dec 21, 2016

Thanks to some help provided on the other issue I have been able to resolve the problems I've had with regression performance of mxnet.

However I have noticed something unexpected (to me) when using MakeLoss to define custom loss functions. It appears that if you define a squared loss function (as in example above), the network outputs approximations for y^2 (where y is the target of regression).

To illustrate this consider the example below, in which I train 4 networks to perform simple polynomial regression. The first is using nnet just as a sanity check. The second (top-right) uses an mxnet LinearRegressionOutput layer. The third (bottom-left) uses a MakeLoss layer as defined above with a mx.symbol.square loss function, and the fourth (bottom-right) uses a MakeLoss layer with a mx.symbol.sqrt(mx.symbol.sqare()) loss function.

On all plots I have included a y=x (perfect fit) line (red), and on the third plot I have included a y=x^2 line (blue). As you can see the model trained with a square loss function appears to output the square of the response. This does not make sense to me, is it expected behaviour? If so my understanding of how MakeLoss works is very poor.

Plots:
mxnet_regression_makeloss

RMSE results:

$nnet
[1] 0.03222651

$mxModel_lro
[1] 0.03263213

$mxModel_makeloss
[1] 0.8708557

$mxModel_makeloss_sqrt
[1] 0.03802211

Produced from this code:

## REGRESSION WITH MAKELOSS
# Check mxnet regression performance with custom loss function

library(mxnet)
library(tictoc)
library(reshape)
library(nnet)

# Data config
nObservations <- 1000
noiseLvl <- 0.1

# Network config
nHidden <- 3
batchSize <- 100
nRound <- 400
verbose <- FALSE
array.layout = "rowmajor"
optimizer <- "rmsprop"

# GENERATE DATA:
set.seed(0)
df <- data.frame(x1=runif(nObservations),
                 x2=runif(nObservations),
                 x3=runif(nObservations))

df$y <- df$x1 + df$x2^2 + df$x3^3 + noiseLvl*runif(nObservations)
# normalize data columns
# df <- scale(df)

# Seperate data into train/test
test.ind = seq(1, nObservations, 10)    # 1 in 10 samples for testing
train.x = data.matrix(df[-test.ind, -which(colnames(df) %in% c("y"))])
train.y = df[-test.ind, "y"]
test.x = data.matrix(df[test.ind, -which(colnames(df) %in% c("y"))])
test.y = df[test.ind, "y"]

# Define mxnet network, following 5-minute regression example from here:
# http://mxnet-tqchen.readthedocs.io/en/latest//packages/r/fiveMinutesNeuralNetwork.html#regression
# but with additional hidden layer
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=nHidden, name="fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type="tanh", name="tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden=1, name="fc2")
lro <- mx.symbol.LinearRegressionOutput(data=fc2, label=label, name="lro2")
tic("mxnet training")
mx.set.seed(0)
mxModel_lro <- mx.model.FeedForward.create(lro, X=train.x, y=train.y,
                                           eval.data=list(data=test.x, label=test.y),
                                           ctx=mx.cpu(), num.round=nRound,
                                           array.batch.size=batchSize,
                                           eval.metric=mx.metric.rmse,
                                           verbose=verbose,
                                           array.layout=array.layout,
                                           optimizer=optimizer)
toc()

lro_makeLoss <- mx.symbol.MakeLoss(
  mx.symbol.square(label - mx.symbol.Reshape(fc2, shape=0)), name="lro3")

tic("mxnet training, make loss squared")
mx.set.seed(0)
mxModel_makeloss <- mx.model.FeedForward.create(lro_makeLoss, X=train.x, y=train.y,
                                                eval.data=list(data=test.x, label=test.y),
                                                ctx=mx.cpu(), num.round=nRound,
                                                array.batch.size=batchSize,
                                                eval.metric=mx.metric.rmse,
                                                verbose=verbose,
                                                array.layout=array.layout,
                                                optimizer=optimizer)
toc()

lro_makeLoss_sqrt <- mx.symbol.MakeLoss(
  mx.symbol.sqrt(mx.symbol.square(label - mx.symbol.Reshape(fc2, shape=0))),
  name="lro3")

tic("mxnet training, make loss sqrt(squared)")
mx.set.seed(0)
mxModel_makeloss_sqrt <- mx.model.FeedForward.create(lro_makeLoss_sqrt,
                                                     X=train.x, y=train.y,
                                                     eval.data=list(data=test.x, label=test.y),
                                                     ctx=mx.cpu(), num.round=nRound,
                                                     array.batch.size=batchSize,
                                                     eval.metric=mx.metric.rmse,
                                                     verbose=verbose,
                                                     array.layout=array.layout,
                                                     optimizer=optimizer)
toc()

# Train nnet model
set.seed(0)
tic("nnet training")
nnetModel <- nnet(y~x1+x2+x3, data=df[-test.ind, ], size=nHidden, trace=F,
                  linout=TRUE)
toc()

# Check response VS targets on training data:
par(mfrow=c(2,2))

plot(train.y, predict(nnetModel, train.x), 
     main="nnet Train Fit", xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

plot(train.y, predict(mxModel_lro, train.x, array.layout=array.layout), 
     main="LRO mxnet Train Fit", xlab="Target",
     ylab="Response")
abline(0,1, col="red", lwd=2)

plot(train.y, predict(mxModel_makeloss, train.x, array.layout=array.layout),
     main="MakeLoss square mxnet Train Fit", xlab="Target",
     ylab="Response")
abline(0,1, col="red", lwd=2)
curve(x^2, add=TRUE, col="blue", lwd=2)

plot(train.y, predict(mxModel_makeloss_sqrt, train.x, array.layout=array.layout),
     main="MakeLoss sqrt(sqare) Train Fit",
     xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

# Create and print table of results:
results <- list()
rmse <- function(target, response) {
  return(sqrt(mean((target - response)^2)))
}

results$nnet <- rmse(train.y, predict(nnetModel, train.x))

results$mxModel_lro <- rmse(train.y, predict(mxModel_lro, train.x,
                                          array.layout=array.layout))
results$mxModel_makeloss <- rmse(train.y, predict(mxModel_makeloss, train.x,
                                          array.layout=array.layout))
results$mxModel_makeloss_sqrt <- rmse(train.y, predict(mxModel_makeloss_sqrt,
                                                       train.x, 
                                                       array.layout=array.layout))


print(results)

@thirdwing
Copy link
Contributor

thirdwing commented Dec 21, 2016

@piiswrong Do you have any idea on @khalida 's question?

@thirdwing thirdwing reopened this Dec 21, 2016
@uzhao
Copy link
Author

uzhao commented Dec 21, 2016

I guess MakeLoss was used as grad function rather than loss function?

@khalida
Copy link

khalida commented Dec 21, 2016

One possible explanation from what I can see (keep in mind I don't understand the internal workings of mxnet very well) is that when a MakeLoss layer is used, the output of a call to predict is not the desired network response (pred), but rather the loss function itself when a label of zero is assumed.

This is made a little clearer if we normalize (to zero mean and unit variance) the data of the excellent MakeLoss example which @thirdwing wrote (link to that example). Details of this example included below.

If this is the case then what we really want when we call predict is to extract the data from the fc2 layer (or whatever is the last layer before we start computing the loss function). Is there a way to do that so I can check?

The plots from this example are shown below:
mxnet_regression_squared_bostonhousing

And the code used to produce these plots (based on the BostonHousing example provided by @thirdwing) is pasted below

# Custom loss function tutorial from:
# https://github.com/dmlc/mxnet/blob/master/docs/tutorials/r/CustomLossFunction.md

# Network config
optimizer <- "rmsprop"
batchSize <- 100
nRounds <- 500
normalize <- TRUE
nHidden <- 14
verbose <- FALSE
array.layout <- "rowmajor"

library(mxnet)
data(BostonHousing, package="mlbench")
if (normalize) {
  BostonHousing[, sapply(BostonHousing, is.factor)] <- 
    as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
  BostonHousing <- data.frame(scale(BostonHousing))
}
test.ind = seq(1, 506, 5)    # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind, -14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[--test.ind, -14])
test.y = BostonHousing[--test.ind, 14]
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=nHidden)
tanh1 <- mx.symbol.Activation(fc1, act_type="tanh")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden=1)
lro <- mx.symbol.LinearRegressionOutput(fc2)
mx.set.seed(0)
model <- mx.model.FeedForward.create(lro,
                                     X=train.x, y=train.y,
                                     eval.data=list(data=test.x, label=test.y),
                                     ctx=mx.cpu(), num.round=nRounds,
                                     array.batch.size=batchSize,
                                     eval.metric=mx.metric.rmse,
                                     optimizer=optimizer, verbose=verbose,
                                     array.layout=array.layout)

lro2 <- mx.symbol.MakeLoss(mx.symbol.square(
  mx.symbol.Reshape(fc2, shape = 0) - label))

model2 <- mx.model.FeedForward.create(lro2,
                                      X=train.x, y=train.y,
                                      eval.data=list(data=test.x, label=test.y),
                                      ctx=mx.cpu(), num.round=nRounds,
                                      array.batch.size=batchSize,
                                      eval.metric=mx.metric.rmse,
                                      optimizer=optimizer, verbose=verbose,
                                      array.layout=array.layout)


# Plotting of fits

par(mfrow=c(1,2))

plot(train.y, predict(model, train.x, array.layout=array.layout),
     main="nnet Train Fit", xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

plot(train.y, predict(model2, train.x, array.layout=array.layout),
     main="nnet MakeLoss square Train Fit", xlab="Target",  ylab="Response")
abline(0,1, col="red", lwd=2)
curve(x^2, col="blue", lwd=2, add=TRUE)

plot(test.y, predict(model, test.x, array.layout=array.layout),
     main="nnet Test Fit", xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

plot(test.y, predict(model2, test.x, array.layout=array.layout),
     main="nnet MakeLoss square Test Fit", xlab="Target",  ylab="Response")
abline(0,1, col="red", lwd=2)
curve(x^2, col="blue", lwd=2, add=TRUE)

@khalida
Copy link

khalida commented Dec 21, 2016

Ok, I have managed to work this out by following the example for extracting the data from internal layers in R given here.

Below is some example code which trains an linear regression model using LinearRegressionOutput layer, and then trains another using a MakeLoss output with square-error loss function.

For the MakeLoss model a mx.symbol.Group and an executor are used to perform the forward pass of the network for extracting the pred (network response) for plotting purposes. Results between the two models are similar as one might hope.

If someone with some more mxnet experience (@thirdwing, @piiswrong, @Lodewic) could give the code below a quick check, I would be happy to update the MakeLoss example given here and submit a PR.

A few things which feel a little hacky about my implementation below:

  1. I have created separate train and test executors for plotting the train and test data
  2. In a number of places I have to transpose my data presumably because of a [examples x features] VS [features x examples] conflict somewhere (elsewhere I have seen this more elegantly handled by setting the array.layout argument... but that didn't seem available here?
  3. Is there some way we could modify MakeLoss and/or the behaviour of predict to do this automatically? When people specify a loss layer they will still expect their network to output a response (not the loss of that response with a zero label).

Plots:
mxnet_regression_square_loss_extract

Code:

# Custom loss function tutorial from:
# https://github.com/dmlc/mxnet/blob/master/docs/tutorials/r/CustomLossFunction.md

# modifed to extract the predicted response from within the network

# Network config
optimizer <- "rmsprop"
batchSize <- 100
nRounds <- 500
normalize <- TRUE
nHidden <- 14
verbose <- FALSE
array.layout <- "rowmajor"

library(mxnet)
data(BostonHousing, package="mlbench")
if (normalize) {
  BostonHousing[, sapply(BostonHousing, is.factor)] <- 
    as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
  BostonHousing <- data.frame(scale(BostonHousing))
}
test.ind = seq(1, 506, 5)    # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind, -14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[--test.ind, -14])
test.y = BostonHousing[--test.ind, 14]
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=nHidden, name="fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type="tanh", name="tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden=1, name="fc2")
lro <- mx.symbol.LinearRegressionOutput(fc2, name="lro")
mx.set.seed(0)
model <- mx.model.FeedForward.create(lro,
                                     X=train.x, y=train.y,
                                     eval.data=list(data=test.x, label=test.y),
                                     ctx=mx.cpu(), num.round=nRounds,
                                     array.batch.size=batchSize,
                                     eval.metric=mx.metric.rmse,
                                     optimizer=optimizer, verbose=verbose,
                                     array.layout=array.layout)

lro2 <- mx.symbol.MakeLoss(
  mx.symbol.square(mx.symbol.Reshape(fc2, shape = 0) - label), name="lro2")

# Create output of networki
out <- mx.symbol.Group(c(fc2, lro2))

# Create an executor (responsible for forward pass test)
# note transpose because of rowmajor / colmajor conflict - HACK?
testExecutor <- mx.simple.bind(symbol=out, data=dim(t(test.x)), ctx=mx.cpu())
trainExecutor <- mx.simple.bind(symbol=out, data=dim(t(train.x)), ctx=mx.cpu())

# Train the model (based on custom loss)
mx.set.seed(0)
model2 <- mx.model.FeedForward.create(lro2,
                                      X=train.x, y=train.y,
                                      eval.data=list(data=test.x, label=test.y),
                                      ctx=mx.cpu(), num.round=nRounds,
                                      array.batch.size=batchSize,
                                      eval.metric=mx.metric.rmse,
                                      optimizer=optimizer, verbose=verbose,
                                      array.layout=array.layout)


# Update executor parameters and outputs
# Update parameters:
mx.exec.update.arg.arrays(testExecutor, model2$arg.params, match.name=TRUE)
mx.exec.update.arg.arrays(trainExecutor, model2$arg.params, match.name=TRUE)

mx.exec.update.aux.arrays(testExecutor, model2$aux.params, match.name=TRUE)
mx.exec.update.aux.arrays(trainExecutor, model2$aux.params, match.name=TRUE)

# Select data to use
mx.exec.update.arg.arrays(testExecutor, list(data=mx.nd.array(t(test.x))),
                          match.name=TRUE)

mx.exec.update.arg.arrays(trainExecutor, list(data=mx.nd.array(t(train.x))),
                          match.name=TRUE)

# Do a forward pass with the current parameters and data
mx.exec.forward(testExecutor, is.train=FALSE)
mx.exec.forward(trainExecutor, is.train=FALSE)

# Plotting of fits
par(mfrow=c(1,2))

# Train fits
plot(train.y, predict(model, train.x, array.layout=array.layout),
     main="nnet Train Fit", xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

plot(train.y, as.array(trainExecutor$ref.outputs$fc2_output),
     main="nnet MakeLoss square Train Fit", xlab="Target",  ylab="Response")
abline(0,1, col="red", lwd=2)

# Test fits
plot(test.y, predict(model, test.x, array.layout=array.layout),
     main="nnet Test Fit", xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

plot(test.y, as.array(testExecutor$ref.outputs$fc2_output),
     main="nnet MakeLoss square Test Fit", xlab="Target",  ylab="Response")
abline(0,1, col="red", lwd=2)

@thirdwing
Copy link
Contributor

@khalida Thank you for what you have done. I have to admit I never used Makeloss by myself, so I might misunderstood it.

If we can confirm your solution, let's fix the documents first. Then I will try to provide some helper functions to make it easy.

@khalida
Copy link

khalida commented Dec 24, 2016

Hi @thirdwing many thanks for the response. Could you provide some links to appropriate parts of the mxnet docs so that I can check my understanding and confirm that it's correct?

In the meantime the example below might be able to form the basis of some updated documentation for the use of MakeLoss.

It's a bit long, but essentially it attempts to train 4 neural networks using mxnet via R:

  1. Using LinearRegressionOutput as a reference
  2. Using MakeLoss to minimize squared-error
  3. Using MakeLoss to minimize quartic-error (fourth-power error)
  4. Using MakeLoss to minimize mean-absolute-error

Each of the networks are then used for prediction (on both the in-sample training set, and a held-out test set) and their responses assessed using 3 error metrics. As might be hoped for, the networks perform well on the metric they have been trained to minimize.

The results (errors have been normalized to the errors of the LRO model):

[1] "Train results:"
     model.lro model.se.loss model.qe.loss model.mae.loss
rmse         1             1          1.27           1.19
rmqe         1             1          0.91           2.00
mae          1             1          1.58           0.95
[1] "Test results:"
     model.lro model.se.loss model.qe.loss model.mae.loss
rmse         1             1           1.3           1.07
rmqe         1             1           0.9           1.11
mae          1             1           1.6           0.94

The plots:
mxnet_custom_loss_regression

The code:

## Specific custom loss example given at:
# https://github.com/dmlc/mxnet/pull/4181/commits/52df9498d1efd0a0e6373a7a61cff632fc435b99
# modified to train a network to modify each of 3 loss functions,
# and also to use 'predict' layer output rather than loss layer output

library(mxnet)

# Network training config
verbose <- FALSE
nRounds <- 1000
nHidden <- 3
batchSize <- 128
optimizer <- "rmsprop"
array.layout <- "rowmajor"

# Data config
nObservations <- 5000
noiseLvl <- 0.2

# Loss functions
# Root Mean Squared Error - rmse
rmse <- function(actual, response) {
  return(mean((actual - response)^2)^(1/2))
}

# Root Mean Quartic Error - rmqe
rmqe <- function(actual, response) {
  return(mean((actual - response)^4)^(1/4))
}

# Mean Absolute Error - mae
mae <- function(actual, response) {
  return(mean(abs(actual - response)))
}

# Generate some random data
set.seed(0)
df <- data.frame(x1=rnorm(nObservations), 
                 x2=rnorm(nObservations), 
                 x3=rnorm(nObservations),
                 x4=rnorm(nObservations))
df$y <- df$x1 + df$x2^2 + df$x3^3 + df$x4^4 + noiseLvl*rnorm(nObservations)

# Scale data to zero-mean unit-variance
df <- data.frame(scale(df))

# Split into training and test sets
test.ind = seq(1, nObservations, 10) # 10% data for testing
train.x = data.matrix(df[-test.ind, -which(names(df) %in% c("y"))])
train.y = df[-test.ind, "y"]
test.x = data.matrix(df[test.ind, -which(names(df) %in% c("y"))])
test.y = df[test.ind, "y"]

# Set up mxnet network object
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=nHidden, name="fc1")
act1 <- mx.symbol.Activation(fc1, act_type="tanh", name="act1")
fc2 <- mx.symbol.FullyConnected(act1, num_hidden=1, name="fc2")
lro <- mx.symbol.LinearRegressionOutput(fc2, name="lro")

# Training Networks
# Generic training function to avoid repeating settings
# Train function and then output the tained model, as well as test and train
# executors
# use of executors as per:
# https://github.com/dmlc/mxnet/issues/1152#issuecomment-170563052
trainMxnetRegression <- function(lastLayer, predLayer, modelName) {
  print(paste0("== Training ", modelName, " =="))
  mx.set.seed(0)
  out <- list()
  out$testExecutor <- mx.simple.bind(symbol=predLayer, data=dim(t(test.x)), 
                                     ctx=mx.cpu())
  out$trainExecutor <- mx.simple.bind(symbol=predLayer, data=dim(t(train.x)), 
                                      ctx=mx.cpu())
  out$model <- mx.model.FeedForward.create(lastLayer, X=train.x, y=train.y,
                                       eval.data=list(data=test.x, label=test.y),
                                       ctx=mx.cpu(), num.round=nRounds,
                                       eval.metric=mx.metric.rmse,
                                       verbose=verbose, optimizer=optimizer,
                                       array.batch.size=batchSize,
                                       array.layout=array.layout)
  
  # Set parameter and data for executors:
  mx.exec.update.arg.arrays(out$testExecutor, out$model$arg.params,
                            match.name=TRUE)
  mx.exec.update.aux.arrays(out$testExecutor, out$model$aux.params,
                            match.name=TRUE)
  mx.exec.update.arg.arrays(out$testExecutor, list(data=mx.nd.array(t(test.x))),
                            match.name=TRUE)
  mx.exec.update.arg.arrays(out$trainExecutor, out$model$arg.params,
                            match.name=TRUE)
  mx.exec.update.aux.arrays(out$trainExecutor, out$model$aux.params,
                            match.name=TRUE)
  mx.exec.update.arg.arrays(out$trainExecutor, list(data=mx.nd.array(t(train.x))),
                            match.name=TRUE)
  # Do a forward pass with the current parameters and data
  mx.exec.forward(out$testExecutor, is.train=FALSE)
  mx.exec.forward(out$trainExecutor, is.train=FALSE)
  
  return(out)
}

# Train conventional LRO model
model.lro <- trainMxnetRegression(lro, fc2, "LRO")

# Define squared-error custom loss layer, and train model
se.loss.layer <- mx.symbol.MakeLoss(mx.symbol.square(
  mx.symbol.Reshape(fc2, shape=0) - label), name="se.loss.layer")

model.se.loss <- trainMxnetRegression(se.loss.layer, fc2, "SE Loss")

# Define quartic-error custom loss layer, and train model
qe.loss.layer <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.square(
  mx.symbol.Reshape(fc2, shape=0) - label)), name="qe.loss.layer")

model.qe.loss <- trainMxnetRegression(qe.loss.layer, fc2, "QE Loss")

# Define absolute-error custom loss layer, and train model
mae.loss.layer <- mx.symbol.MakeLoss(mx.symbol.abs(
  mx.symbol.Reshape(fc2, shape=0) - label), name="mae.loss.layer")

model.mae.loss <- trainMxnetRegression(mae.loss.layer, fc2, "MAE Loss")

list.of.models <- list(model.lro=model.lro,
                       model.se.loss=model.se.loss,
                       model.qe.loss=model.qe.loss,
                       model.mae.loss=model.mae.loss)

list.of.metrics <- list(rmse=rmse,
                        rmqe=rmqe,
                        mae=mae)

resultsTrain <- resultsTest <- matrix(data=0, nrow=length(list.of.metrics),
                                      ncol=length(list.of.models))

rownames(resultsTrain) <- rownames(resultsTest) <- names(list.of.metrics)
colnames(resultsTrain) <- colnames(resultsTest) <- names(list.of.models)

for(modelIdx in 1:length(list.of.models)) {
  model <- list.of.models[[modelIdx]]
  for(metricIdx in 1:length(list.of.metrics)) {
    metric <- list.of.metrics[[metricIdx]]
    resultsTest[metricIdx, modelIdx] <- metric(
      test.y, as.array(model$testExecutor$ref.outputs$fc2_output))
    resultsTrain[metricIdx, modelIdx] <- metric(
      train.y, as.array(model$trainExecutor$ref.outputs$fc2_output))
  }
}

# Normalize to LRO results:
resultsTrain <- resultsTrain / resultsTrain[, 1]
resultsTest <- resultsTest / resultsTest[, 1]

# Display results to console
print("Train results:")
print(resultsTrain, digits=2)

print("Test results:")
print(resultsTest, digits=2)

# Plot results
barplot(resultsTrain, beside=TRUE, legend.text=rownames(resultsTrain),
        main="Train Performance")

barplot(resultsTest, beside=TRUE, legend.text=rownames(resultsTest),
        main="Test Performance")

@khalida
Copy link

khalida commented Feb 1, 2017

As pointed out here (regarding training to minimize custom loss functions in Julia) the above works, but is rather limited. In particular I have been unable to find a way to log the training and validation error during training. I tried to follow the example here, but doing so runs into the problem identified above (the output of a network with a custom loss, is the loss itself, not a predicted value).

I have two questions:

  1. Has anyone got an example of a use of MakeLoss which also does logging of the train/validation error during training?

  2. Does anyone have an example of using a mx.io.DataIter in R? I suspect this may help with the problem above, but I have been unable to get this working, and the docs for mx.io.arrayiter in R are very limited.

Any pointers greatly appreciated.

@yxzf
Copy link

yxzf commented Mar 2, 2017

I was also confused how to add custom loss functions in python, just like TensorFlow. Sometimes I feel that some functions in MXNet are black-box...

@thirdwing
Copy link
Contributor

@khalida I have updated the document using your example.

# Network config
optimizer <- "rmsprop"
batchSize <- 60
nRounds <- 50
nHidden <- 14
verbose <- FALSE
array.layout <- "rowmajor"

library(mxnet)
data(BostonHousing, package="mlbench")
BostonHousing[, sapply(BostonHousing, is.factor)] <-
  as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
BostonHousing <- data.frame(scale(BostonHousing))

test.ind = seq(1, 506, 5)    # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind, -14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[--test.ind, -14])
test.y = BostonHousing[--test.ind, 14]
data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden=nHidden, name="fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type="tanh", name="tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden=1, name="fc2")
lro <- mx.symbol.LinearRegressionOutput(fc2, name="lro")

mx.set.seed(0)
model <- mx.model.FeedForward.create(lro,
                                     X=train.x, y=train.y,
                                     eval.data=list(data=test.x, label=test.y),
                                     ctx=mx.cpu(), num.round=nRounds,
                                     array.batch.size=batchSize,
                                     eval.metric=mx.metric.rmse,
                                     optimizer=optimizer, verbose=verbose,
                                     array.layout=array.layout)

pred <- predict(model, test.x)

lro2 <- mx.symbol.MakeLoss(mx.symbol.square(mx.symbol.Reshape(fc2, shape = 0) - label), name="lro2")

mx.set.seed(0)
model2 <- mx.model.FeedForward.create(lro2,
                                     X=train.x, y=train.y,
                                     eval.data=list(data=test.x, label=test.y),
                                     ctx=mx.cpu(), num.round=nRounds,
                                     array.batch.size=batchSize,
                                     eval.metric=mx.metric.rmse,
                                     optimizer=optimizer, verbose=verbose,
                                     array.layout=array.layout)


internals = internals(model2$symbol)
fc_symbol = internals[[match("fc2_output", outputs(internals))]]

model3 <- list(symbol = fc_symbol,
               arg.params = model2$arg.params,
               aux.params = model2$aux.params)

class(model3) <- "MXFeedForwardModel"

pred3 <- predict(model3, test.x)

# Plotting of fits
par(mfrow=c(1,2))

# Train fits
plot(test.y, pred[1,], main="nnet Train Fit", xlab="Target", ylab="Response")
abline(0,1, col="red", lwd=2)

plot(test.y, pred3[1,], main="nnet MakeLoss square Train Fit", xlab="Target",  ylab="Response")
abline(0,1, col="red", lwd=2)

res

The output of mx.symbol.MakeLoss is the gradient of loss with respect to the input data.

So currently the metric doesn't work with MakeLoss during the training process.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants