Skip to content

Commit

Permalink
Feat/additive smoothing (#358)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed Jul 17, 2023
1 parent 487eee1 commit dfd11f2
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 300 deletions.
600 changes: 306 additions & 294 deletions fsrs4anki_optimizer.ipynb

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions package/fsrs4anki_optimizer/fsrs4anki_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,15 +475,20 @@ def define_model(self):
'''

def pretrain(self, verbose=True):
self.dataset = pd.read_csv("./revlog_history.tsv", sep='\t', index_col=None, dtype={'r_history': str ,'t_history': str} )
self.dataset = self.dataset[(self.dataset['i'] > 1) & (self.dataset['delta_t'] > 0) & (self.dataset['t_history'].str.count(',0') == 0)]
if self.dataset.empty:
raise ValueError('Training data is inadequate.')
rating_stability = {}
rating_count = {}
average_recall = self.dataset['y'].mean()

for first_rating in ("1", "2", "3", "4"):
group = self.S0_dataset_group[self.S0_dataset_group['r_history'] == first_rating]
if group.empty:
continue
delta_t = group['delta_t']
recall = group['y']['mean']
recall = (group['y']['mean'] * group['y']['count'] + average_recall * 1) / (group['y']['count'] + 1)
count = group['y']['count']
total_count = sum(count)
if total_count < 100:
Expand Down Expand Up @@ -564,10 +569,6 @@ def S0_rating_curve(rating, a, b, c):

def train(self, lr: float = 4e-2, n_epoch: int = 5, n_splits: int = 5, batch_size: int = 512, verbose: bool = True):
"""Step 4"""
self.dataset = pd.read_csv("./revlog_history.tsv", sep='\t', index_col=None, dtype={'r_history': str ,'t_history': str} )
self.dataset = self.dataset[(self.dataset['i'] > 1) & (self.dataset['delta_t'] > 0) & (self.dataset['t_history'].str.count(',0') == 0)]
if self.dataset.empty:
raise ValueError('Training data is inadequate.')
self.dataset['tensor'] = self.dataset.progress_apply(lambda x: lineToTensor(list(zip([x['t_history']], [x['r_history']]))[0]), axis=1)
self.dataset['group'] = self.dataset['r_history'] + self.dataset['t_history']
tqdm.write("Tensorized!")
Expand Down
2 changes: 1 addition & 1 deletion package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fsrs4anki_optimizer"
version = "4.0.4"
version = "4.1.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down

0 comments on commit dfd11f2

Please sign in to comment.