Skip to content

Commit

Permalink
Feat/save metrics (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Aug 1, 2023
1 parent 118c2f5 commit 5843a20
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,5 @@ meta
*.csv
*.tsv
*.png
*.json
.fsrs_optimizer
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "4.5.5"
version = "4.6.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand All @@ -16,7 +16,7 @@ dependencies = [
"tqdm>=4.64.1",
"statsmodels>=0.13.5",
]
requires-python = ">=3.8"
requires-python = ">=3.9"

[project.urls]
Homepage = "https://github.com/open-spaced-repetition/fsrs-optimizer"
20 changes: 17 additions & 3 deletions src/fsrs_optimizer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,25 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
loss_before, loss_after = optimizer.evaluate()
print(f"Loss before training: {loss_before:.4f}")
print(f"Loss after training: {loss_after:.4f}")
metrics, figures = optimizer.calibration_graph()
metrics['Log loss'] = loss_after
if save_graphs:
for i, f in enumerate(optimizer.calibration_graph()):
for i, f in enumerate(figures):
f.savefig(f"calibration_{i}.png")
for i, f in enumerate(optimizer.compare_with_sm2()):
figures = optimizer.compare_with_sm2()
if save_graphs:
for i, f in enumerate(figures):
f.savefig(f"compare_with_sm2_{i}.png")

evaluation = {
"filename": filename,
"size": optimizer.dataset.shape[0],
"parameters": optimizer.w,
"metrics": metrics
}

with open("evaluation.json", "w+") as f:
json.dump(evaluation, f)

if __name__ == "__main__":

Expand All @@ -141,7 +155,7 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
curdir = os.getcwd()
for filename in args.filenames:
if os.path.isdir(filename):
files = [f for f in os.listdir(filename) if f.lower().endswith('.apkg')]
files = [f for f in os.listdir(filename) if f.lower().endswith('.apkg') or f.lower().endswith('.colpkg')]
files = [os.path.join(filename, f) for f in files]
for file_path in files:
try:
Expand Down
10 changes: 8 additions & 2 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def evaluate(self):

def calibration_graph(self):
fig1 = plt.figure()
plot_brier(self.dataset['p'], self.dataset['y'], bins=40, ax=fig1.add_subplot(111))
metrics = plot_brier(self.dataset['p'], self.dataset['y'], bins=40, ax=fig1.add_subplot(111))
fig2 = plt.figure(figsize=(16, 12))
for last_rating in ("1","2","3","4"):
calibration_data = self.dataset[self.dataset['r_history'].str.endswith(last_rating)]
Expand Down Expand Up @@ -910,7 +910,7 @@ def to_percent(temp, position):
ax2.yaxis.set_major_formatter(ticker.FuncFormatter(to_percent))
ax2.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))

return fig1, fig2, fig3, fig4
return metrics, (fig1, fig2, fig3, fig4)

def bw_matrix(self):
B_W_Metric_raw = self.dataset[['difficulty', 'stability', 'p', 'y']].copy()
Expand Down Expand Up @@ -1025,6 +1025,12 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
ax2.legend(loc='lower center')
if title:
ax.set_title(title)
metrics = {
"R-squared": r2,
"RMSE": rmse,
"MAE": mae
}
return metrics

def sm2(history):
ivl = 0
Expand Down

0 comments on commit 5843a20

Please sign in to comment.