In [1]:
from fsrs_optimizer import (
    next_interval,
    power_forgetting_curve,
    lineToTensor,
    FSRS,
    DEFAULT_PARAMETER,
)
from datetime import datetime
import torch
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt

easy_days_percentages = np.array(
    [
        1,  # Monday
        1,  # Tuesday
        0.5,  # Wednesday
        1,  # Thursday
        1,  # Friday
        0,  # Saturday
        0,  # Sunday
    ]
)

initial_date = datetime(2024, 1, 1)  # Monday

sample_size = 1

requestRetention = 0.9
# parameters for Anki
graduatingInterval = 1
easyInterval = 4
easyBonus = 1.3
hardInterval = 1.2
intervalModifier = 1
newInterval = 0
minimumInterval = 1
leechThreshold = 8
leechSuspend = False

# common parameters
maximumInterval = 36500
new_cards_limits = 40
review_limits = 400
max_time_limts = 10000
learn_days = 90
deck_size = 1200


FUZZ_RANGES = [
    {
        "start": 2.5,
        "end": 7.0,
        "factor": 0.15,
    },
    {
        "start": 7.0,
        "end": 20.0,
        "factor": 0.1,
    },
    {
        "start": 20.0,
        "end": np.inf,
        "factor": 0.05,
    },
]


def get_fuzz_range(interval):
    delta = 1.0
    for range in FUZZ_RANGES:
        delta += range["factor"] * max(
            min(interval, range["end"]) - range["start"], 0.0
        )
    min_ivl = int(round(interval - delta))
    max_ivl = int(round(interval + delta))
    min_ivl = max(2, min_ivl)
    min_ivl = min(min_ivl, max_ivl)
    return min_ivl, max_ivl


def check_review_distribution(actual_reviews, percentages):
    if np.sum(percentages) == 0:
        return np.ones_like(actual_reviews)
    total_actual = np.sum(actual_reviews) + 1
    expected_distribution = percentages * (total_actual / np.sum(percentages))
    comparison = (expected_distribution - actual_reviews).clip(min=0)
    return comparison

def load_balance(card, delta_t, today):
    if delta_t < 2.5:
        return delta_t
    min_ivl, max_ivl = get_fuzz_range(delta_t)
    possible_intervals = np.array(range(min_ivl, max_ivl + 1))
    possible_dates = initial_date + pd.to_timedelta(
        possible_intervals + today, unit="D"
    )
    review_cnts = np.array(
        [card[card["due"] == today + i].shape[0] for i in possible_intervals]
    )
    inv = np.array(
        [
            1 if r == 0 else (1 / np.power(r, 2)) * (1 / delta_t)
            for r, delta_t in zip(review_cnts, possible_intervals)
        ]
    )
    mask = check_review_distribution(review_cnts, easy_days_percentages[possible_dates.dayofweek % 7])
    weights = inv * mask
    delta_t = random.choices(possible_intervals, weights)[0]
    return delta_t


first_rating_prob = np.array([0.24, 0.094, 0.495, 0.171])
review_rating_prob = np.array([0.224, 0.631, 0.145])
review_costs = np.array([23.0, 11.68, 7.33, 5.6])
learn_costs = np.array([33.79, 24.3, 13.68, 6.5])


def generate_rating(review_type):
    if review_type == "new":
        return np.random.choice([1, 2, 3, 4], p=first_rating_prob)
    elif review_type == "recall":
        return np.random.choice([2, 3, 4], p=review_rating_prob)


class Collection:
    def __init__(self):
        self.model = FSRS(DEFAULT_PARAMETER)
        self.model.eval()

    def states(self, t_history, r_history):
        with torch.no_grad():
            line_tensor = lineToTensor(
                list(zip([str(t_history)], [str(r_history)]))[0]
            ).unsqueeze(1)
            output_t = self.model(line_tensor)
            return output_t[-1][0]

    def next_states(self, states, t, r):
        with torch.no_grad():
            return self.model.step(torch.FloatTensor([[t, r]]), states.unsqueeze(0))[0]

    def init(self):
        t = 0
        r = generate_rating("new")
        p = round(first_rating_prob[r - 1], 2)
        new_states = self.states(t, r)
        return r, t, p, new_states


feature_list = [
    "difficulty",
    "stability",
    "retrievability",
    "delta_t",
    "reps",
    "lapses",
    "last_date",
    "due",
    "r_history",
    "t_history",
    "p_history",
    "states",
    "time",
    "factor",
]
field_map = {key: i for i, key in enumerate(feature_list)}


def fsrs_scheduler(stability):
    def constrain_interval(stability):
        if stability > 0:
            return min(
                next_interval(stability, requestRetention),
                maximumInterval,
            )
        else:
            return 1

    interval = constrain_interval(stability)
    return interval


def scheduler(scheduler_name, fsrs_inputs, anki_inputs):
    return fsrs_scheduler(fsrs_inputs), 2.5


scheduler_name = "fsrs"

for seed in range(0, sample_size):
    new_card_per_day = np.array([0] * learn_days)
    review_card_per_day = np.array([0.0] * learn_days)
    std_dev_per_day = np.array([0.0] * learn_days)
    time_per_day = np.array([0.0] * learn_days)
    learned_per_day = np.array([0.0] * learn_days)
    retention_per_day = np.array([0.0] * learn_days)
    expected_memorization_per_day = np.array([0.0] * learn_days)

    card_df = pd.DataFrame(
        np.zeros((deck_size, len(feature_list))),
        index=range(deck_size),
        columns=feature_list,
    )
    card_df["states"] = card_df["states"].astype(object)

    card_df["r_history"] = card_df["r_history"].astype(str)
    card_df["t_history"] = card_df["t_history"].astype(str)
    card_df["p_history"] = card_df["p_history"].astype(str)
    card_df["reps"] = 0
    card_df["lapses"] = 0
    card_df["due"] = learn_days

    student = Collection()
    random.seed(seed)
    np.random.seed(seed)

    for today in range(learn_days):
        reviewed = 0
        learned = 0
        review_time_today = 0
        learn_time_today = 0

        card_df["delta_t"] = today - card_df["last_date"]
        card_df["retrievability"] = power_forgetting_curve(
            card_df["delta_t"], card_df["stability"]
        )
        need_review = (
            card_df[card_df["due"] <= today]
            if leechSuspend == False
            else card_df[
                (card_df["due"] <= today) & (card_df["lapses"] < leechThreshold)
            ]
        )
        retention_per_day[today] = need_review["retrievability"].mean()
        last_date_list = []
        for idx in need_review.index:
            if (
                reviewed >= review_limits
                or review_time_today + learn_time_today >= max_time_limts
            ):
                break

            reviewed += 1
            last_date = card_df.iat[idx, field_map["last_date"]]
            last_date_list.append(last_date)
            due = card_df.iat[idx, field_map["due"]]
            factor = card_df.iat[idx, field_map["factor"]]
            card_df.iat[idx, field_map["last_date"]] = today
            ivl = card_df.iat[idx, field_map["delta_t"]]
            card_df.iat[idx, field_map["t_history"]] += f",{ivl}"

            stability = card_df.iat[idx, field_map["stability"]]
            retrievability = card_df.iat[idx, field_map["retrievability"]]
            card_df.iat[idx, field_map["p_history"]] += f",{retrievability:.2f}"
            reps = card_df.iat[idx, field_map["reps"]]
            lapses = card_df.iat[idx, field_map["lapses"]]
            states = card_df.iat[idx, field_map["states"]]

            if random.random() < retrievability:
                rating = generate_rating("recall")
                recall_time = review_costs[rating - 1]
                review_time_today += recall_time
                card_df.iat[idx, field_map["r_history"]] += f",{rating}"
                new_states = student.next_states(states, ivl, rating)
                new_stability = float(new_states[0])
                new_difficulty = float(new_states[1])
                card_df.iat[idx, field_map["stability"]] = new_stability
                card_df.iat[idx, field_map["difficulty"]] = new_difficulty
                card_df.iat[idx, field_map["states"]] = new_states
                card_df.iat[idx, field_map["reps"]] = reps + 1
                card_df.iat[idx, field_map["time"]] += recall_time

                delta_t, factor = scheduler(
                    scheduler_name,
                    new_stability,
                    (due - last_date, ivl, factor, rating),
                )
                delta_t = load_balance(card_df, delta_t, today)

                card_df.iat[idx, field_map["factor"]] = factor
                card_df.iat[idx, field_map["due"]] = today + delta_t

            else:
                review_time_today += review_costs[0]

                rating = 1
                card_df.iat[idx, field_map["r_history"]] += f",{rating}"

                new_states = student.next_states(states, ivl, 1)
                new_stability = float(new_states[0])
                new_difficulty = float(new_states[1])

                card_df.iat[idx, field_map["stability"]] = new_stability
                card_df.iat[idx, field_map["difficulty"]] = new_difficulty
                card_df.iat[idx, field_map["states"]] = new_states

                reps = 0
                lapses = lapses + 1

                card_df.iat[idx, field_map["reps"]] = reps
                card_df.iat[idx, field_map["lapses"]] = lapses

                delta_t, factor = scheduler(
                    scheduler_name,
                    new_stability,
                    (due - last_date, ivl, factor, rating),
                )
                delta_t = load_balance(card_df, delta_t, today)

                card_df.iat[idx, field_map["due"]] = today + delta_t
                card_df.iat[idx, field_map["factor"]] = factor
                card_df.iat[idx, field_map["time"]] += review_costs[0]

        need_learn = card_df[card_df["stability"] == 0]

        for idx in need_learn.index:
            if (
                learned >= new_cards_limits
                or review_time_today + learn_time_today >= max_time_limts
            ):
                break
            learned += 1
            r, t, p, new_states = student.init()
            learn_time_today += learn_costs[r - 1]
            card_df.iat[idx, field_map["last_date"]] = today

            card_df.iat[idx, field_map["reps"]] = 1
            card_df.iat[idx, field_map["lapses"]] = 0

            new_stability = float(new_states[0])
            new_difficulty = float(new_states[1])

            card_df.iat[idx, field_map["r_history"]] = str(r)
            card_df.iat[idx, field_map["t_history"]] = str(t)
            card_df.iat[idx, field_map["p_history"]] = str(p)
            card_df.iat[idx, field_map["stability"]] = new_stability
            card_df.iat[idx, field_map["difficulty"]] = new_difficulty
            card_df.iat[idx, field_map["states"]] = new_states

            delta_t, factor = scheduler(
                scheduler_name, new_stability, (None, None, None, r)
            )
            delta_t = load_balance(card_df, delta_t, today)
            card_df.iat[idx, field_map["due"]] = today + delta_t
            card_df.iat[idx, field_map["time"]] = learn_costs[r - 1]
            card_df.iat[idx, field_map["factor"]] = factor

        new_card_per_day[today] = learned
        review_card_per_day[today] = reviewed
        learned_per_day[today] = learned_per_day[today - 1] + learned
        time_per_day[today] = review_time_today + learn_time_today
        expected_memorization_per_day[today] = sum(
            card_df[card_df["retrievability"] > 0]["retrievability"]
        )
        if len(last_date_list) > 0:
            std_dev_per_day[today] = np.std(last_date_list)

    volatility = np.mean(np.abs(np.diff(review_card_per_day)) / review_card_per_day[1:])
    average_retention = retention_per_day[1:].mean()
    std_dev = np.mean(std_dev_per_day * review_card_per_day) / review_card_per_day.sum()

    total_learned = sum(new_card_per_day)
    total_time = sum(time_per_day)
    total_remembered = int(card_df["retrievability"].sum())
    total_leeches = len(card_df[card_df["lapses"] >= leechThreshold])

    plt.figure(1)
    plt.bar(range(learn_days), review_card_per_day, label=f"{scheduler_name}")
    plt.figure(2)
    plt.bar(range(learn_days), time_per_day / 60, label=f"{scheduler_name}")
    plt.figure(3)
    plt.plot(learned_per_day, label=f"{scheduler_name}")
    plt.figure(4)
    plt.plot(retention_per_day, label=f"{scheduler_name}")
    plt.figure(5)
    plt.plot(expected_memorization_per_day, label=f"{scheduler_name}")

    print("scheduler:", scheduler_name)
    print("learned cards:", total_learned)
    print("time in minutes:", round(total_time / 60, 1))
    print("remembered cards:", total_remembered)
    print("time per remembered card:", round(total_time / 60 / total_remembered, 2))
    print("leeches:", total_leeches)

    save = card_df[card_df["retrievability"] > 0].copy()
    save["stability"] = round(save["stability"], 2)
    save["retrievability"] = round(save["retrievability"], 2)
    save["difficulty"] = round(save["difficulty"], 2)
    save["factor"] = round(save["factor"], 2)
    save["time"] = round(save["time"], 2)

plt.figure(1)
plt.title("Review Count per Day")
plt.xlabel("Day")
plt.ylabel("Review Count")
plt.legend()
plt.grid(True)
plt.ylim(0, None)

plt.figure(2)
plt.title("Time Cost in minutes per Day")
plt.xlabel("Day")
plt.ylabel("Time (minutes)")
plt.legend()
plt.grid(True)
plt.ylim(0, None)

plt.figure(3)
plt.title(f"Cumulative Learn Count per Day")
plt.legend()
plt.grid(True)

plt.figure(4)
plt.title("Retention per Day")
plt.legend()
plt.grid(True)

plt.figure(5)
plt.title(f"Memorized Count per Day")
plt.legend()
plt.grid(True)
plt.show()
plt.close("all")