From a7504d088333f71345dd32b73ce39f44dd4b91cc Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Sun, 7 Apr 2024 12:18:46 +0800 Subject: [PATCH] Fix/fit power forgetting curve in analysis (#102) --- pyproject.toml | 2 +- src/fsrs_optimizer/fsrs_optimizer.py | 22 ++++++++++------------ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8503653..e09b485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.27.3" +version = "4.27.4" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index ff89cc5..6201ec9 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -18,7 +18,7 @@ from torch.nn.utils.rnn import pad_sequence from sklearn.model_selection import TimeSeriesSplit from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score -from scipy.optimize import minimize +from scipy.optimize import minimize, curve_fit from itertools import accumulate from tqdm.auto import tqdm import warnings @@ -709,15 +709,15 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame() group["group_cnt"] = group_cnt if group["i"].values[0] > 1: - r_ivl_cnt = sum( - group["delta_t"] - * group["retention"].map(np.log) - * pow(group["total_cnt"], 2) + group["stability"] = round( + curve_fit( + power_forgetting_curve, + group["delta_t"], + group["retention"], + sigma=1 / group["total_cnt"], + )[0][0], + 1, ) - ivl_ivl_cnt = sum( - group["delta_t"].map(lambda x: x**2) * pow(group["total_cnt"], 2) - ) - group["stability"] = round(np.log(0.9) / (r_ivl_cnt / ivl_ivl_cnt), 1) else: group["stability"] = 0.0 group["avg_retention"] = round( @@ -1011,9 +1011,7 @@ def train( w.append(trainer.train(verbose=verbose)) self.w = w[-1] self.evaluate() - metrics, figures = self.calibration_graph( - self.dataset.iloc[test_index] - ) + metrics, figures = self.calibration_graph(self.dataset.iloc[test_index]) for j, f in enumerate(figures): f.savefig(f"graph_{j}_test_{i}.png") plt.close(f)