diff --git a/pyproject.toml b/pyproject.toml index 24285ff..d95de0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "4.13.2" +version = "4.13.3" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 1e475d3..1448f95 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -511,8 +511,11 @@ def define_model(self): https://github.com/open-spaced-repetition/fsrs4anki/wiki/The-Algorithm ''' - def pretrain(self, verbose=True): - self.dataset = pd.read_csv("./revlog_history.tsv", sep='\t', index_col=None, dtype={'r_history': str ,'t_history': str} ) + def pretrain(self, dataset=None, verbose=True): + if dataset is None: + self.dataset = pd.read_csv("./revlog_history.tsv", sep='\t', index_col=None, dtype={'r_history': str ,'t_history': str} ) + else: + self.dataset = dataset self.dataset = self.dataset[(self.dataset['i'] > 1) & (self.dataset['delta_t'] > 0) & (self.dataset['t_history'].str.count(',0') == 0)] if self.dataset.empty: raise ValueError('Training data is inadequate.')