In [1]:
import random
from fsrs_optimizer import (
    Collection,
    DEFAULT_WEIGHT,
    next_interval,
    power_forgetting_curve,
)


col = Collection(DEFAULT_WEIGHT)
desired_retention = 0.9
card_id = 0


def get_fuzz_range(interval):
    if interval <= 7:
        factor = 0.15
    elif interval <= 20:
        factor = 0.1
    else:
        factor = 0.05
    min_ivl = max(2, int(round(interval * (1 - factor) - 1)))
    max_ivl = int(round(interval * (1 + factor) + 1))
    min_ivl = min(min_ivl, max_ivl)
    return min_ivl, max_ivl


class Card:
    def __init__(self):
        global card_id
        self.id = card_id
        card_id += 1
        self.stability = 0
        self.difficulty = 0
        self.last_date = 0
        self.interval = 0
        self.due_date = 0
        self.t_history = []
        self.r_history = []
        self.fuzz_range = (0, 0)

    def retrievability(self, date):
        elapsed_days = date - self.last_date
        return power_forgetting_curve(elapsed_days, self.stability)

    def apply_fuzz(self, interval):
        if interval < 2.5:
            self.fuzz_range = (interval, interval)
            return interval
        min_ivl, max_ivl = get_fuzz_range(interval)
        self.fuzz_range = (min_ivl, max_ivl)
        return int(round(random.random() * (max_ivl - min_ivl + 1) + min_ivl))

    def review(self, rating, date):
        self.r_history.append(rating)
        if len(self.t_history) == 0:
            self.t_history.append(0)
        else:
            self.t_history.append(date - self.last_date)
        self.last_date = date
        self.stability, self.difficulty = tuple(
            col.predict(
                ",".join(map(str, self.t_history)), ",".join(map(str, self.r_history))
            )
            .detach()
            .numpy()
        )
        interval = int(next_interval(self.stability, desired_retention))
        self.interval = self.apply_fuzz(interval)
        self.due_date = int(date + interval)


cards = [Card() for _ in range(4)]
print("Total siblings:", len(cards))

Total siblings: 4


In [2]:
from copy import deepcopy

copy_cards = deepcopy(cards)
random.seed(42)

def disperse_siblings(current_card_id, today, siblings, messages):
    for sibling in sorted(siblings, key=lambda x: (x.due_date, -x.interval)):
        if sibling.id != current_card_id and sibling.interval > 0 and sibling.due_date <= today + 1:
            delay = today + 2 - sibling.due_date
            sibling.interval += delay
            sibling.due_date += delay
            messages.append(f"card:{sibling.id:<3}delay {delay} day")

for today in range(0, 60):
    learn_cnt = 0
    messages = []
    for card in copy_cards:
        if card.due_date <= today and card.interval > 0:
            retrievability = card.retrievability(today)
            rating = 3 if random.random() < retrievability else 1
            card.review(rating, today)
            messages.append(
                f"R card:{card.id:<3}R={retrievability * 100:6.2f}% grade:{rating:<3}ivl:{card.interval:>3} ({card.fuzz_range[0]:>3}, {card.fuzz_range[1]:<3})\tnext due:{card.due_date:>3}"
            )
            learn_cnt += 1

            """ #TODO: disperse siblings here
            We need to avoid the situation where siblings are due too close to each other.
            Because it feels like wasting time to review the related cards in a row.
            The challenge is that we couldn't delay the siblings too much, otherwise the
            siblings' retention will be lower than the desired retention too much.
            """
            disperse_siblings(card.id, today, copy_cards, messages)

    if learn_cnt == 0:
        for card in copy_cards:
            if card.interval == 0:
                rating = random.randint(1, 4)
                card.review(rating, today)
                messages.append(
                    f"L card:{card.id:<3}R={100:6.2f}% grade:{rating:<3}ivl:{card.interval:>3} ({card.fuzz_range[0]:>3}, {card.fuzz_range[1]:<3})\tnext due:{card.due_date:>3}"
                )
                disperse_siblings(card.id, today, copy_cards, messages)
                break
    if len(messages) > 0:
        messages.insert(0, f"Day {today}")
        print("\t".join(messages))

Day 0	L card:0  R=100.00% grade:1  ivl:  1 (  1, 1  )	next due:  1
Day 1	R card:0  R= 83.98% grade:3  ivl:  3 (  2, 4  )	next due:  4
Day 2	L card:1  R=100.00% grade:2  ivl:  1 (  1, 1  )	next due:  3
Day 3	R card:1  R= 91.79% grade:3  ivl:  3 (  2, 6  )	next due:  7	card:0  delay 1 day
Day 4	L card:2  R=100.00% grade:1  ivl:  1 (  1, 1  )	next due:  5	card:0  delay 1 day
Day 5	R card:2  R= 83.98% grade:3  ivl:  2 (  2, 4  )	next due:  8	card:0  delay 1 day
Day 6	L card:3  R=100.00% grade:1  ivl:  1 (  1, 1  )	next due:  7	card:0  delay 1 day	card:1  delay 1 day
Day 7	R card:3  R= 83.98% grade:3  ivl:  4 (  2, 4  )	next due: 10	card:0  delay 1 day	card:1  delay 1 day	card:2  delay 1 day
Day 9	R card:0  R= 75.74% grade:3  ivl: 13 ( 12, 16 )	next due: 23	card:1  delay 2 day	card:2  delay 2 day	card:3  delay 1 day
Day 11	R card:1  R= 81.81% grade:3  ivl: 20 ( 16, 22 )	next due: 30	card:2  delay 2 day	card:3  delay 2 day
Day 13	R card:2  R= 75.74% grade:3  ivl: 15 ( 12, 16 )	next due: 27	c