In [1]:
from sympy import symbols, log, diff, exp, Pow

# Define symbols for Loss and Retrievability
y, s, t, p = symbols("y s t p")

# Retrievability (R) and Loss (L)
factor = Pow(0.9, 1 / p) - 1
R = (1 + factor * t / s) ** p
L = -(y * log(R) + (1 - y) * log(1 - R))

# ∂L/∂s (Gradient of Loss w.r.t. Stability)
dL_ds = diff(L, s).simplify()
# ∂L/∂p (Gradient of Loss w.r.t. decay parameter)
dL_dp = diff(L, p).simplify()

print("--- Foundational Gradients ---")
print(f"∂L/∂S = {dL_ds}")
print(f"∂L/∂p = {dL_dp}\n")

--- Foundational Gradients ---
∂L/∂S = p*t*(0.9**(1/p) - 1)*(y*(((s + t*(0.9**(1/p) - 1))/s)**p - 1) - ((s + t*(0.9**(1/p) - 1))/s)**p*(y - 1))/(s*(s + t*(0.9**(1/p) - 1))*(((s + t*(0.9**(1/p) - 1))/s)**p - 1))
∂L/∂p = (0.105360515657826*0.9**(1/p)*t + p*(s + t*(0.9**(1/p) - 1))*log((s + t*(0.9**(1/p) - 1))/s))*(-y*(((s + t*(0.9**(1/p) - 1))/s)**p - 1) + ((s + t*(0.9**(1/p) - 1))/s)**p*(y - 1))/(p*(s + t*(0.9**(1/p) - 1))*(((s + t*(0.9**(1/p) - 1))/s)**p - 1))



In [2]:
from sympy import symbols, exp, log, diff, Pow

# Define symbols
s_last, d_last, t, p, rating = symbols("s_last d_last t p rating")
w8, w9, w10, w15, w16 = symbols("w8 w9 w10 w15 w16")

# Retrievability (r) at the last step
factor = Pow(0.9, 1 / p) - 1
r = (1 + factor * t / s_last) ** p

# S_new for success case
s_new_success = s_last * (
    1 + exp(w8) * (11 - d_last) * s_last ** (-w9) * (exp((1 - r) * w10) - 1) * w15 * w16
)

# Derivatives w.r.t weights
ds_dw8 = diff(s_new_success, w8).simplify()
ds_dw9 = diff(s_new_success, w9).simplify()
ds_dw10 = diff(s_new_success, w10).simplify()
ds_dw15 = diff(s_new_success, w15).simplify()
ds_dw16 = diff(s_new_success, w16).simplify()

# Derivative w.r.t last difficulty (for chain rule)
ds_d_last = diff(s_new_success, d_last).simplify()

print("--- Gradients for stability_after_success ---")
print(f"∂S_new/∂w8 = {ds_dw8}")
print(f"∂S_new/∂w9 = {ds_dw9}")
print(f"∂S_new/∂w10 = {ds_dw10}")
print(f"∂S_new/∂w15 = {ds_dw15}")
print(f"∂S_new/∂w16 = {ds_dw16}")
print(f"∂S_new/∂D_last = {ds_d_last}\n")

--- Gradients for stability_after_success ---
∂S_new/∂w8 = s_last**(1 - w9)*w15*w16*(d_last - 11)*(exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)) - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂w9 = s_last**(1 - w9)*w15*w16*(1 - exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)))*(d_last - 11)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)*log(s_last)
∂S_new/∂w10 = s_last**(1 - w9)*w15*w16*(d_last - 11)*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂w15 = s_last**(1 - w9)*w16*(d_last - 11)*(exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)) - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂w16 = s_last**(1 - w9)*w15*(d_last - 11)*(exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)) - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂D_last = s_last**(1 - w9)*w15*w16*exp(w8) - s_last**(1 - w9)*w15

In [3]:
from sympy import symbols, exp, log, diff, Pow

# Define symbols
s_last, d_last, t, p = symbols("s_last d_last t p")
w11, w12, w13, w14, w17, w18 = symbols("w11 w12 w13 w14 w17 w18")

# Retrievability (r) at the last step
factor = Pow(0.9, 1 / p) - 1
r = (1 + factor * t / s_last) ** p

# Branch 1: The main failure formula
s_main = w11 * d_last ** (-w12) * ((s_last + 1) ** w13 - 1) * exp((1 - r) * w14)
ds_main_dw11 = diff(s_main, w11).simplify()
ds_main_dw12 = diff(s_main, w12).simplify()
ds_main_dw13 = diff(s_main, w13).simplify()
ds_main_dw14 = diff(s_main, w14).simplify()
ds_main_d_last = diff(s_main, d_last).simplify()

# Branch 2: The minimum stability penalty
s_min = s_last / exp(w17 * w18)
ds_min_dw17 = diff(s_min, w17).simplify()
ds_min_dw18 = diff(s_min, w18).simplify()

print("--- Gradients for stability_after_failure ---")
print("--- Branch 1: Main Formula ---")
print(f"∂S_main/∂w11 = {ds_main_dw11}")
print(f"∂S_main/∂w12 = {ds_main_dw12}")
print(f"∂S_main/∂w13 = {ds_main_dw13}")
print(f"∂S_main/∂w14 = {ds_main_dw14}")
print(f"∂S_main/∂D_last = {ds_main_d_last}")
print("\n--- Branch 2: Minimum Penalty ---")
print(f"∂S_min/∂w17 = {ds_min_dw17}")
print(f"∂S_min/∂w18 = {ds_min_dw18}\n")

--- Gradients for stability_after_failure ---
--- Branch 1: Main Formula ---
∂S_main/∂w11 = ((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))/d_last**w12
∂S_main/∂w12 = -w11*((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))*log(d_last)/d_last**w12
∂S_main/∂w13 = w11*(s_last + 1)**w13*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))*log(s_last + 1)/d_last**w12
∂S_main/∂w14 = -w11*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)*((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))/d_last**w12
∂S_main/∂D_last = -d_last**(-w12 - 1)*w11*w12*((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))

--- Branch 2: Minimum Penalty ---
∂S_min/∂w17 = -s_last*w18*exp(-w17*w18)
∂S_min/∂w18 = -s_last*w17*exp(-w17*w18)



In [4]:
from sympy import symbols, exp, log, diff

s_last, rating = symbols("s_last rating")
w17, w18, w19 = symbols("w17 w18 w19")

s_new_short = s_last ** (1 - w19) * exp(w17 * (rating - 3 + w18))

ds_short_dw17 = diff(s_new_short, w17).simplify()
ds_short_dw18 = diff(s_new_short, w18).simplify()
ds_short_dw19 = diff(s_new_short, w19).simplify()

print("--- Gradients for stability_short_term ---")
print(f"∂S_short/∂w17 = {ds_short_dw17}")
print(f"∂S_short/∂w18 = {ds_short_dw18}")
print(f"∂S_short/∂w19 = {ds_short_dw19}\n")

--- Gradients for stability_short_term ---
∂S_short/∂w17 = s_last**(1 - w19)*(rating + w18 - 3)*exp(w17*(rating + w18 - 3))
∂S_short/∂w18 = s_last**(1 - w19)*w17*exp(w17*(rating + w18 - 3))
∂S_short/∂w19 = -s_last**(1 - w19)*exp(w17*(rating + w18 - 3))*log(s_last)



In [5]:
from sympy import symbols, exp, diff

d_last, rating = symbols("d_last rating")
w4, w5, w6, w7 = symbols("w4 w5 w6 w7")

init_d_4 = w4 - exp(w5 * (4 - 1)) + 1
delta_d = -w6 * (rating - 3)
d_intermediate = d_last + delta_d * (10 - d_last) / 9
d_new = w7 * init_d_4 + (1 - w7) * d_intermediate

dd_dw4 = diff(d_new, w4).simplify()
dd_dw5 = diff(d_new, w5).simplify()
dd_dw6 = diff(d_new, w6).simplify()
dd_dw7 = diff(d_new, w7).simplify()

print("--- Gradients for next_d ---")
print(f"∂D_new/∂w4 = {dd_dw4}")
print(f"∂D_new/∂w5 = {dd_dw5}")
print(f"∂D_new/∂w6 = {dd_dw6}")
print(f"∂D_new/∂w7 = {dd_dw7}")

--- Gradients for next_d ---
∂D_new/∂w4 = w7
∂D_new/∂w5 = -3*w7*exp(3*w5)
∂D_new/∂w6 = -(d_last - 10)*(rating - 3)*(w7 - 1)/9
∂D_new/∂w7 = -d_last + w4 - w6*(d_last - 10)*(rating - 3)/9 - exp(3*w5) + 1


In [6]:
from typing import List
import math

S_MIN = 0.001


class FSRS_one_step:
    def __init__(self, w: List[float], lr: float = 1e-2):
        self.w = w
        self.lr = lr
        self.s_min = S_MIN

    def forgetting_curve(self, t: float, s: float) -> float:
        decay = -self.w[20]
        if decay == 0 or s == 0:
            return 1.0
        factor = math.pow(0.9, 1 / decay) - 1
        return math.pow(1 + factor * t / s, decay)

    def init_stability(self, rating: int) -> float:
        return max(self.s_min, self.w[rating - 1])

    def init_difficulty(self, rating: int) -> float:
        val = self.w[4] - (math.exp(self.w[5] * (rating - 1)) - 1)
        return max(1, min(10, val))

    def next_difficulty(self, d: float, rating: int) -> float:
        init_d_4 = self.w[4] - (math.exp(self.w[5] * 3) - 1)
        delta_d = -self.w[6] * (rating - 3)
        linear_damping = delta_d * (10 - d) / 9 if d < 10 else 0
        d_intermediate = d + linear_damping
        new_d = self.w[7] * init_d_4 + (1 - self.w[7]) * d_intermediate
        return max(1, min(10, new_d))

    def stability_short_term(self, s: float, rating: int) -> float:
        if s <= 0:
            return self.s_min
        sinc = math.exp(self.w[17] * (rating - 3 + self.w[18])) * math.pow(
            s, -self.w[19]
        )
        new_s = s * (max(1, sinc) if rating >= 3 else sinc)
        return max(self.s_min, new_s)

    def stability_after_success(
        self, s: float, d: float, r: float, rating: int
    ) -> float:
        hard_penalty = self.w[15] if rating == 2 else 1.0
        easy_bonus = self.w[16] if rating == 4 else 1.0
        new_s = s * (
            1
            + math.exp(self.w[8])
            * (11 - d)
            * math.pow(s, -self.w[9])
            * (math.exp((1 - r) * self.w[10]) - 1)
            * hard_penalty
            * easy_bonus
        )
        return max(self.s_min, new_s)

    def stability_after_failure(self, s: float, d: float, r: float) -> float:
        s_main = (
            self.w[11]
            * math.pow(d, -self.w[12])
            * (math.pow(s + 1, self.w[13]) - 1)
            * math.exp((1 - r) * self.w[14])
        )
        s_min_penalty = s / math.exp(self.w[17] * self.w[18])
        return max(self.s_min, min(s_main, s_min_penalty))

    def step(self, delta_t, rating, last_s, last_d):
        if last_s is None:
            return self.init_stability(rating), self.init_difficulty(rating)
        elif delta_t < 1:
            return self.stability_short_term(last_s, rating), self.next_difficulty(
                last_d, rating
            )
        else:
            new_d = self.next_difficulty(last_d, rating)

            r = self.forgetting_curve(delta_t, last_s)
            if rating == 1:
                new_s = self.stability_after_failure(last_s, new_d, r)
            else:
                new_s = self.stability_after_success(last_s, new_d, r, rating)

            return new_s, new_d

    def forward(self, inputs):
        last_s = None
        last_d = None
        outputs = []
        for delta_t, rating in inputs:
            last_s, last_d = self.step(delta_t, rating, last_s, last_d)
            outputs.append((last_s, last_d))

        self.last_s, self.last_d = outputs[-2] if len(outputs) > 1 else (None, None)
        self.new_s, self.new_d = outputs[-1]
        self.last_delta_t = inputs[-1][0]
        self.last_rating = inputs[-1][1]
        return outputs

    def backward(self, delta_t, y):
        """
        Perform a single step of backpropagation.
        :param delta_t: Time elapsed in days.
        :param y: Actual outcome (0 for fail, 1 for success).
        """
        self.grad = [0.0] * len(self.w)
        if self.new_s <= S_MIN:
            return

        r = self.forgetting_curve(delta_t, self.new_s)
        r = min(max(r, 1e-9), 1.0 - 1e-9)
        dL_dr = (r - y) / (r * (1 - r))

        s = self.new_s
        decay = -self.w[20]
        factor = 0.9 ** (1 / decay) - 1
        dr_ds = (
            decay
            * math.pow(1 + factor * delta_t / s, decay - 1)
            * (-factor * delta_t / (s**2))
        )
        C = dL_dr * dr_ds
        rating = self.last_rating

        if self.last_s is None:
            if self.w[rating - 1] > S_MIN:
                self.grad[rating - 1] = C
        else:
            last_r = self.forgetting_curve(self.last_delta_t, self.last_s)
            s = self.last_s
            d = self.new_d
            if rating == 1:
                term1 = math.pow(d, -self.w[12])
                term3 = math.exp((1 - last_r) * self.w[14])
                self.grad[11] = C * (self.new_s / self.w[11])
                self.grad[12] = C * (-self.new_s * math.log(d))
                self.grad[13] = C * (
                    self.w[11]
                    * term1
                    * math.pow(s + 1, self.w[13])
                    * math.log(s + 1)
                    * term3
                )
                self.grad[14] = C * (self.new_s * (1 - last_r))
                ds_new_d_new = self.new_s * (-self.w[12] / d)
            else:
                hard_penalty = self.w[15] if rating == 2 else 1.0
                easy_bonus = self.w[16] if rating == 4 else 1.0
                ds_new_d_new = (
                    s
                    * math.exp(self.w[8])
                    * (-1)
                    * math.pow(s, -self.w[9])
                    * (math.exp((1 - last_r) * self.w[10]) - 1)
                    * hard_penalty
                    * easy_bonus
                )
                term_exp_w10 = math.exp((1 - last_r) * self.w[10])
                term_s_pow_w9 = math.pow(s, -self.w[9])
                common_factor = (
                    s
                    * math.exp(self.w[8])
                    * (11 - d)
                    * term_s_pow_w9
                    * (term_exp_w10 - 1)
                )
                self.grad[8] = C * common_factor * hard_penalty * easy_bonus
                self.grad[9] = (
                    C * common_factor * (-math.log(s)) * hard_penalty * easy_bonus
                )
                self.grad[10] = (
                    C
                    * s
                    * math.exp(self.w[8])
                    * (11 - d)
                    * term_s_pow_w9
                    * (term_exp_w10 * (1 - last_r))
                    * hard_penalty
                    * easy_bonus
                )
                if rating == 2:
                    self.grad[15] = C * (common_factor * easy_bonus)
                if rating == 4:
                    self.grad[16] = C * (common_factor * hard_penalty)

            last_d = self.last_d
            init_d_4 = self.w[4] - math.exp(self.w[5] * (4 - 1)) + 1
            d_intermediate = last_d + (-self.w[6] * (rating - 3) * (10 - last_d) / 9)

            d_newd_dw4 = self.w[7]
            self.grad[4] = C * ds_new_d_new * d_newd_dw4

            d_newd_dw5 = self.w[7] * (-math.exp(self.w[5] * 3) * 3)
            self.grad[5] = C * ds_new_d_new * d_newd_dw5

            d_newd_dw6 = (1 - self.w[7]) * (-(rating - 3) * (10 - last_d) / 9)
            self.grad[6] = C * ds_new_d_new * d_newd_dw6

            d_newd_dw7 = init_d_4 - d_intermediate
            self.grad[7] = C * ds_new_d_new * d_newd_dw7

        t = delta_t
        s = self.new_s
        log_term = math.log(1 + factor * t / s)
        d_factor_d_decay = math.pow(0.9, 1 / decay) * math.log(0.9) * (-1 / decay**2)
        dr_d_decay = r * (
            log_term + decay * (t / s) * d_factor_d_decay / (1 + factor * t / s)
        )
        dr_dw20 = -dr_d_decay
        self.grad[20] = dL_dr * dr_dw20
        for i in range(len(self.w)):
            self.w[i] -= self.lr * self.grad[i]

        self.clamp_weights()

    def clamp_weights(self):
        # Clamping bounds based on provided instructions
        self.w[0] = max(S_MIN, min(self.w[0], 100))
        self.w[1] = max(S_MIN, min(self.w[1], 100))
        self.w[2] = max(S_MIN, min(self.w[2], 100))
        self.w[3] = max(S_MIN, min(self.w[3], 100))
        self.w[4] = max(1, min(self.w[4], 10))
        self.w[5] = max(0.001, min(self.w[5], 4))
        self.w[6] = max(0.001, min(self.w[6], 4))
        self.w[7] = max(0.001, min(self.w[7], 0.75))
        self.w[8] = max(0, min(self.w[8], 4.5))
        self.w[9] = max(0, min(self.w[9], 0.8))
        self.w[10] = max(0.001, min(self.w[10], 3.5))
        self.w[11] = max(0.001, min(self.w[11], 5))
        self.w[12] = max(0.001, min(self.w[12], 0.25))
        self.w[13] = max(0.001, min(self.w[13], 0.9))
        self.w[14] = max(0, min(self.w[14], 4))
        self.w[15] = max(0, min(self.w[15], 1))
        self.w[16] = max(1, min(self.w[16], 6))
        self.w[17] = max(0, min(self.w[17], 2))
        self.w[18] = max(0, min(self.w[18], 2))
        self.w[19] = max(0.01, min(self.w[19], 0.8))
        self.w[20] = max(0.1, min(self.w[20], 0.8))

In [7]:
DEFAULT_PARAMETER = [
    0.212,  # w[0] initial stability for again
    1.2931,  # w[1] initial stability for hard
    2.3065,  # w[2] initial stability for good
    8.2956,  # w[3] initial stability for easy
    6.4133,  # w[4] initial difficulty
    0.8334,  # w[5] initial difficulty rating offset
    3.0194,  # w[6] next difficulty rating offset
    0.001,  # w[7] next difficulty reversion
    1.8722,  # w[8] stability after success
    0.1666,  # w[9] stability after success S decay
    0.796,  # w[10] stability after success R bonus
    1.4835,  # w[11] stability after failure
    0.0614,  # w[12] stability after failure
    0.2629,  # w[13] stability after failure
    1.6483,  # w[14] stability after failure
    0.6014,  # w[15] stability after success
    1.8729,  # w[16] stability after success
    0.5425,  # w[17] short term stability
    0.0912,  # w[18] short term stability
    0.0658,  # w[19] short term stability
    0.1542,  # w[20] forgetting curve decay
]

The learner recalled the card in the 2nd repetition, so the initial stability parameter should increase. The retrievability will decay slowly.

The learner recalled the card before the 90% point, so the decay parameter should increase, too. The retrievability before the 90% point will decay slowly.

In [8]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []

last_rating = 3
inputs.append((0, last_rating))

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 1
fsrs.backward(delta_t, 1)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(2.3065, 2.118103970459015)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3066994027269585 - 2.3065 = 0.00019940272695828654
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.4133 - 6.4133 = 0.0
w[5] = 0.8334 - 0.8334 = 0.0
w[6] = 3.0194 - 3.0194 = 0.0
w[7] = 0.001 - 0.001 = 0.0
w[8] = 1.8722 - 1.8722 = 0.0
w[9] = 0.1666 - 0.1666 = 0.0
w[10] = 0.796 - 0.796 = 0.0
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.1547747837234733 - 0.1542 = 0.0005747837234733044


The learner recalled the card in the 2nd repetition, so the initial stability parameter should increase. The retrievability will decay slowly.

The learner recalled the card after the 90% point, so the decay parameter should decrease. The retrievability after the 90% point will decay slowly.

In [9]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []
last_rating = 3
inputs.append((0, last_rating))

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 3
fsrs.backward(delta_t, 1)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(2.3065, 2.118103970459015)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3068746934098394 - 2.3065 = 0.00037469340983919963
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.4133 - 6.4133 = 0.0
w[5] = 0.8334 - 0.8334 = 0.0
w[6] = 3.0194 - 3.0194 = 0.0
w[7] = 0.001 - 0.001 = 0.0
w[8] = 1.8722 - 1.8722 = 0.0
w[9] = 0.1666 - 0.1666 = 0.0
w[10] = 0.796 - 0.796 = 0.0
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.1537154302886968 - 0.1542 = -0.00048456971130320103


The learner forgot the card in the 2nd repetition, so the initial stability parameter should decrease. The retrievability will decay quickly.

The learner forget the card before the 90% point, so the decay parameter should decrease, too. The retrievability before the 90% point will decay quickly.

In [10]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []
last_rating = 3
inputs.append((0, last_rating))

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 1
fsrs.backward(delta_t, 0)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(2.3065, 2.118103970459015)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.302947881638729 - 2.3065 = -0.003552118361271006
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.4133 - 6.4133 = 0.0
w[5] = 0.8334 - 0.8334 = 0.0
w[6] = 3.0194 - 3.0194 = 0.0
w[7] = 0.001 - 0.001 = 0.0
w[8] = 1.8722 - 1.8722 = 0.0
w[9] = 0.1666 - 0.1666 = 0.0
w[10] = 0.796 - 0.796 = 0.0
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.14396092328800497 - 0.1542 = -0.010239076711995032


The learner forgot the card in the 2nd repetition, so the initial stability parameter should decrease. The retrievability will decay quickly.

The learner forget the card after the 90% point, so the decay parameter should increase. The retrievability after the 90% point will decay quickly.

In [11]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []
last_rating = 3
inputs.append((0, last_rating))

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 3
fsrs.backward(delta_t, 0)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(2.3065, 2.118103970459015)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.303727385757546 - 2.3065 = -0.002772614242454008
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.4133 - 6.4133 = 0.0
w[5] = 0.8334 - 0.8334 = 0.0
w[6] = 3.0194 - 3.0194 = 0.0
w[7] = 0.001 - 0.001 = 0.0
w[8] = 1.8722 - 1.8722 = 0.0
w[9] = 0.1666 - 0.1666 = 0.0
w[10] = 0.796 - 0.796 = 0.0
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.15778566456666127 - 0.1542 = 0.003585664566661262


The learner recalled the card in the 3rd repetition, and the 2nd repetition's rating is easy, so the SInc parameters should increase. The retrievability will decay slowly.

- w[4]: The initial difficulty should decrease.
- w[5]: The initial difficulty rating offset should increase.
- w[6]: The next difficulty rating offset should increase.
- w[7]: The next difficulty reversion should increase.
- w[8]: The stability after success factor should increase.
- w[9]: The stability after success S decay should decrease.
- w[10]: The stability after success R bonus should increase.
- w[16]: The stability after success easy bonus should increase.

In [12]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = [(0, 3), (3, 4)]

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 100
fsrs.backward(delta_t, 1)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(26.599245611290105, 1)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3065 - 2.3065 = 0.0
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.413299889226701 - 6.4133 = -1.1077329897801746e-07
w[5] = 0.8334040492949065 - 0.8334 = 4.049294906494083e-06
w[6] = 3.019496914502205 - 3.0194 = 9.691450220472575e-05
w[7] = 0.0014702820726741222 - 0.001 = 0.0004702820726741222
w[8] = 1.8733077329873358 - 1.8722 = 0.0011077329873356856
w[9] = 0.16567423295461134 - 0.1666 = -0.0009257670453886591
w[10] = 0.7974586045936835 - 0.796 = 0.0014586045936834102
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8734914533543359 - 1.8729 = 0.000591453354335858
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.14961173986155998 - 0.1542 = -0.00458826013844002


In [13]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = [(0, 3), (3, 2)]

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 100
fsrs.backward(delta_t, 1)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(7.179616345177827, 4.752858488532556)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3065 - 2.3065 = 0.0
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.413299843896175 - 6.4133 = -1.5610382497754927e-07
w[5] = 0.8334057063428596 - 0.8334 = 5.706342859568281e-06
w[6] = 3.019263426253108 - 3.0194 = -0.00013657374689213242
w[7] = 0.002488297488680252 - 0.001 = 0.001488297488680252
w[8] = 1.8731752026834179 - 1.8722 = 0.0009752026834177752
w[9] = 0.1657849926677238 - 0.1666 = -0.0008150073322761853
w[10] = 0.797284095653075 - 0.796 = 0.001284095653074968
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.603021554179278 - 0.6014 = 0.0016215541792780064
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.14021303609688318 - 0.1542 = -0.013986963903116822


In [14]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = [(0, 3), (3, 3)]

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 100
fsrs.backward(delta_t, 0)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(13.835840133660223, 2.1112142357853942)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3065 - 2.3065 = 0.0
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.413300333115854 - 6.4133 = 3.331158540120782e-07
w[5] = 0.833387823019197 - 0.8334 = -1.2176980803024762e-05
w[6] = 3.0194 - 3.0194 = 0.0
w[7] = 0.001 - 0.001 = 0.0
w[8] = 1.869239004536278 - 1.8722 = -0.0029609954637221936
w[9] = 0.16907459636320157 - 0.1666 = 0.0024745963632015755
w[10] = 0.7921011167540938 - 0.796 = -0.0038988832459062595
w[11] = 1.4835 - 1.4835 = 0.0
w[12] = 0.0614 - 0.0614 = 0.0
w[13] = 0.2629 - 0.2629 = 0.0
w[14] = 1.6483 - 1.6483 = 0.0
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.1773547833850453 - 0.1542 = 0.023154783385045286


In [15]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = [(0, 3), (3, 1)]

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 100
fsrs.backward(delta_t, 1)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(0.589793227905589, 7.394502741279718)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3065 - 2.3065 = 0.0
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.413299987272626 - 6.4133 = -1.272737382151945e-08
w[5] = 0.8334004652465133 - 0.8334 = 4.652465133148098e-07
w[6] = 3.0193777298837023 - 3.0194 = -2.2270116297740117e-05
w[7] = 0.001154997929308168 - 0.001 = 0.000154997929308168
w[8] = 1.8722 - 1.8722 = 0.0
w[9] = 0.1666 - 0.1666 = 0.0
w[10] = 0.796 - 0.796 = 0.0
w[11] = 1.4845332177566457 - 1.4835 = 0.0010332177566456657
w[12] = 0.05833331348610436 - 0.0614 = -0.0030666865138956403
w[13] = 0.26969473607626765 - 0.2629 = 0.006794736076267627
w[14] = 1.6484824804187814 - 1.6483 = 0.0001824804187813278
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.11672685047474332 - 0.1542 = -0.03747314952525668


In [16]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = [(0, 3), (3, 1)]

outputs = fsrs.forward(inputs)
print(outputs[-1])

delta_t = 1
fsrs.backward(delta_t, 0)

for i, w in enumerate(fsrs.w):
    print(f"w[{i}] = {w} - {DEFAULT_PARAMETER[i]} = {w - DEFAULT_PARAMETER[i]}")

(0.589793227905589, 7.394502741279718)
w[0] = 0.212 - 0.212 = 0.0
w[1] = 1.2931 - 1.2931 = 0.0
w[2] = 2.3065 - 2.3065 = 0.0
w[3] = 8.2956 - 8.2956 = 0.0
w[4] = 6.413300049051702 - 6.4133 = 4.9051702255553664e-08
w[5] = 0.8333982069252138 - 0.8334 = -1.79307478620494e-06
w[6] = 3.0194858297329956 - 3.0194 = 8.582973299553487e-05
w[7] = 0.001 - 0.001 = 0.0
w[8] = 1.8722 - 1.8722 = 0.0
w[9] = 0.1666 - 0.1666 = 0.0
w[10] = 0.796 - 0.796 = 0.0
w[11] = 1.4795179460675592 - 1.4835 = -0.003982053932440888
w[12] = 0.07321910687623685 - 0.0614 = 0.01181910687623685
w[13] = 0.23671287212848716 - 0.2629 = -0.026187127871512866
w[14] = 1.647596714691044 - 1.6483 = -0.0007032853089561364
w[15] = 0.6014 - 0.6014 = 0.0
w[16] = 1.8729 - 1.8729 = 0.0
w[17] = 0.5425 - 0.5425 = 0.0
w[18] = 0.0912 - 0.0912 = 0.0
w[19] = 0.0658 - 0.0658 = 0.0
w[20] = 0.1614014236370198 - 0.1542 = 0.007201423637019783


In [17]:
last_rating = 3
delta_t = 1
inputs = [(0, last_rating)]

for y in [0 if i % 10 == 0 else 1 for i in range(100)]:
    outputs = fsrs.forward(inputs)
    fsrs.backward(delta_t, y)
    last_rating = 3 if y == 1 else 1
    inputs.append((delta_t, last_rating))
    last_s = outputs[-1][0]
    delta_t = round(last_s)


print(outputs[-1])
print(fsrs.w)

(379.91706249008826, 7.272065100357654)
[0.212, 1.2931, 2.3029164701964655, 8.2956, 6.413318298717482, 0.8327308232137631, 3.019444953629405, 0.015703305503781552, 1.8658653183045282, 0.22794597436940295, 0.7881691654641193, 1.4867349679708937, 0.049464717497093164, 0.310032751898678, 1.6486600515208725, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.10676680649777816]
