Skip to content

Commit

Permalink
Feat/cross comparison (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Feb 1, 2024
1 parent bf30a6e commit 5108d83
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 66 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.22.0"
version = "4.23.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
141 changes: 76 additions & 65 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,76 +1634,16 @@ def compare_with_sm2(self):
axis=1,
)
tqdm.write(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}")
cross_comparison = self.dataset[["sm2_p", "p", "y"]].copy()
dataset = self.dataset[["sm2_p", "p", "y"]].copy()
dataset.rename(columns={"sm2_p": "R (SM2)", "p": "R (FSRS)"}, inplace=True)
fig1 = plt.figure()
plot_brier(
cross_comparison["sm2_p"],
cross_comparison["y"],
dataset["R (SM2)"],
dataset["y"],
bins=20,
ax=fig1.add_subplot(111),
)

fig2 = plt.figure(figsize=(6, 6))
ax = fig2.gca()

def get_bin(x, bins=20):
return (
np.log(
np.minimum(np.floor(np.exp(np.log(bins + 1) * x) - 1), bins - 1) + 1
)
/ np.log(bins)
).round(3)

cross_comparison["SM2_B-W"] = cross_comparison["sm2_p"] - cross_comparison["y"]
cross_comparison["SM2_bin"] = cross_comparison["sm2_p"].map(get_bin)
cross_comparison["FSRS_B-W"] = cross_comparison["p"] - cross_comparison["y"]
cross_comparison["FSRS_bin"] = cross_comparison["p"].map(get_bin)

ax.axhline(y=0.0, color="black", linestyle="-")

cross_comparison_group = cross_comparison.groupby(by="SM2_bin").agg(
{"y": ["mean"], "FSRS_B-W": ["mean"], "p": ["mean", "count"]}
)
tqdm.write(
f"Universal Metric of FSRS: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['p', 'mean'], sample_weight=cross_comparison_group['p', 'count'], squared=False):.4f}"
)
cross_comparison_group["p", "percent"] = (
cross_comparison_group["p", "count"]
/ cross_comparison_group["p", "count"].sum()
)
ax.scatter(
cross_comparison_group.index,
cross_comparison_group["FSRS_B-W", "mean"],
s=cross_comparison_group["p", "percent"] * 1024,
alpha=0.5,
)
ax.plot(cross_comparison_group["FSRS_B-W", "mean"], label="FSRS by SM2")

cross_comparison_group = cross_comparison.groupby(by="FSRS_bin").agg(
{"y": ["mean"], "SM2_B-W": ["mean"], "sm2_p": ["mean", "count"]}
)
tqdm.write(
f"Universal Metric of SM2: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['sm2_p', 'mean'], sample_weight=cross_comparison_group['sm2_p', 'count'], squared=False):.4f}"
)
cross_comparison_group["sm2_p", "percent"] = (
cross_comparison_group["sm2_p", "count"]
/ cross_comparison_group["sm2_p", "count"].sum()
)
ax.scatter(
cross_comparison_group.index,
cross_comparison_group["SM2_B-W", "mean"],
s=cross_comparison_group["sm2_p", "percent"] * 1024,
alpha=0.5,
)
ax.plot(cross_comparison_group["SM2_B-W", "mean"], label="SM2 by FSRS")

ax.legend(loc="lower center")
ax.grid(linestyle="--")
ax.set_title("SM2 vs. FSRS")
ax.set_xlabel("Predicted R")
ax.set_ylabel("B-W Metric")
ax.set_xlim(0, 1)
ax.set_xticks(np.arange(0, 1.1, 0.1))
_, fig2 = cross_comparison(dataset, "SM2", "FSRS")
return fig1, fig2


Expand Down Expand Up @@ -1855,3 +1795,74 @@ def sm2(history):
ef = max(1.3, ef + (0.1 - (5 - rating) * (0.08 + (5 - rating) * 0.02)))
ivl = max(1, round(ivl + 0.01))
return ivl


def cross_comparison(dataset, algoA, algoB):
if algoA != algoB:
cross_comparison_record = dataset[[f"R ({algoA})", f"R ({algoB})", "y"]].copy()
bin_algo = (
algoA,
algoB,
)
pair_algo = [(algoA, algoB), (algoB, algoA)]
else:
cross_comparison_record = dataset[[f"R ({algoA})", "y"]].copy()
bin_algo = (algoA,)
pair_algo = [(algoA, algoA)]

def get_bin(x, bins=20):
return (
np.log(np.minimum(np.floor(np.exp(np.log(bins + 1) * x) - 1), bins - 1) + 1)
/ np.log(bins)
).round(3)

for algo in bin_algo:
cross_comparison_record[f"{algo}_B-W"] = (
cross_comparison_record[f"R ({algo})"] - cross_comparison_record["y"]
)
cross_comparison_record[f"{algo}_bin"] = cross_comparison_record[
f"R ({algo})"
].map(get_bin)

fig = plt.figure(figsize=(6, 6))
ax = fig.gca()
ax.axhline(y=0.0, color="black", linestyle="-")

universal_metric_list = []

for algoA, algoB in pair_algo:
cross_comparison_group = cross_comparison_record.groupby(by=f"{algoA}_bin").agg(
{"y": ["mean"], f"{algoB}_B-W": ["mean"], f"R ({algoB})": ["mean", "count"]}
)
universal_metric = mean_squared_error(
y_true=cross_comparison_group["y", "mean"],
y_pred=cross_comparison_group[f"R ({algoB})", "mean"],
sample_weight=cross_comparison_group[f"R ({algoB})", "count"],
squared=False,
)
cross_comparison_group[f"R ({algoB})", "percent"] = (
cross_comparison_group[f"R ({algoB})", "count"]
/ cross_comparison_group[f"R ({algoB})", "count"].sum()
)
ax.scatter(
cross_comparison_group.index,
cross_comparison_group[f"{algoB}_B-W", "mean"],
s=cross_comparison_group[f"R ({algoB})", "percent"] * 1024,
alpha=0.5,
)
ax.plot(
cross_comparison_group[f"{algoB}_B-W", "mean"],
label=f"{algoB} by {algoA}, UM={universal_metric:.4f}",
)
universal_metric_list.append(universal_metric)

tqdm.write(f"Universal Metric of {algoB}: {universal_metric:.4f}")

ax.legend(loc="lower center")
ax.grid(linestyle="--")
ax.set_title(f"{algoA} vs {algoB}")
ax.set_xlabel("Predicted R")
ax.set_ylabel("B-W Metric")
ax.set_xlim(0, 1)
ax.set_xticks(np.arange(0, 1.1, 0.1))
return universal_metric_list, fig

0 comments on commit 5108d83

Please sign in to comment.