diff --git a/pyproject.toml b/pyproject.toml index 01a29b3..1fa840d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.22.0" +version = "4.23.0" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 82ca346..ae0751d 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -1634,76 +1634,16 @@ def compare_with_sm2(self): axis=1, ) tqdm.write(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}") - cross_comparison = self.dataset[["sm2_p", "p", "y"]].copy() + dataset = self.dataset[["sm2_p", "p", "y"]].copy() + dataset.rename(columns={"sm2_p": "R (SM2)", "p": "R (FSRS)"}, inplace=True) fig1 = plt.figure() plot_brier( - cross_comparison["sm2_p"], - cross_comparison["y"], + dataset["R (SM2)"], + dataset["y"], bins=20, ax=fig1.add_subplot(111), ) - - fig2 = plt.figure(figsize=(6, 6)) - ax = fig2.gca() - - def get_bin(x, bins=20): - return ( - np.log( - np.minimum(np.floor(np.exp(np.log(bins + 1) * x) - 1), bins - 1) + 1 - ) - / np.log(bins) - ).round(3) - - cross_comparison["SM2_B-W"] = cross_comparison["sm2_p"] - cross_comparison["y"] - cross_comparison["SM2_bin"] = cross_comparison["sm2_p"].map(get_bin) - cross_comparison["FSRS_B-W"] = cross_comparison["p"] - cross_comparison["y"] - cross_comparison["FSRS_bin"] = cross_comparison["p"].map(get_bin) - - ax.axhline(y=0.0, color="black", linestyle="-") - - cross_comparison_group = cross_comparison.groupby(by="SM2_bin").agg( - {"y": ["mean"], "FSRS_B-W": ["mean"], "p": ["mean", "count"]} - ) - tqdm.write( - f"Universal Metric of FSRS: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['p', 'mean'], sample_weight=cross_comparison_group['p', 'count'], squared=False):.4f}" - ) - cross_comparison_group["p", "percent"] = ( - cross_comparison_group["p", "count"] - / cross_comparison_group["p", "count"].sum() - ) - ax.scatter( - cross_comparison_group.index, - cross_comparison_group["FSRS_B-W", "mean"], - s=cross_comparison_group["p", "percent"] * 1024, - alpha=0.5, - ) - ax.plot(cross_comparison_group["FSRS_B-W", "mean"], label="FSRS by SM2") - - cross_comparison_group = cross_comparison.groupby(by="FSRS_bin").agg( - {"y": ["mean"], "SM2_B-W": ["mean"], "sm2_p": ["mean", "count"]} - ) - tqdm.write( - f"Universal Metric of SM2: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['sm2_p', 'mean'], sample_weight=cross_comparison_group['sm2_p', 'count'], squared=False):.4f}" - ) - cross_comparison_group["sm2_p", "percent"] = ( - cross_comparison_group["sm2_p", "count"] - / cross_comparison_group["sm2_p", "count"].sum() - ) - ax.scatter( - cross_comparison_group.index, - cross_comparison_group["SM2_B-W", "mean"], - s=cross_comparison_group["sm2_p", "percent"] * 1024, - alpha=0.5, - ) - ax.plot(cross_comparison_group["SM2_B-W", "mean"], label="SM2 by FSRS") - - ax.legend(loc="lower center") - ax.grid(linestyle="--") - ax.set_title("SM2 vs. FSRS") - ax.set_xlabel("Predicted R") - ax.set_ylabel("B-W Metric") - ax.set_xlim(0, 1) - ax.set_xticks(np.arange(0, 1.1, 0.1)) + _, fig2 = cross_comparison(dataset, "SM2", "FSRS") return fig1, fig2 @@ -1855,3 +1795,74 @@ def sm2(history): ef = max(1.3, ef + (0.1 - (5 - rating) * (0.08 + (5 - rating) * 0.02))) ivl = max(1, round(ivl + 0.01)) return ivl + + +def cross_comparison(dataset, algoA, algoB): + if algoA != algoB: + cross_comparison_record = dataset[[f"R ({algoA})", f"R ({algoB})", "y"]].copy() + bin_algo = ( + algoA, + algoB, + ) + pair_algo = [(algoA, algoB), (algoB, algoA)] + else: + cross_comparison_record = dataset[[f"R ({algoA})", "y"]].copy() + bin_algo = (algoA,) + pair_algo = [(algoA, algoA)] + + def get_bin(x, bins=20): + return ( + np.log(np.minimum(np.floor(np.exp(np.log(bins + 1) * x) - 1), bins - 1) + 1) + / np.log(bins) + ).round(3) + + for algo in bin_algo: + cross_comparison_record[f"{algo}_B-W"] = ( + cross_comparison_record[f"R ({algo})"] - cross_comparison_record["y"] + ) + cross_comparison_record[f"{algo}_bin"] = cross_comparison_record[ + f"R ({algo})" + ].map(get_bin) + + fig = plt.figure(figsize=(6, 6)) + ax = fig.gca() + ax.axhline(y=0.0, color="black", linestyle="-") + + universal_metric_list = [] + + for algoA, algoB in pair_algo: + cross_comparison_group = cross_comparison_record.groupby(by=f"{algoA}_bin").agg( + {"y": ["mean"], f"{algoB}_B-W": ["mean"], f"R ({algoB})": ["mean", "count"]} + ) + universal_metric = mean_squared_error( + y_true=cross_comparison_group["y", "mean"], + y_pred=cross_comparison_group[f"R ({algoB})", "mean"], + sample_weight=cross_comparison_group[f"R ({algoB})", "count"], + squared=False, + ) + cross_comparison_group[f"R ({algoB})", "percent"] = ( + cross_comparison_group[f"R ({algoB})", "count"] + / cross_comparison_group[f"R ({algoB})", "count"].sum() + ) + ax.scatter( + cross_comparison_group.index, + cross_comparison_group[f"{algoB}_B-W", "mean"], + s=cross_comparison_group[f"R ({algoB})", "percent"] * 1024, + alpha=0.5, + ) + ax.plot( + cross_comparison_group[f"{algoB}_B-W", "mean"], + label=f"{algoB} by {algoA}, UM={universal_metric:.4f}", + ) + universal_metric_list.append(universal_metric) + + tqdm.write(f"Universal Metric of {algoB}: {universal_metric:.4f}") + + ax.legend(loc="lower center") + ax.grid(linestyle="--") + ax.set_title(f"{algoA} vs {algoB}") + ax.set_xlabel("Predicted R") + ax.set_ylabel("B-W Metric") + ax.set_xlim(0, 1) + ax.set_xticks(np.arange(0, 1.1, 0.1)) + return universal_metric_list, fig