-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/stratified k fold #95
Conversation
We won't be able to use the built-in progress bar - I'll need to rewrite the progress code to handle the fact that training is happening in multiple threads at once. |
There is an issue in burn that is preventing us from getting optimum performance. I've sent them a PR: tracel-ai/burn#839 Currently n_splits is hard-coded to 5, which means that on CPUs with more cores, we wouldn't be fully taking advantage of them. Would it make sense to allow splits to go up to 16? If so, I can do that in a follow-up PR that determines the splits based on available CPU cores. |
split=16 split=5 split=1 Maybe we could look at the size of the dataset, and use that to help decide how many splits? Eg try to avoid creating splits with less than 5,000 items. |
Because the py optimizer applies |
I tried with: metric: LogLoss What if we adjusted the Python optimizer to use the same logic in calculating splits, so you could still compare the two? In Anki I think 16 should be detected from the current cpu's core count, but we could use an env var to force it to be a specific number, so the tests are consistent. |
Also there's currently one inconsistency: revlog_history-60d332b.json is generated for FSRS-rs but not the others for some reason. When removed, things improve slightly more: metric: LogLoss |
Here is the result of the latest commit: Model: FSRS-rs I will test |
It seems 1000 reviews is not enough: Model: FSRS-rs Will try with 10,000 for comparison. |
By the way, AnkiWeb's terms and conditions allow us to extract revlog data from AnkiWeb users. I could potentially build a dataset much bigger than 71 if that were helpful (though it may take me a while to do). |
10,000 review split: Total number of users: 71 So it seems 5k is a reasonable cutoff? |
Weird. Usually the benchmark only costs me 15 mins, but it costs almost one hours in this time. |
The change I suggested also needs to be applied to the non-test case: diff --git a/src/training.rs b/src/training.rs
index 14e4d03..a8c59c7 100644
--- a/src/training.rs
+++ b/src/training.rs
@@ -18,7 +18,7 @@ use burn::train::{ClassificationOutput, TrainOutput, TrainStep, TrainingInterrup
use burn::{config::Config, module::Param, tensor::backend::ADBackend, train::LearnerBuilder};
use core::marker::PhantomData;
use log::info;
-use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
+use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator, IntoParallelIterator};
use std::path::Path;
use std::sync::{Arc, Mutex};
@@ -206,7 +206,8 @@ impl<B: Backend> FSRS<B> {
items: Vec<FSRSItem>,
progress: Option<Arc<Mutex<ProgressState>>>,
) -> Result<Vec<f32>> {
- let n_splits = 5;
+ let n_splits = 16.min(items.len() / 10000).max(1);
+ dbg!(n_splits);
let average_recall = calculate_average_recall(&items);
let (pre_trainset, trainsets) = split_data(items, n_splits);
let initial_stability = pretrain(pre_trainset, average_recall)?;
@@ -218,24 +219,17 @@ impl<B: Backend> FSRS<B> {
AdamConfig::new(),
);
- let mut weights_sets: Vec<Vec<f32>> = Vec::new();
-
- for i in 0..n_splits {
- let trainset = trainsets
- .par_iter()
- .enumerate()
- .filter(|&(j, _)| j != i)
- .flat_map(|(_, trainset)| trainset.clone())
- .collect();
+ let weights_sets: Vec<Vec<f32>> = (0..n_splits).into_par_iter().map(|i| {
+ let trainset = trainsets[i].clone();
let model = train::<ADBackendDecorator<B>>(
trainset,
&config,
self.device(),
progress.clone().map(ProgressCollector::new),
);
- weights_sets.push(model?.w.val().to_data().convert().value)
- }
+ model.unwrap().w.val().to_data().convert().value
+ }).collect();
let average_weights = weights_sets
.iter() |
10000 should be 5000 |
Oh, and you also need to be using the changes in tracel-ai/burn#839 to get optimum performance. It might be worth waiting a day or two to see if that gets merged. |
It's inconstant with the previous code. |
KFold divides all the samples in |
I'm afraid I don't have enough domain knowledge to know what you mean by that. I also noticed that we're not using any verification items when training - is that related? |
The verification items will be used for selecting the best weights when we implement #88 (comment) |
n_splits = 4: Model: FSRS-rs |
pub(crate) struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 5)]
pub num_epochs: usize,
#[config(default = 512)]
pub batch_size: usize,
#[config(default = 1)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 4e-2)]
pub learning_rate: f64,
} n_splits = 5 Model: FSRS-rs It seems to be the best result. |
Port https://github.com/open-spaced-repetition/fsrs-optimizer/blob/fa63f61630a3e6a347e0576fc46b60e4561ac15e/src/fsrs_optimizer/fsrs_optimizer.py#L1049-L1082
also #88 (comment)
Consider to speed it up via #91 (comment)
By the way, the progress bar doesn't work as before. @dae