-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Ensmallen Callbacks To train LSTM and Fit only on training data. #56
Conversation
Earlier Predicted values were
and
Nearly similar values were obtained (These may vary machine to machine). So I think changes made till now (after limiting epochs) haven't really changed the results. |
After sampling save results function, The new results are:
and
These are similar to earlier values. |
Hi @zoq, Could you take a look at this. I think this is ready. I have repeated the tests twice to ensure that this works. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @kartikdutt18, thanks for taking the time to work on this one. The changes look good to me. I have a handful of comments but they're all pretty minor. Let me know if I can clarify any of them. 👍
LSTM/TimeSeries-Multivariate/src/LSTMTimeSeriesMultivariate.cpp
Outdated
Show resolved
Hide resolved
ens::PrintLoss(), | ||
ens::ProgressBar(), | ||
ens::EarlyStopAtMinLoss(), | ||
ens::StoreBestCoordinates<arma::mat>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does StoreBestCoordinates
do anything here? It looks like you are creating a temporary StoreBestCoordinates()
object to pass to the optimizer, but then there is no way to get the result out since it is temporary. Take a look at the ensmallen callback documentation; I think it has a nice example of using an instantiated StoreBestCoordinates
callback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I removing it for now, I think the issue doesn't warrant StoreBestCoordinates()
function however I think I need to use it locally to understand it a bit more clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I think it's fine with or without StoreBestCoordinates()
.
LSTM/TimeSeries-Multivariate/src/LSTMTimeSeriesMultivariate.cpp
Outdated
Show resolved
Hide resolved
Hi @rcurtin, Thanks for the review. I have made changes that you suggested. |
dataset.n_cols - 1)); | ||
|
||
// Number of epochs for training. | ||
const int EPOCHS = 500; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, just a minor note, and whether or not you handle it is up to you. Previously, we were doing 500 epochs, but they weren't really epochs because we were only looking at 1000 points in each "epoch". Now we're using the whole dataset, and since one epoch now is equivalent to an entire pass (not just 1000 points), 500 epochs is a lot more training. I wonder if maybe 100 is a better number here. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make changes but I would have to test it first. Once I do I'll post a comment for the same. I hope that's okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, no problem. If 500 epochs are actually necessary we can go with that, I just figured it would generally terminate way before then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my laptop, it took 107 epochs to stop. So I have changed it to 150 (A nice round number). If needed I can reduce it further. Thanks for the suggestion.
ens::PrintLoss(), | ||
ens::ProgressBar(), | ||
ens::EarlyStopAtMinLoss(), | ||
ens::StoreBestCoordinates<arma::mat>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I think it's fine with or without StoreBestCoordinates()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @kartikdutt18! I only have a few comments left. If you want to handle them before merge (or let me know what you think), that would be awesome. 👍
// Progressbar Callback prints progress bar for each epoch. | ||
ens::ProgressBar(), | ||
// Stops the optimization process if the loss stops decreasing | ||
// or no improvement has been made. Useful in preventing overfitting. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a really pedantic comment but I think it's worth making---actually this is not useful for preventing overfitting in the way that people might expect. Normally you might terminate the optimization when the loss on a validation set starts increasing (but the training set loss will keep going down). However, EarlyStopAtMinLoss()
only considers the training set. So basically EarlyStopAtMinLoss()
will terminate if the training error starts going up. In my opinion the argument to say this is useful in preventing overfitting is a little weak, and I think a more accurate thing to say might be that this will terminate the optimization once we hit a minimum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, You are absolutely right, This will only get to a minima on training data. Nice I never looked at it that way. Thanks.
|
||
// Don't reset optimizer's parameters between cycles. | ||
optimizer.ResetPolicy() = false; | ||
// Use Early Stopping criteria to stop training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another pedantic comment---this line itself doesn't specify that early stopping should be used, that's the callback. I think a more effective comment might be:
// Instead of terminating based on the tolerance of the objective function, we'll depend on
// the maximum number of iterations, and terminate early using the EarlyStopAtMinLoss
// callback.
(I didn't check if those are 80 characters, so you might need to reflow it if you want to use that text directly. Feel free to adapt it if you want to improve the wording or anything. 👍)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, I will make the changes for the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Second approval provided automatically after 24 hours. 👍
Hi @rcurtin, Since I made some changes that you suggested I tested it again and I got the following results:
Similarly for Univariate results were:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I have made some changes that I think might be better than what we were doing. Kindly have a look and I'll be more than happy to learn more about them / revert them.
Also, sorry for the delay.
Thanks for all the help.
@@ -57,7 +58,7 @@ double MSE(arma::cube& pred, arma::cube& Y) | |||
arma::mat temp = diff.slice(i); | |||
err_sum += accu(temp%temp); | |||
} | |||
return (err_sum) / (diff.n_elem + 1e-50); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @rcurtin, Is there a reason why we need to do this?
I don't think diff.n_elem can ever be zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mlpack
has L2
distance in core/metric
, would it be okay to switch to that rather using for loop to calculate it. I think making our own L2 is a bit redundant. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, this is a lot cleaner. Thanks! 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And yeah, I agree, diff.n_elem
should never be 0, so the + 1e-50
shouldn't ever be necessary.
y.set_size(outputSize, dataset.n_cols - rho + 1, rho); | ||
// Split the dataset into training and validation sets. | ||
arma::mat trainData, testData; | ||
data::Split(dataset, trainData, testData, RATIO); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have also switched to Split
function from mlpack. I hope that's okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, actually I'm not sure on this one. The Split()
function shuffles the data during splitting, but for a time series problem like this it makes more sense to keep the first part of the data as the training set, and then the last part of the data as the test set. Unfortunately I don't see any option to pass to Split()
to avoid shuffling (maybe it would be good to add one? I don't know how important it would be), so I think we should revert this bit, and then I can go ahead and merge it. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are right, I recently learned about this when I went through the code base for data for creating a DataLoader. This should definitely be reverted. Thanks for pointing it out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made the changes for the same and tested it for epochs. Thanks for all the help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks. I'll go ahead and merge it once it passes the tests. 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also help me in answering this question, I was creating a DataLoader, So for time series analysis since input is transformed into output, does it make sense to fit only on training input and not all training data?
We would however transform both training input and output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately I don't see any option to pass to
Split()
to avoid shuffling (maybe it would be good to add one? I don't know how important it would be)
I think optional shuffle will be a good idea. Should be simple to resolve this. Would it be okay for me to open a PR for this?
Hmm, I think this is test failure is related to cmake changes added in mlpack and unrelated to models repo. I might be missing something though. I'll take a closer look just in case it is related to this PR. |
Yes, I think that due to mlpack#2247, the patch |
Oh, wait, sorry, I didn't see #58 which solves that issue. Maybe we can just merge that, rebase this branch, and it should work. :) |
Agreed, That would be great. Thanks. |
Simplify Save Results Changed Path for data, removed unnecc header and added comment for callbacks Removed optimizer reset Changed epochs, better comments Changed to internal split, reduced epochs Changed path Switch to L2 in MSE
5da4a76
to
9f0e92f
Compare
Hi, I have rebased this. I think this should pass the tests now. Thanks. |
I think this fits well with the other changes from models->examples, or at least doesn't conflict with them. If we're reworking LTSM, I think it would make more sense to do so after this is merged, rather than before. |
Agreed, That makes sense. This is ready from my side. Thanks a lot. |
Oops, I didn't realize that this wasn't merged before the repository split! So I merged it here then cherry-picked the commits into the |
No worries, I think the models repo would be restructured with mlpack/models#3, so I'll try to take of anything that is needed there. Thanks a lot @rcurtin, @birm for the helpful reviews and comments. |
Hi everyone,
This closes #41, closes #42 and also closes #43.
I have used Ensmallen Callbacks instead of for loop to train LSTMs.
Another good idea that @shrit suggested is to use mlpack's CLI functionality to improve access to models from command line.
This has been built and locally tested.
Thanks.