Skip to content

Commit

Permalink
Feat/update default weights & flat power forgetting curve (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Dec 15, 2023
1 parent 2ca347d commit 8ac5d5b
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 70 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.19.2"
version = "4.20.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
88 changes: 47 additions & 41 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,19 @@
import warnings

try:
from .fsrs_simulator import optimal_retention, simulate
except ImportError:
from fsrs_simulator import optimal_retention, simulate
from .fsrs_simulator import (
optimal_retention,
simulate,
next_interval,
power_forgetting_curve,
)
except:
from fsrs_simulator import (
optimal_retention,
simulate,
next_interval,
power_forgetting_curve,
)

warnings.filterwarnings("ignore", category=UserWarning)

Expand All @@ -34,13 +44,25 @@
Review = 2
Relearning = 3


def power_forgetting_curve(t, s):
return (1 + t / (9 * s)) ** -1


def next_interval(s, r):
return np.maximum(1, np.round(9 * s * (1 / r - 1)))
DEFAULT_WEIGHT = [
0.27,
0.74,
1.3,
5.52,
5.1,
1.02,
0.78,
0.06,
1.57,
0.14,
0.94,
2.16,
0.06,
0.31,
1.34,
0.21,
2.69,
]


class FSRS(nn.Module):
Expand Down Expand Up @@ -629,7 +651,10 @@ def remove_outliers(group: pd.DataFrame) -> pd.DataFrame:
has_been_removed += count
group = group[
group["delta_t"].isin(
grouped_group[grouped_group[("y", "count")] >= count]["delta_t"]
grouped_group[
(grouped_group[("y", "count")] > count)
& (grouped_group[("y", "mean")] < 1)
]["delta_t"]
)
]
return group
Expand Down Expand Up @@ -776,25 +801,7 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:

def define_model(self):
"""Step 3"""
self.init_w = [
0.4,
0.9,
2.3,
10.9,
4.93,
0.94,
0.86,
0.01,
1.49,
0.14,
0.94,
2.18,
0.05,
0.34,
1.26,
0.29,
2.61,
]
self.init_w = DEFAULT_WEIGHT.copy()
"""
For details about the parameters, please see:
https://github.com/open-spaced-repetition/fsrs4anki/wiki/The-Algorithm
Expand All @@ -821,7 +828,7 @@ def pretrain(self, dataset=None, verbose=True):
rating_count = {}
average_recall = self.dataset["y"].mean()
plots = []
r_s0_default = {"1": 0.4, "2": 0.9, "3": 2.3, "4": 10.9}
r_s0_default = {str(i): DEFAULT_WEIGHT[i - 1] for i in range(1, 5)}

for first_rating in ("1", "2", "3", "4"):
group = self.S0_dataset_group[
Expand All @@ -837,30 +844,29 @@ def pretrain(self, dataset=None, verbose=True):
group["y"]["count"] + 1
)
count = group["y"]["count"]
total_count = sum(count)
weight = np.sqrt(count)

init_s0 = r_s0_default[first_rating]

def loss(stability):
y_pred = power_forgetting_curve(delta_t, stability)
logloss = sum(
-(recall * np.log(y_pred) + (1 - recall) * np.log(1 - y_pred))
* count
/ total_count
* weight
)
l1 = np.abs(stability - init_s0) / total_count / 16
l1 = np.abs(stability - init_s0) / 16
return logloss + l1

res = minimize(
loss,
x0=init_s0,
bounds=((0.1, 365),),
options={"maxiter": int(np.sqrt(total_count))},
bounds=((0.1, 100),),
options={"maxiter": int(sum(weight))},
)
params = res.x
stability = params[0]
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = total_count
rating_count[int(first_rating)] = sum(count)
predict_recall = power_forgetting_curve(delta_t, *params)
rmse = mean_squared_error(
recall, predict_recall, sample_weight=count, squared=False
Expand All @@ -875,15 +881,15 @@ def loss(stability):
power_forgetting_curve(np.linspace(0, 30), *params),
label=f"Weighted fit (RMSE: {rmse:.4f})",
)
count_percent = np.array([x / total_count for x in count])
count_percent = np.array([x / sum(count) for x in count])
ax.scatter(delta_t, recall, s=count_percent * 1000, alpha=0.5)
ax.legend(loc="upper right", fancybox=True, shadow=False)
ax.grid(True)
ax.set_ylim(0, 1)
ax.set_xlabel("Interval")
ax.set_ylabel("Recall")
ax.set_title(
f"Forgetting curve for first rating {first_rating} (n={total_count}, s={stability:.2f})"
f"Forgetting curve for first rating {first_rating} (n={sum(count)}, s={stability:.2f})"
)
plots.append(fig)
tqdm.write(str(rating_stability))
Expand Down Expand Up @@ -983,7 +989,7 @@ def loss(stability):
item[1] for item in sorted(rating_stability.items(), key=lambda x: x[0])
]

self.init_w[0:4] = init_s0
self.init_w[0:4] = list(map(lambda x: max(min(100, x), 0.1), init_s0))

tqdm.write(f"Pretrain finished!")
return plots
Expand Down
65 changes: 37 additions & 28 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
import numpy as np
from tqdm import trange, tqdm


DECAY = -0.5
FACTOR = 0.9 ** (1 / DECAY) - 1


def power_forgetting_curve(t, s):
return (1 + FACTOR * t / s) ** DECAY


def next_interval(s, r):
ivl = s / FACTOR * (r ** (1 / DECAY) - 1)
return np.maximum(1, np.round(ivl))


columns = [
"difficulty",
"stability",
Expand Down Expand Up @@ -74,13 +88,10 @@ def stability_after_failure(s, r, d):
card_table[col["delta_t"]][has_learned] = (
today - card_table[col["last_date"]][has_learned]
)
card_table[col["retrievability"]][has_learned] = np.power(
1
+ card_table[col["delta_t"]][has_learned]
/ (9 * card_table[col["stability"]][has_learned]),
-1,
card_table[col["retrievability"]][has_learned] = power_forgetting_curve(
card_table[col["delta_t"]][has_learned],
card_table[col["stability"]][has_learned],
)

card_table[col["cost"]] = 0
need_review = card_table[col["due"]] <= today
card_table[col["rand"]][need_review] = np.random.rand(np.sum(need_review))
Expand Down Expand Up @@ -139,11 +150,9 @@ def stability_after_failure(s, r, d):
)

card_table[col["ivl"]][true_review | true_learn] = np.clip(
np.round(
9
* card_table[col["stability"]][true_review | true_learn]
* (1 / request_retention - 1),
0,
next_interval(
card_table[col["stability"]][true_review | true_learn],
request_retention,
),
1,
max_ivl,
Expand Down Expand Up @@ -418,23 +427,23 @@ def brent(tol=0.01, maxiter=20, **kwargs):
if __name__ == "__main__":
default_params = {
"w": [
0.4,
0.9,
2.3,
10.9,
4.93,
0.94,
0.86,
0.01,
1.49,
0.14,
0.94,
2.18,
0.05,
0.34,
1.26,
0.29,
2.61,
0.5888,
1.4616,
3.8226,
14.1364,
4.9214,
1.0325,
0.8731,
0.0613,
1.57,
0.1395,
0.988,
2.212,
0.0658,
0.3439,
1.3098,
0.2837,
2.7766,
],
"deck_size": 10000,
"learn_span": 365,
Expand Down

0 comments on commit 8ac5d5b

Please sign in to comment.