Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/stratified k fold #95

Merged
merged 11 commits into from
Oct 2, 2023
Merged

Feat/stratified k fold #95

merged 11 commits into from
Oct 2, 2023

Conversation

L-M-Sherlock
Copy link
Member

@L-M-Sherlock L-M-Sherlock added the enhancement New feature or request label Oct 1, 2023
src/optimal_retention.rs Outdated Show resolved Hide resolved
src/training.rs Outdated Show resolved Hide resolved
@dae
Copy link
Collaborator

dae commented Oct 1, 2023

the progress bar doesn't work as before

We won't be able to use the built-in progress bar - I'll need to rewrite the progress code to handle the fact that training is happening in multiple threads at once.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

There is an issue in burn that is preventing us from getting optimum performance. I've sent them a PR: tracel-ai/burn#839

Currently n_splits is hard-coded to 5, which means that on CPUs with more cores, we wouldn't be fully taking advantage of them. Would it make sense to allow splits to go up to 16? If so, I can do that in a follow-up PR that determines the splits based on available CPU cores.

src/training.rs Outdated Show resolved Hide resolved
@dae
Copy link
Collaborator

dae commented Oct 2, 2023

split=16
metric: LogLoss
FSRS-rs mean: 0.3899
metric: RMSE
FSRS-rs mean: 0.3345
metric: RMSE(bins)
FSRS-rs mean: 0.0624

split=5
metric: LogLoss
FSRS-rs mean: 0.3894
metric: RMSE
FSRS-rs mean: 0.3343
metric: RMSE(bins)
FSRS-rs mean: 0.0612

split=1
metric: LogLoss
FSRS-rs mean: 0.3878
metric: RMSE
FSRS-rs mean: 0.3337
metric: RMSE(bins)
FSRS-rs mean: 0.0598

Maybe we could look at the size of the dataset, and use that to help decide how many splits? Eg try to avoid creating splits with less than 5,000 items.

@L-M-Sherlock
Copy link
Member Author

Currently n_splits is hard-coded to 5

Because the py optimizer applies n_splits=5. I want to compare them in the benchmark with the same hype-parameters.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

I tried with: let n_splits = 16.min(items.len() / 5000).max(1);, and it looks like we get the best of both worlds: better performance in large revlogs, and weights equal to the split=1 case:

metric: LogLoss
FSRS-rs mean: 0.3877
metric: RMSE
FSRS-rs mean: 0.3337
metric: RMSE(bins)
FSRS-rs mean: 0.0597

What if we adjusted the Python optimizer to use the same logic in calculating splits, so you could still compare the two? In Anki I think 16 should be detected from the current cpu's core count, but we could use an env var to force it to be a specific number, so the tests are consistent.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

Also there's currently one inconsistency: revlog_history-60d332b.json is generated for FSRS-rs but not the others for some reason. When removed, things improve slightly more:

metric: LogLoss
FSRS-rs mean: 0.3859
metric: RMSE
FSRS-rs mean: 0.3328
metric: RMSE(bins)
FSRS-rs mean: 0.0587

@L-M-Sherlock
Copy link
Member Author

Here is the result of the latest commit:

Model: FSRS-rs
Total number of users: 71
Total number of reviews: 4632965
metric: LogLoss
FSRS-rs mean: 0.3858
metric: RMSE
FSRS-rs mean: 0.3327
metric: RMSE(bins)
FSRS-rs mean: 0.0594

I will test let n_splits = 16.min(items.len() / 5000).max(1);.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

It seems 1000 reviews is not enough:

Model: FSRS-rs
Total number of users: 71
Total number of reviews: 4632965
metric: LogLoss
FSRS-rs mean: 0.3874
metric: RMSE
FSRS-rs mean: 0.3334
metric: RMSE(bins)
FSRS-rs mean: 0.0613

Will try with 10,000 for comparison.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

By the way, AnkiWeb's terms and conditions allow us to extract revlog data from AnkiWeb users. I could potentially build a dataset much bigger than 71 if that were helpful (though it may take me a while to do).

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

10,000 review split:

Total number of users: 71
Total number of reviews: 4632965
metric: LogLoss
FSRS-rs mean: 0.3858
metric: RMSE
FSRS-rs mean: 0.3327
metric: RMSE(bins)
FSRS-rs mean: 0.0586

So it seems 5k is a reasonable cutoff?

@L-M-Sherlock
Copy link
Member Author

Weird. Usually the benchmark only costs me 15 mins, but it costs almost one hours in this time.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

The change I suggested also needs to be applied to the non-test case:

diff --git a/src/training.rs b/src/training.rs
index 14e4d03..a8c59c7 100644
--- a/src/training.rs
+++ b/src/training.rs
@@ -18,7 +18,7 @@ use burn::train::{ClassificationOutput, TrainOutput, TrainStep, TrainingInterrup
 use burn::{config::Config, module::Param, tensor::backend::ADBackend, train::LearnerBuilder};
 use core::marker::PhantomData;
 use log::info;
-use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
+use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator, IntoParallelIterator};
 use std::path::Path;
 use std::sync::{Arc, Mutex};

@@ -206,7 +206,8 @@ impl<B: Backend> FSRS<B> {
         items: Vec<FSRSItem>,
         progress: Option<Arc<Mutex<ProgressState>>>,
     ) -> Result<Vec<f32>> {
-        let n_splits = 5;
+        let n_splits = 16.min(items.len() / 10000).max(1);
+        dbg!(n_splits);
         let average_recall = calculate_average_recall(&items);
         let (pre_trainset, trainsets) = split_data(items, n_splits);
         let initial_stability = pretrain(pre_trainset, average_recall)?;
@@ -218,24 +219,17 @@ impl<B: Backend> FSRS<B> {
             AdamConfig::new(),
         );

-        let mut weights_sets: Vec<Vec<f32>> = Vec::new();
-
-        for i in 0..n_splits {
-            let trainset = trainsets
-                .par_iter()
-                .enumerate()
-                .filter(|&(j, _)| j != i)
-                .flat_map(|(_, trainset)| trainset.clone())
-                .collect();

+        let weights_sets: Vec<Vec<f32>> = (0..n_splits).into_par_iter().map(|i| {
+            let trainset = trainsets[i].clone();
             let model = train::<ADBackendDecorator<B>>(
                 trainset,
                 &config,
                 self.device(),
                 progress.clone().map(ProgressCollector::new),
             );
-            weights_sets.push(model?.w.val().to_data().convert().value)
-        }
+            model.unwrap().w.val().to_data().convert().value
+        }).collect();

         let average_weights = weights_sets
             .iter()

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

10000 should be 5000

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

Oh, and you also need to be using the changes in tracel-ai/burn#839 to get optimum performance. It might be worth waiting a day or two to see if that gets merged.

@L-M-Sherlock
Copy link
Member Author

+            let trainset = trainsets[i].clone();

It's inconstant with the previous code.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

#95 (comment)

@L-M-Sherlock
Copy link
Member Author

L-M-Sherlock commented Oct 2, 2023

KFold divides all the samples in $k$ groups of samples, called folds (if $k=n$, this is equivalent to the Leave One Out strategy), of equal sizes (if possible). The prediction function is learned using $k-1$ folds, and the fold left out is used for test.

@dae
Copy link
Collaborator

dae commented Oct 2, 2023

I'm afraid I don't have enough domain knowledge to know what you mean by that. I also noticed that we're not using any verification items when training - is that related?

@L-M-Sherlock
Copy link
Member Author

The verification items will be used for selecting the best weights when we implement #88 (comment)

@L-M-Sherlock
Copy link
Member Author

n_splits = 4:

Model: FSRS-rs
Total number of users: 71
Total number of reviews: 4632965
metric: LogLoss
FSRS-rs mean: 0.3859
metric: RMSE
FSRS-rs mean: 0.3327
metric: RMSE(bins)
FSRS-rs mean: 0.0593

@L-M-Sherlock
Copy link
Member Author

L-M-Sherlock commented Oct 2, 2023

pub(crate) struct TrainingConfig {
    pub model: ModelConfig,
    pub optimizer: AdamConfig,
    #[config(default = 5)]
    pub num_epochs: usize,
    #[config(default = 512)]
    pub batch_size: usize,
    #[config(default = 1)]
    pub num_workers: usize,
    #[config(default = 42)]
    pub seed: u64,
    #[config(default = 4e-2)]
    pub learning_rate: f64,
}

n_splits = 5

Model: FSRS-rs
Total number of users: 71
Total number of reviews: 4632965
metric: LogLoss
FSRS-rs mean: 0.3853
metric: RMSE
FSRS-rs mean: 0.3325
metric: RMSE(bins)
FSRS-rs mean: 0.0582

It seems to be the best result.

@asukaminato0721 asukaminato0721 merged commit 0aac1cf into main Oct 2, 2023
3 checks passed
@asukaminato0721 asukaminato0721 deleted the Feat/StratifiedKFold branch October 2, 2023 13:39
L-M-Sherlock added a commit to open-spaced-repetition/srs-benchmark that referenced this pull request Oct 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants