Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation - Padé Approximant in Log Softmax Layer #3685

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

MarkFischinger
Copy link
Contributor

Following our discussion in issue #3662, I've implemented the pade approximant in the log softmax layer. Due to time constraints, I haven't run the tests yet, but I plan to do so shortly and update you with the results.

@MarkFischinger MarkFischinger changed the title Implementation - Pade Approximant in Log Softmax Layer Implementation - Padé Approximant in Log Softmax Layer Apr 10, 2024
Copy link
Member

@shrit shrit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you use mlpack style for variables ?

@MarkFischinger
Copy link
Contributor Author

@shrit, thank you for pointing that out. The new commit includes the fix :)

@shrit
Copy link
Member

shrit commented Apr 11, 2024

I approved this one too quickly, I did not see the that the tests were not passing.
@MarkFischinger could you try to run the tests locally ?
I would be nice to compare the matrices generated by the original fast method and Padé because I think there are a good amount of difference, otherwise the tests would not have failed ?

Copy link
Member

@rcurtin rcurtin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just preventing mlpack-bot from auto-approving until we get the fixes worked out. I'm guessing that the level of approximation might be too high, and things are not converging? (Maybe a threshold like x < 13 is needed?)

};

output.transform([padeApproximant](double x) {
return padeApproximant(x);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit cleaner to just inline the whole approximant into the lambda, but, up to you.

@MarkFischinger
Copy link
Contributor Author

@shrit @rcurtin Sorry for the delay with the benchmarks -I needed to run a more detailed analysis to find an effective solution.
Here’s what I found:

Mnist Simple
Old Implementation:
Validation Loss: 567.577
Duration: 79709ms
Loss: 0.0331713
Accuracy: Train = 98.463%, Valid = 97.0707%

New Implementation (Scale 4, x < 13.0):
Validation Loss: 563.98
Duration: 80412ms
Loss: 0.049433
Accuracy: Train = 98.3571%, Valid = 96.9517%

The initial idea of adding only x < 13.0 proved too broad, leading to uncontrolled error spikes due to the large $X$ values I had been concerned about. In the discussion issue example, it featured only small X values, which worked perfectly with the Padé approximation, but large values (above 8.0) do not work quite well. But by scaling $X$ by $4$, I reduced the error, now notably smaller than in the old version, as you can see in this graph:

errors_and_time_4_4_fair

Despite the graph showing a seemingly doubled duration in runtime, the actual difference in the cnn run is minor. This improvement could be a viable option for implementation? What do you think?

auto scaledPadeApproxExpMinusX = [](double x) {
    if (x < 13.0) {
      double s = 4.0;
  
      double xs = x / s;
  
      double numerator = 24 - 12*xs + 4*xs*xs - xs*xs*xs;
      double denominator = 24 + 12*xs + 4*xs*xs + xs*xs*xs;
  
      double pade = numerator / denominator;
      return std::pow(pade, s);
    }
  
    return 0.0;
  };

  output.transform([scaledPadeApproxExpMinusX](double x) {
    return scaledPadeApproxExpMinusX(x);
  });

I think I will also test the algorithm on mnist_cnn soon.

@shrit
Copy link
Member

shrit commented Apr 16, 2024

@MarkFischinger give it a try, what I find weird is that, when we tested this separately the time was way faster, while here it looks much slower than the original one.
This worth investigating.

@MarkFischinger
Copy link
Contributor Author

Hey @shrit, I think the trouble we're seeing comes from the higher $X$ values in our MNIST examples. Originally, we only saw $X$ values above $4$ about $2.275$% of the time, based on our normal distribution setup with arma::mat output = arma::randn(1000, 1000, arma::distr_param(0, 2)). But the MNIST data showed much higher $X$ values frequently, which caused those spikes and ultimately broke the code. That's why I had to scale them down, which did slow our runtime a bit.

Here are the $X$ values for the mnist example (output):

x = 8.78431
x = 23.1975
x = 22.0821
x = 16.2784
x = 18.5682

I'm thinking, since lower $X$ values are more common and they handle better, maybe we should try the original Padé approximation for values up to say, $4$, and keep our current method as a backup for anything higher. This way, we can handle typical cases fast and still catch any outliers without any problems. What do you think? Should I run some benchmarks on this mixed approach?

@rcurtin
Copy link
Member

rcurtin commented Apr 16, 2024

Yeah, a switch to the existing implementation at about x > 4 would probably do the trick for convergence too. I would be interested to see if it would be faster, too---although, to check that, you'd need to ensure that the number of epochs used for training are constant (or, just time a single epoch, that's fine too).

The scaling trick is definitely a good one for convergence, but I suspect the std::pow is painful and what causes it to be slower.

@MarkFischinger
Copy link
Contributor Author

@rcurtin I did some backtesting, and the results showed unfortunately no/only minor improvements in runtime. Statistically, combining those two algorithms should reduce the error, but I'm still looking for faster implementations because I'm hopeful that I can find a better solution. I'll update you as soon as possible, though my available time will be limited for the next few days due to the exams :/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants