Skip to content

Commit

Permalink
sqrt(count) as weights for pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 16, 2023
1 parent c0937b4 commit 4fcfeb5
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/pre_training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,9 @@ fn loss(
let y_pred = power_forgetting_curve(delta_t, init_s0);
let logloss = (-(recall * y_pred.clone().mapv_into(|v| v.ln())
+ (1.0 - recall) * (1.0 - &y_pred).mapv_into(|v| v.ln()))
* count
/ count.sum())
* count.mapv(|v| v.sqrt()))
.sum();
let l1 = (init_s0 - default_s0).abs() / count.sum() / 16.0;
let l1 = (init_s0 - default_s0).abs() / 16.0;
logloss + l1
}

Expand Down Expand Up @@ -324,8 +323,8 @@ mod tests {
let count = Array1::from(vec![100.0, 100.0, 100.0]);
let init_s0 = 1.0;
let actual = loss(&delta_t, &recall, &count, init_s0, init_s0);
assert_eq!(actual, 0.45414436);
assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 0.48402837);
assert_eq!(actual, 13.6243305);
assert_eq!(loss(&delta_t, &recall, &count, 2.0, init_s0), 14.577101);
}

#[test]
Expand Down Expand Up @@ -356,7 +355,7 @@ mod tests {
],
)]);
let actual = search_parameters(pretrainset, 0.9);
let expected = [(4, 1.2733965)].into_iter().collect();
let expected = [(4, 1.452559)].into_iter().collect();
assert_eq!(actual, expected);
}

Expand All @@ -368,7 +367,7 @@ mod tests {
let pretrainset = split_data(items, 1).0;
assert_eq!(
pretrain(pretrainset, average_recall).unwrap(),
[0.89360625, 1.6562619, 4.1792974, 9.724018],
[0.9773208, 1.7733966, 4.3346996, 14.136393],
)
}

Expand Down

0 comments on commit 4fcfeb5

Please sign in to comment.