In [1]:
import random
from fsrs_optimizer import (
    Collection,
    DEFAULT_WEIGHT,
    next_interval,
    power_forgetting_curve,
)


col = Collection(DEFAULT_WEIGHT)
desired_retention = 0.9
card_id = 0


def get_fuzz_range(interval):
    if interval <= 7:
        factor = 0.15
    elif interval <= 20:
        factor = 0.1
    else:
        factor = 0.05
    min_ivl = max(2, int(round(interval * (1 - factor) - 1)))
    max_ivl = int(round(interval * (1 + factor) + 1))
    min_ivl = min(min_ivl, max_ivl)
    return min_ivl, max_ivl


class Card:
    def __init__(self):
        global card_id
        self.id = card_id
        card_id += 1
        self.stability = 0
        self.difficulty = 0
        self.last_date = 0
        self.interval = 0
        self.due_date = 0
        self.t_history = []
        self.r_history = []

    def retrievability(self, date):
        elapsed_days = date - self.last_date
        return power_forgetting_curve(elapsed_days, self.stability)

    def fuzz_factor(self):
        return random.random()

    def review(self, rating, date):
        self.r_history.append(rating)
        if len(self.t_history) == 0:
            self.t_history.append(0)
        else:
            self.t_history.append(date - self.last_date)
        self.last_date = date
        self.stability, self.difficulty = tuple(
            col.predict(
                ",".join(map(str, self.t_history)), ",".join(map(str, self.r_history))
            )
            .detach()
            .numpy()
        )
        interval = next_interval(self.stability, desired_retention)
        min_ivl, max_ivl = get_fuzz_range(interval)
        self.interval = int(self.fuzz_factor() * (max_ivl - min_ivl + 1) + min_ivl)
        self.due_date = date + interval


cards = [Card() for _ in range(5)]

In [2]:
from copy import deepcopy

copy_cards = deepcopy(cards)
random.seed(42)

for today in range(0, 30):
    learn_cnt = 0
    for card in copy_cards:
        if card.due_date <= today and card.interval > 0:
            retrievability = card.retrievability(today)
            rating = 3 if random.random() < retrievability else 1
            card.review(rating, today)
            print(f"today {today:>5}\tcard {card.id:>5}\tdue date {card.due_date:>5}\trating {rating:>5}\treview")
            learn_cnt += 1
    if learn_cnt == 0:
        for card in copy_cards:
            if card.interval == 0:
                rating = random.randint(1, 4)
                card.review(rating, today)
                print(f"today {today:>5}\tcard {card.id:>5}\tdue date {card.due_date:>5}\trating {rating:>5}\tlearn")
                break

today     0	card     0	due date   1.0	rating     1	learn
today     1	card     0	due date   4.0	rating     3	review
today     2	card     1	due date   3.0	rating     1	learn
today     3	card     1	due date   4.0	rating     1	review
today     4	card     0	due date  12.0	rating     3	review
today     4	card     1	due date   5.0	rating     3	review
today     5	card     1	due date   8.0	rating     3	review
today     6	card     2	due date  14.0	rating     4	learn
today     7	card     3	due date  11.0	rating     3	learn
today     8	card     1	due date  14.0	rating     3	review
today     9	card     4	due date  17.0	rating     4	learn
today    11	card     3	due date  25.0	rating     3	review
today    12	card     0	due date  32.0	rating     3	review
today    14	card     1	due date  27.0	rating     3	review
today    14	card     2	due date  45.0	rating     3	review
today    17	card     4	due date  48.0	rating     3	review
today    25	card     3	due date  28.0	rating     1	review
today    27	card   