diff --git a/fsrs4anki_optimizer.ipynb b/fsrs4anki_optimizer.ipynb index cf05439..07e87cb 100644 --- a/fsrs4anki_optimizer.ipynb +++ b/fsrs4anki_optimizer.ipynb @@ -5,9 +5,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# FSRS4Anki v3.25.2 Optimizer\n", + "# FSRS4Anki v3.25.3 Optimizer\n", "\n", - "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.25.2/fsrs4anki_optimizer.ipynb)\n", + "[![open in colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-spaced-repetition/fsrs4anki/blob/v3.25.3/fsrs4anki_optimizer.ipynb)\n", "\n", "↑ Click the above button to open the optimizer on Google Colab.\n", "\n", @@ -103,7 +103,7 @@ } ], "source": [ - "%pip install -q fsrs4anki_optimizer==3.25.2\n", + "%pip install -q fsrs4anki_optimizer==3.25.3\n", "# for local development\n", "# import os\n", "# import sys\n", @@ -159,7 +159,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "53c7d8cc155c410f9a0d84b387ef7dd9", + "model_id": "24b1f0f5916d479ab8f6069b65ca587e", "version_major": 2, "version_minor": 0 }, @@ -180,7 +180,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "146c56a60e254ef4a8febe1e65e4b4b6", + "model_id": "d663b6d68eb64255aa5488356241cf8d", "version_major": 2, "version_minor": 0 }, @@ -317,7 +317,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9cd1e8985ca2415ea9816c5e21ecc6dc", + "model_id": "a5928e08e0264a2e95a5c67089e6e33d", "version_major": 2, "version_minor": 0 }, @@ -342,7 +342,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1b43e2fbd2ec4f1d8229dd4cd999493f", + "model_id": "7e9229111d4b410ebd5f881b671d12a3", "version_major": 2, "version_minor": 0 }, @@ -363,7 +363,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d98e3ec13028493584328c4b5001ca90", + "model_id": "fa08fa6a568c4ea38016ac0bcd4a75bb", "version_major": 2, "version_minor": 0 }, @@ -415,7 +415,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "505394933bf840eeb349462918db2766", + "model_id": "8bbcffffb49341d8bb2a10476902db54", "version_major": 2, "version_minor": 0 }, @@ -436,7 +436,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "887ef137d1c348ffb749e21e66c2c5c6", + "model_id": "928bfd1eb36d4f669257c1230c54dc3c", "version_major": 2, "version_minor": 0 }, @@ -488,7 +488,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0bf34417570e41079326834d1fcc6031", + "model_id": "4017fb729a78467bad483a5818702c1e", "version_major": 2, "version_minor": 0 }, @@ -509,7 +509,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "8eb43490dac7455d89cb28a917d5f019", + "model_id": "0d5f6fd6f93d49e3a364f5ad26011247", "version_major": 2, "version_minor": 0 }, @@ -827,7 +827,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d0557f3c180a46abb5c05d9b996adffd", + "model_id": "a7de23d2b16d46378d3f5bc8e8389082", "version_major": 2, "version_minor": 0 }, @@ -1013,225 +1013,225 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", @@ -1249,257 +1249,257 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
d_bin1234567891012345678910
s_bin
0.5100002.01%0.5100002.01%
0.7100003.01%0.7100003.01%
1.000000-0.56%1.000000-0.56%
1.4000007.49%-1.94%1.4000007.49%-1.94%
1.9600001.47%-0.94%1.50%-0.92%1.9600001.47%-0.94%1.50%-0.92%
2.7400001.87%-0.06%-2.08%-1.56%0.94%2.7400001.87%-0.06%-2.08%-1.56%0.94%
3.840000-0.13%3.35%-1.06%-0.03%-0.05%2.24%3.840000-0.13%3.35%-1.06%-0.03%-0.05%2.24%
5.3800000.75%-1.12%-1.56%-0.14%-0.93%3.56%5.3800000.75%-1.12%-1.56%-0.14%-0.93%3.56%
7.530000-1.09%-1.31%-0.70%1.54%-0.18%-0.18%1.03%2.32%7.530000-1.09%-1.31%-0.70%1.54%-0.18%-0.18%1.03%2.32%
10.540000-2.02%-3.71%-2.23%-0.93%1.34%1.22%1.80%1.99%3.17%10.540000-2.02%-3.71%-2.23%-0.93%1.34%1.22%1.80%1.99%3.17%
14.760000-4.26%-3.58%-1.22%-0.49%-0.68%3.55%-0.17%1.43%2.58%2.58%14.760000-4.26%-3.58%-1.22%-0.49%-0.68%3.55%-0.17%1.43%2.58%2.58%
20.660000-2.86%-1.22%0.24%2.29%3.85%3.62%0.86%2.78%20.660000-2.86%-1.22%0.24%2.29%3.85%3.62%0.86%2.78%
28.930000-5.01%-2.90%-0.04%1.53%2.70%0.73%3.92%1.28%4.04%28.930000-5.01%-2.90%-0.04%1.53%2.70%0.73%3.92%1.28%4.04%
40.500000-3.42%-2.45%-0.89%2.81%2.57%1.75%4.93%2.36%40.500000-3.42%-2.45%-0.89%2.81%2.57%1.75%4.93%2.36%
56.690000-4.24%-1.55%-0.51%1.41%3.53%1.98%2.17%56.690000-4.24%-1.55%-0.51%1.41%3.53%1.98%2.17%
79.370000-4.92%-1.01%-2.67%5.00%2.39%79.370000-4.92%-1.01%-2.67%5.00%2.39%
111.120000-7.21%-2.62%-2.03%111.120000-7.21%-2.62%-2.03%
155.570000-8.77%-2.07%-1.76%155.570000-8.77%-2.07%-1.76%
217.800000-3.50%217.800000-3.50%
\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 13, diff --git a/package/fsrs4anki_optimizer/fsrs4anki_optimizer.py b/package/fsrs4anki_optimizer/fsrs4anki_optimizer.py index d6905b9..067caf6 100644 --- a/package/fsrs4anki_optimizer/fsrs4anki_optimizer.py +++ b/package/fsrs4anki_optimizer/fsrs4anki_optimizer.py @@ -465,6 +465,11 @@ def train(self, lr: float = 4e-2, n_epoch: int = 3, n_splits: int = 3, batch_siz self.dataset['tensor'] = self.dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1) self.dataset['group'] = self.dataset['r_history'] + self.dataset['t_history'] print("Tensorized!") + + n_pre_train_groups = len(self.dataset[self.dataset['i'] == 2]['group'].unique()) + if n_pre_train_groups < n_splits: + print("Not enough groups for pre-training. Splitting into {} folds.".format(n_pre_train_groups)) + n_splits = n_pre_train_groups w = [] plots = [] diff --git a/package/pyproject.toml b/package/pyproject.toml index a069b29..b5ef9b2 100644 --- a/package/pyproject.toml +++ b/package/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fsrs4anki_optimizer" -version = "3.25.2" +version = "3.25.3" readme = "README.md" dependencies = [ "matplotlib>=3.7.0",