Skip to content

Commit

Permalink
Run retention calculation in parallel (#91)
Browse files Browse the repository at this point in the history
* Run retention calculation in parallel

Roughtly 10x speed-up if you have enough cores on your CPU.

* Use DEFAULT_WEIGHTS in more places
  • Loading branch information
dae committed Sep 30, 2023
1 parent fa39f0b commit 5d67e1c
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 27 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ log = "0.4"
ndarray = "0.15.6"
ndarray-rand = "0.14.0"
rand = "0.8.5"
rayon = "1.8.0"
serde = "1.0.188"
snafu = "0.7.5"
strum = { version = "0.25.0", features = ["derive"] }
Expand Down
2 changes: 1 addition & 1 deletion src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use burn::tensor::ElementConversion;
/// This is a slice for efficiency, but should always be 17 in length.
pub type Weights = [f32];

pub static DEFAULT_WEIGHTS: &[f32] = &[
pub static DEFAULT_WEIGHTS: [f32; 17] = [
0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29,
2.61,
];
Expand Down
14 changes: 4 additions & 10 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<B: Backend> Model<B> {
pub fn new(config: ModelConfig) -> Self {
let initial_params = config
.initial_stability
.unwrap_or([0.4, 0.6, 2.4, 5.8])
.unwrap_or(<[f32; 4]>::try_from(&DEFAULT_WEIGHTS[0..4]).unwrap())
.into_iter()
.chain([
4.93, 0.94, 0.86, 0.01, // difficulty
Expand Down Expand Up @@ -216,7 +216,7 @@ impl<B: Backend> FSRS<B> {
) -> Result<FSRS<B2>> {
if let Some(weights) = &mut weights {
if weights.is_empty() {
*weights = DEFAULT_WEIGHTS
*weights = DEFAULT_WEIGHTS.as_slice()
} else if weights.len() != 17 {
return Err(FSRSError::InvalidWeights);
}
Expand Down Expand Up @@ -257,13 +257,7 @@ mod tests {
#[test]
fn w() {
let model = Model::new(ModelConfig::default());
assert_eq!(
model.w.val().to_data(),
Data::from([
0.4, 0.6, 2.4, 5.8, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34,
1.26, 0.29, 2.61
])
)
assert_eq!(model.w.val().to_data(), Data::from(DEFAULT_WEIGHTS))
}

#[test]
Expand Down Expand Up @@ -370,6 +364,6 @@ mod tests {
fn fsrs() {
assert!(FSRS::new(Some(&[])).is_ok());
assert!(FSRS::new(Some(&[1.])).is_err());
assert!(FSRS::new(Some(DEFAULT_WEIGHTS)).is_ok());
assert!(FSRS::new(Some(DEFAULT_WEIGHTS.as_slice())).is_ok());
}
}
53 changes: 37 additions & 16 deletions src/optimal_retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use rand::{
rngs::StdRng,
SeedableRng,
};
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
use std::sync::{Arc, Mutex};
use strum::EnumCount;

#[derive(Debug, EnumCount)]
Expand Down Expand Up @@ -348,10 +351,10 @@ impl<B: Backend> FSRS<B> {
mut progress: F,
) -> Result<f64>
where
F: FnMut(ItemProgress) -> bool,
F: FnMut(ItemProgress) -> bool + Send,
{
let weights = if weights.is_empty() {
DEFAULT_WEIGHTS
&DEFAULT_WEIGHTS
} else if weights.len() != 17 {
return Err(FSRSError::InvalidWeights);
} else {
Expand All @@ -365,35 +368,53 @@ impl<B: Backend> FSRS<B> {
let mut optimal_retention = 0.85;
let epsilon = 0.01;
let mut iter = 0;

let mut progress_info = ItemProgress {
current: 0,
total: 10,
total: 100,
};
let inc_progress = Arc::new(Mutex::new(move || {
progress_info.current += 1;
progress(progress_info)
}));

while high - low > epsilon && iter < 10 {
iter += 1;
progress_info.current += 1;
let mid1 = low + (high - low) / 3.0;
let mid2 = high - (high - low) / 3.0;
let sample_several = |n, mid| {
(0..n)
.map(|i| simulate(config, &weights, mid, Some((i + 42).try_into().unwrap())))
.sum::<f64>()
/ n as f64

let sample_several = |n: usize, mid: f64| -> Result<f64, FSRSError> {
let out: Vec<f64> = (0..n)
.into_par_iter()
.map(|i| {
let result =
simulate(config, &weights, mid, Some((i + 42).try_into().unwrap()));
if !(inc_progress.lock().unwrap()()) {
return Err(FSRSError::Interrupted);
}
Ok(result)
})
.collect::<Result<Vec<_>, _>>()?;
Ok(out.iter().sum::<f64>() / n as f64)
};
let memorization1 = sample_several(5, mid1);
let memorization2 = sample_several(5, mid2);

if memorization1 > memorization2 {
let mut memorization1 = None;
let mut memorization2 = None;
rayon::scope(|s| {
s.spawn(|_| {
memorization1 = Some(sample_several(5, mid1));
});
s.spawn(|_| {
memorization2 = Some(sample_several(5, mid2));
});
});
if memorization1.unwrap()? > memorization2.unwrap()? {
high = mid2;
} else {
low = mid1;
}

optimal_retention = (high + low) / 2.0;
// dbg!(iter, optimal_retention);
if !(progress(progress_info)) {
return Err(FSRSError::Interrupted);
}
}
Ok(optimal_retention)
}
Expand Down

0 comments on commit 5d67e1c

Please sign in to comment.