In [1]:
from sympy import symbols, log, diff, exp, Pow

# Define symbols for Loss and Retrievability
y, s, t, p = symbols("y s t p")

# Retrievability (R) and Loss (L)
factor = Pow(0.9, 1 / p) - 1
R = (1 + factor * t / s) ** p
L = -(y * log(R) + (1 - y) * log(1 - R))

# ∂L/∂s (Gradient of Loss w.r.t. Stability)
dL_ds = diff(L, s).simplify()
# ∂L/∂p (Gradient of Loss w.r.t. decay parameter)
dL_dp = diff(L, p).simplify()

print("--- Foundational Gradients ---")
print(f"∂L/∂S = {dL_ds}")
print(f"∂L/∂p = {dL_dp}\n")

--- Foundational Gradients ---
∂L/∂S = p*t*(0.9**(1/p) - 1)*(y*(((s + t*(0.9**(1/p) - 1))/s)**p - 1) - ((s + t*(0.9**(1/p) - 1))/s)**p*(y - 1))/(s*(s + t*(0.9**(1/p) - 1))*(((s + t*(0.9**(1/p) - 1))/s)**p - 1))
∂L/∂p = (0.105360515657826*0.9**(1/p)*t + p*(s + t*(0.9**(1/p) - 1))*log((s + t*(0.9**(1/p) - 1))/s))*(-y*(((s + t*(0.9**(1/p) - 1))/s)**p - 1) + ((s + t*(0.9**(1/p) - 1))/s)**p*(y - 1))/(p*(s + t*(0.9**(1/p) - 1))*(((s + t*(0.9**(1/p) - 1))/s)**p - 1))



In [2]:
from sympy import symbols, exp, log, diff, Pow

# Define symbols
s_last, d_last, t, p, rating = symbols("s_last d_last t p rating")
w8, w9, w10, w15, w16 = symbols("w8 w9 w10 w15 w16")

# Retrievability (r) at the last step
factor = Pow(0.9, 1 / p) - 1
r = (1 + factor * t / s_last) ** p

# S_new for success case
s_new_success = s_last * (
    1 + exp(w8) * (11 - d_last) * s_last ** (-w9) * (exp((1 - r) * w10) - 1) * w15 * w16
)

# Derivatives w.r.t weights
ds_dw8 = diff(s_new_success, w8).simplify()
ds_dw9 = diff(s_new_success, w9).simplify()
ds_dw10 = diff(s_new_success, w10).simplify()
ds_dw15 = diff(s_new_success, w15).simplify()
ds_dw16 = diff(s_new_success, w16).simplify()

# Derivative w.r.t last difficulty (for chain rule)
ds_d_last = diff(s_new_success, d_last).simplify()

print("--- Gradients for stability_after_success ---")
print(f"∂S_new/∂w8 = {ds_dw8}")
print(f"∂S_new/∂w9 = {ds_dw9}")
print(f"∂S_new/∂w10 = {ds_dw10}")
print(f"∂S_new/∂w15 = {ds_dw15}")
print(f"∂S_new/∂w16 = {ds_dw16}")
print(f"∂S_new/∂D_last = {ds_d_last}\n")

--- Gradients for stability_after_success ---
∂S_new/∂w8 = s_last**(1 - w9)*w15*w16*(d_last - 11)*(exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)) - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂w9 = s_last**(1 - w9)*w15*w16*(1 - exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)))*(d_last - 11)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)*log(s_last)
∂S_new/∂w10 = s_last**(1 - w9)*w15*w16*(d_last - 11)*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂w15 = s_last**(1 - w9)*w16*(d_last - 11)*(exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)) - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂w16 = s_last**(1 - w9)*w15*(d_last - 11)*(exp(w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)) - 1)*exp(-w10*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1) + w8)
∂S_new/∂D_last = s_last**(1 - w9)*w15*w16*exp(w8) - s_last**(1 - w9)*w15

In [3]:
from sympy import symbols, exp, log, diff, Pow

# Define symbols
s_last, d_last, t, p = symbols("s_last d_last t p")
w11, w12, w13, w14, w17, w18 = symbols("w11 w12 w13 w14 w17 w18")

# Retrievability (r) at the last step
factor = Pow(0.9, 1 / p) - 1
r = (1 + factor * t / s_last) ** p

# Branch 1: The main failure formula
s_main = w11 * d_last ** (-w12) * ((s_last + 1) ** w13 - 1) * exp((1 - r) * w14)
ds_main_dw11 = diff(s_main, w11).simplify()
ds_main_dw12 = diff(s_main, w12).simplify()
ds_main_dw13 = diff(s_main, w13).simplify()
ds_main_dw14 = diff(s_main, w14).simplify()
ds_main_d_last = diff(s_main, d_last).simplify()

# Branch 2: The minimum stability penalty
s_min = s_last / exp(w17 * w18)
ds_min_dw17 = diff(s_min, w17).simplify()
ds_min_dw18 = diff(s_min, w18).simplify()

print("--- Gradients for stability_after_failure ---")
print("--- Branch 1: Main Formula ---")
print(f"∂S_main/∂w11 = {ds_main_dw11}")
print(f"∂S_main/∂w12 = {ds_main_dw12}")
print(f"∂S_main/∂w13 = {ds_main_dw13}")
print(f"∂S_main/∂w14 = {ds_main_dw14}")
print(f"∂S_main/∂D_last = {ds_main_d_last}")
print("\n--- Branch 2: Minimum Penalty ---")
print(f"∂S_min/∂w17 = {ds_min_dw17}")
print(f"∂S_min/∂w18 = {ds_min_dw18}\n")

--- Gradients for stability_after_failure ---
--- Branch 1: Main Formula ---
∂S_main/∂w11 = ((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))/d_last**w12
∂S_main/∂w12 = -w11*((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))*log(d_last)/d_last**w12
∂S_main/∂w13 = w11*(s_last + 1)**w13*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))*log(s_last + 1)/d_last**w12
∂S_main/∂w14 = -w11*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1)*((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))/d_last**w12
∂S_main/∂D_last = -d_last**(-w12 - 1)*w11*w12*((s_last + 1)**w13 - 1)*exp(-w14*(((s_last + t*(0.9**(1/p) - 1))/s_last)**p - 1))

--- Branch 2: Minimum Penalty ---
∂S_min/∂w17 = -s_last*w18*exp(-w17*w18)
∂S_min/∂w18 = -s_last*w17*exp(-w17*w18)



In [4]:
from sympy import symbols, exp, log, diff

s_last, rating = symbols("s_last rating")
w17, w18, w19 = symbols("w17 w18 w19")

s_new_short = s_last ** (1 - w19) * exp(w17 * (rating - 3 + w18))

ds_short_dw17 = diff(s_new_short, w17).simplify()
ds_short_dw18 = diff(s_new_short, w18).simplify()
ds_short_dw19 = diff(s_new_short, w19).simplify()

print("--- Gradients for stability_short_term ---")
print(f"∂S_short/∂w17 = {ds_short_dw17}")
print(f"∂S_short/∂w18 = {ds_short_dw18}")
print(f"∂S_short/∂w19 = {ds_short_dw19}\n")

--- Gradients for stability_short_term ---
∂S_short/∂w17 = s_last**(1 - w19)*(rating + w18 - 3)*exp(w17*(rating + w18 - 3))
∂S_short/∂w18 = s_last**(1 - w19)*w17*exp(w17*(rating + w18 - 3))
∂S_short/∂w19 = -s_last**(1 - w19)*exp(w17*(rating + w18 - 3))*log(s_last)



In [5]:
from sympy import symbols, exp, diff

d_last, rating = symbols("d_last rating")
w4, w5, w6, w7 = symbols("w4 w5 w6 w7")

init_d_4 = w4 - exp(w5 * (4 - 1)) + 1
delta_d = -w6 * (rating - 3)
d_intermediate = d_last + delta_d * (10 - d_last) / 9
d_new = w7 * init_d_4 + (1 - w7) * d_intermediate

dd_dw4 = diff(d_new, w4).simplify()
dd_dw5 = diff(d_new, w5).simplify()
dd_dw6 = diff(d_new, w6).simplify()
dd_dw7 = diff(d_new, w7).simplify()

print("--- Gradients for next_d ---")
print(f"∂D_new/∂w4 = {dd_dw4}")
print(f"∂D_new/∂w5 = {dd_dw5}")
print(f"∂D_new/∂w6 = {dd_dw6}")
print(f"∂D_new/∂w7 = {dd_dw7}")

--- Gradients for next_d ---
∂D_new/∂w4 = w7
∂D_new/∂w5 = -3*w7*exp(3*w5)
∂D_new/∂w6 = -(d_last - 10)*(rating - 3)*(w7 - 1)/9
∂D_new/∂w7 = -d_last + w4 - w6*(d_last - 10)*(rating - 3)/9 - exp(3*w5) + 1


In [6]:
from typing import List
import math

S_MIN = 0.001


class FSRS_one_step:
    def __init__(self, w: List[float], lr: float = 1e-2):
        self.w = w
        self.lr = lr
        self.s_min = S_MIN

    def forgetting_curve(self, t: float, s: float) -> float:
        decay = -self.w[20]
        if decay == 0 or s == 0:
            return 1.0
        factor = math.pow(0.9, 1 / decay) - 1
        return math.pow(1 + factor * t / s, decay)

    def init_stability(self, rating: int) -> float:
        return max(self.s_min, self.w[rating - 1])

    def init_difficulty(self, rating: int) -> float:
        val = self.w[4] - (math.exp(self.w[5] * (rating - 1)) - 1)
        return max(1, min(10, val))

    def next_difficulty(self, d: float, rating: int) -> float:
        init_d_4 = self.w[4] - (math.exp(self.w[5] * 3) - 1)
        delta_d = -self.w[6] * (rating - 3)
        linear_damping = delta_d * (10 - d) / 9 if d < 10 else 0
        d_intermediate = d + linear_damping
        new_d = self.w[7] * init_d_4 + (1 - self.w[7]) * d_intermediate
        return max(1, min(10, new_d))

    def stability_short_term(self, s: float, rating: int) -> float:
        if s <= 0:
            return self.s_min
        sinc = math.exp(self.w[17] * (rating - 3 + self.w[18])) * math.pow(
            s, -self.w[19]
        )
        new_s = s * (max(1, sinc) if rating >= 3 else sinc)
        return max(self.s_min, new_s)

    def stability_after_success(
        self, s: float, d: float, r: float, rating: int
    ) -> float:
        hard_penalty = self.w[15] if rating == 2 else 1.0
        easy_bonus = self.w[16] if rating == 4 else 1.0
        new_s = s * (
            1
            + math.exp(self.w[8])
            * (11 - d)
            * math.pow(s, -self.w[9])
            * (math.exp((1 - r) * self.w[10]) - 1)
            * hard_penalty
            * easy_bonus
        )
        return max(self.s_min, new_s)

    def stability_after_failure(self, s: float, d: float, r: float) -> float:
        s_main = (
            self.w[11]
            * math.pow(d, -self.w[12])
            * (math.pow(s + 1, self.w[13]) - 1)
            * math.exp((1 - r) * self.w[14])
        )
        s_min_penalty = s / math.exp(self.w[17] * self.w[18])
        return max(self.s_min, min(s_main, s_min_penalty))

    def step(self, delta_t, rating, last_s, last_d):
        if last_s is None:
            return self.init_stability(rating), self.init_difficulty(rating)
        elif delta_t == 0:
            return self.stability_short_term(last_s, rating), self.next_difficulty(
                last_d, rating
            )
        else:
            r = self.forgetting_curve(delta_t, last_s)
            if rating == 1:
                return self.stability_after_failure(
                    last_s, last_d, r
                ), self.next_difficulty(last_d, rating)
            else:
                return self.stability_after_success(
                    last_s, last_d, r, rating
                ), self.next_difficulty(last_d, rating)

    def forward(self, inputs):
        last_s = None
        last_d = None
        outputs = []
        for delta_t, rating in inputs:
            last_s, last_d = self.step(delta_t, rating, last_s, last_d)
            outputs.append((last_s, last_d))

        self.last_s, self.last_d = outputs[-2] if len(outputs) > 1 else (None, None)
        self.new_s, self.new_d = outputs[-1]
        self.last_delta_t = inputs[-1][0]
        self.last_rating = inputs[-1][1]
        return outputs

    def backward(self, delta_t, y):
        """
        Perform a single step of backpropagation.
        :param delta_t: Time elapsed in days.
        :param y: Actual outcome (0 for fail, 1 for success).
        """
        p = -self.w[20]
        factor = math.pow(0.9, 1 / p) - 1
        new_r = self.forgetting_curve(delta_t, self.new_s)
        dL_dR = (new_r - y) / (new_r * (new_r - 1))

        # --- START: ADDED GRADIENT CALCULATION FOR w[20] ---
        # This gradient depends on the state BEFORE the review (last_s, delta_t)
        if new_r > 1e-6 and new_r < 1.0 - 1e-6:
            # Using a more stable, manually derived formula for dL/dp
            log_r = math.log(new_r)
            log_0_9 = math.log(0.9)

            # d(logR)/dp
            d_log_r_dp = log_r / p - (delta_t * (factor + 1) * log_0_9) / (
                p * (self.new_s + delta_t * factor)
            )

            # dL/dp = (dL/dR) * (dR/dp) = (R-y)/(R(1-R)) * (R * d(logR)/dp)
            grad_p = dL_dR * new_r * d_log_r_dp

            self.w[20] -= self.lr * grad_p
        # --- END: ADDED GRADIENT CALCULATION FOR w[20] ---

        dR_ds = (
            new_r
            * p
            * factor
            * delta_t
            / (self.new_s * (factor * delta_t + self.new_s))
        )
        grad_s = dL_dR * dR_ds
        if self.last_s is None:
            if new_r > 1e-6 and new_r < 1.0 - 1e-6:
                self.w[self.last_rating - 1] -= self.lr * grad_s
        else:
            if p == 0:
                return

            # --- Update weights based on the path taken ---
            if delta_t < 1:  # Short-term path
                if self.last_s > 0:
                    g17 = self.last_s * (self.last_rating - 3 + self.w[18])
                    g18 = self.last_s * self.w[17]
                    g19 = -self.last_s * math.log(self.last_s)
                    self.w[17] -= self.lr * grad_s * g17
                    self.w[18] -= self.lr * grad_s * g18
                    self.w[19] -= self.lr * grad_s * g19
            else:  # Long-term path
                last_r = self.forgetting_curve(self.last_delta_t, self.last_s)
                ds_new_d_last_d = 0.0  # This will connect S to D gradients

                if self.last_rating > 1:  # Success
                    ds_new_d_last_d = (
                        -(self.last_s ** (1 - self.w[9]))
                        * (self.w[15] if self.last_rating == 2 else 1.0)
                        * (self.w[16] if self.last_rating == 4 else 1.0)
                        * math.exp(self.w[8])
                        * (math.exp((1 - last_r) * self.w[10]) - 1)
                    )
                    g8 = ds_new_d_last_d * (11 - self.last_d)
                    g9 = g8 * math.log(self.last_s) if self.last_s > 0 else 0
                    g10 = (
                        self.last_s
                        * math.exp(self.w[8])
                        * (11 - self.last_d)
                        * math.pow(self.last_s, -self.w[9])
                        * (1 - last_r)
                        * math.exp((1 - last_r) * self.w[10])
                        * (self.w[15] if self.last_rating == 2 else 1.0)
                        * (self.w[16] if self.last_rating == 4 else 1.0)
                    )
                    self.w[8] -= self.lr * grad_s * g8
                    self.w[9] -= self.lr * grad_s * g9
                    self.w[10] -= self.lr * grad_s * g10
                    if self.last_rating == 2 and self.w[15] > 0:
                        self.w[15] -= (
                            self.lr * grad_s * (self.new_s - self.last_s) / self.w[15]
                        )
                    if self.last_rating == 4 and self.w[16] > 0:
                        self.w[16] -= (
                            self.lr * grad_s * (self.new_s - self.last_s) / self.w[16]
                        )
                else:  # Failure
                    s_main = (
                        self.w[11]
                        * math.pow(self.last_d, -self.w[12])
                        * (math.pow(self.last_s + 1, self.w[13]) - 1)
                        * math.exp((1 - last_r) * self.w[14])
                    )
                    s_min_penalty = self.last_s / math.exp(self.w[17] * self.w[18])

                    if s_main < s_min_penalty:
                        ds_new_d_last_d = (
                            -s_main * self.w[12] / self.last_d if self.last_d > 0 else 0
                        )
                        g11 = s_main / self.w[11] if self.w[11] != 0 else 0
                        g12 = -s_main * math.log(self.last_d) if self.last_d > 0 else 0
                        g13 = (
                            self.w[11]
                            * math.pow(self.last_d, -self.w[12])
                            * math.pow(self.last_s + 1, self.w[13])
                            * math.log(self.last_s + 1)
                            * math.exp((1 - last_r) * self.w[14])
                            if self.last_s >= 0
                            else 0
                        )
                        g14 = s_main * (1 - last_r)
                        self.w[11] -= self.lr * grad_s * g11
                        self.w[12] -= self.lr * grad_s * g12
                        self.w[13] -= self.lr * grad_s * g13
                        self.w[14] -= self.lr * grad_s * g14
                    else:
                        g17 = -s_min_penalty * self.w[18]
                        g18 = -s_min_penalty * self.w[17]
                        self.w[17] -= self.lr * grad_s * g17
                        self.w[18] -= self.lr * grad_s * g18

                # Update difficulty weights via chain rule
                if ds_new_d_last_d != 0:
                    grad_d_w4 = self.w[7]
                    grad_d_w5 = -3 * self.w[7] * math.exp(3 * self.w[5])
                    grad_d_w6 = (
                        -(self.last_d - 10)
                        * (self.last_rating - 3)
                        * (self.w[7] - 1)
                        / 9
                    )
                    init_d_4 = self.w[4] - (math.exp(self.w[5] * 3) - 1)
                    delta_d_term = (
                        -self.w[6] * (self.last_rating - 3) * (self.last_d - 10) / 9
                    )
                    grad_d_w7 = init_d_4 - (self.last_d + delta_d_term)

                    self.w[4] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w4
                    self.w[5] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w5
                    self.w[6] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w6
                    self.w[7] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w7

        self.clamp_weights()

    def clamp_weights(self):
        # Clamping bounds based on provided instructions
        self.w[0] = max(S_MIN, min(self.w[0], 100))
        self.w[1] = max(S_MIN, min(self.w[1], 100))
        self.w[2] = max(S_MIN, min(self.w[2], 100))
        self.w[3] = max(S_MIN, min(self.w[3], 100))
        self.w[4] = max(1, min(self.w[4], 10))
        self.w[5] = max(0.001, min(self.w[5], 4))
        self.w[6] = max(0.001, min(self.w[6], 4))
        self.w[7] = max(0.001, min(self.w[7], 0.75))
        self.w[8] = max(0, min(self.w[8], 4.5))
        self.w[9] = max(0, min(self.w[9], 0.8))
        self.w[10] = max(0.001, min(self.w[10], 3.5))
        self.w[11] = max(0.001, min(self.w[11], 5))
        self.w[12] = max(0.001, min(self.w[12], 0.25))
        self.w[13] = max(0.001, min(self.w[13], 0.9))
        self.w[14] = max(0, min(self.w[14], 4))
        self.w[15] = max(0, min(self.w[15], 1))
        self.w[16] = max(1, min(self.w[16], 6))
        self.w[17] = max(0, min(self.w[17], 2))
        self.w[18] = max(0, min(self.w[18], 2))
        self.w[19] = max(0.01, min(self.w[19], 0.8))
        self.w[20] = max(0.1, min(self.w[20], 0.8))

In [7]:
DEFAULT_PARAMETER = [
    0.212,  # w[0] initial stability for again
    1.2931,  # w[1] initial stability for hard
    2.3065,  # w[2] initial stability for good
    8.2956,  # w[3] initial stability for easy
    6.4133,  # w[4] initial difficulty
    0.8334,  # w[5] initial difficulty rating offset
    3.0194,  # w[6] next difficulty rating offset
    0.001,  # w[7] next difficulty reversion
    1.8722,  # w[8] stability after success
    0.1666,  # w[9] stability after success S decay
    0.796,  # w[10] stability after success R bonus
    1.4835,  # w[11] stability after failure
    0.0614,  # w[12] stability after failure
    0.2629,  # w[13] stability after failure
    1.6483,  # w[14] stability after failure
    0.6014,  # w[15] stability after success
    1.8729,  # w[16] stability after success
    0.5425,  # w[17] short term stability
    0.0912,  # w[18] short term stability
    0.0658,  # w[19] short term stability
    0.1542,  # w[20] forgetting curve decay
]

The learner recalled the card, so the stability parameter should increase. The retrievability will decay slowly.

The learner recalled the card before the 90% point, so the decay parameter should increase, too. The retrievability before the 90% point will decay slowly.

In [8]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []

last_s = None
last_d = None
last_rating = 3
inputs.append((0, last_rating))

new_s = fsrs.init_stability(last_rating)
print(new_s)
new_d = fsrs.init_difficulty(last_rating)

delta_t = 1


fsrs.forward(inputs)
fsrs.backward(delta_t, 1)
print(fsrs.w)

print(fsrs.w[2] - DEFAULT_PARAMETER[2])
print(fsrs.w[20] - DEFAULT_PARAMETER[20])

2.3065
[0.212, 1.2931, 2.3066994027269585, 8.2956, 6.4133, 0.8334, 3.0194, 0.001, 1.8722, 0.1666, 0.796, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.15477478372347328]
0.00019940272695828654
0.0005747837234732767


The learner recalled the card, so the stability parameter should increase. The retrievability will decay slowly.

The learner recalled the card after the 90% point, so the decay parameter should decrease. The retrievability after the 90% point will decay slowly.

In [9]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []

last_s = None
last_d = None
last_rating = 3
inputs.append((0, last_rating))

new_s = fsrs.init_stability(last_rating)
print(new_s)
new_d = fsrs.init_difficulty(last_rating)

delta_t = 3


fsrs.forward(inputs)
fsrs.backward(delta_t, 1)
print(fsrs.w)

print(fsrs.w[2] - DEFAULT_PARAMETER[2])
print(fsrs.w[20] - DEFAULT_PARAMETER[20])

2.3065
[0.212, 1.2931, 2.3068746934098394, 8.2956, 6.4133, 0.8334, 3.0194, 0.001, 1.8722, 0.1666, 0.796, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.1537154302886968]
0.00037469340983919963
-0.00048456971130320103


The learner forgot the card, so the stability parameter should decrease. The retrievability will decay quickly.

The learner forget the card before the 90% point, so the decay parameter should decrease, too. The retrievability before the 90% point will decay quickly.

In [10]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []

last_s = None
last_d = None
last_rating = 3
inputs.append((0, last_rating))

new_s = fsrs.init_stability(last_rating)
print(new_s)
new_d = fsrs.init_difficulty(last_rating)

delta_t = 1


fsrs.forward(inputs)
fsrs.backward(delta_t, 0)
print(fsrs.w)

print(fsrs.w[2] - DEFAULT_PARAMETER[2])
print(fsrs.w[20] - DEFAULT_PARAMETER[20])

2.3065
[0.212, 1.2931, 2.3029478816387297, 8.2956, 6.4133, 0.8334, 3.0194, 0.001, 1.8722, 0.1666, 0.796, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.143960923288005]
-0.003552118361270562
-0.010239076711995004


The learner forgot the card, so the stability parameter should decrease. The retrievability will decay quickly.

The learner forget the card after the 90% point, so the decay parameter should increase, too. The retrievability after the 90% point will decay quickly.

In [11]:
fsrs = FSRS_one_step(DEFAULT_PARAMETER.copy())

inputs = []

last_s = None
last_d = None
last_rating = 3
inputs.append((0, last_rating))

new_s = fsrs.init_stability(last_rating)
print(new_s)
new_d = fsrs.init_difficulty(last_rating)

delta_t = 3


fsrs.forward(inputs)
fsrs.backward(delta_t, 0)
print(fsrs.w)

print(fsrs.w[2] - DEFAULT_PARAMETER[2])
print(fsrs.w[20] - DEFAULT_PARAMETER[20])

2.3065
[0.212, 1.2931, 2.303727385757546, 8.2956, 6.4133, 0.8334, 3.0194, 0.001, 1.8722, 0.1666, 0.796, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.1577856645666613]
-0.002772614242454008
0.0035856645666612896


In [12]:
last_s = new_s
last_d = new_d
last_t = 2
r = fsrs.forgetting_curve(last_t, last_s)
last_rating = 3
inputs.append((last_t, last_rating))

new_s = fsrs.stability_after_success(last_s, last_d, r, last_rating)
print(new_s)
new_d = fsrs.next_difficulty(last_d, last_rating)

delta_t = 3

fsrs.forward(inputs)
fsrs.backward(delta_t, 1)
print(fsrs.w)

10.958313738080541
[0.212, 1.2931, 2.303727385757546, 8.2956, 6.413299971043636, 0.8334010584938395, 3.0194, 0.0011995016600122384, 1.8719428125920823, 0.16638536980045215, 0.7963348792460392, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.15830317463973395]


In [13]:
last_s = new_s
last_d = new_d
last_t = 2
r = fsrs.forgetting_curve(last_t, last_s)
last_rating = 1
inputs.append((last_t, last_rating))

new_s = fsrs.stability_after_failure(last_s, last_d, r)
print(new_s)
new_d = fsrs.next_difficulty(last_d, last_rating)

delta_t = 3

fsrs.forward(inputs)
fsrs.backward(delta_t, 1)
print(fsrs.w)

1.358322448354761
[0.212, 1.2931, 2.303727385757546, 8.2956, 6.413299933670745, 0.8334024246564425, 3.0193454357079843, 0.001248959311664799, 1.8719428125920823, 0.16638536980045215, 0.7963348792460392, 1.484221683979635, 0.06060066935390815, 0.2684440036518798, 1.6483267388708307, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.15628668652239472]


In [14]:
last_s = None
last_d = None
last_rating = 3
inputs = [(0, last_rating)]

for y in [0 if i % 10 == 0 else 1 for i in range(700)]:
    new_s = fsrs.init_stability(last_rating)
    new_d = fsrs.init_difficulty(last_rating)
    delta_t = 1
    fsrs.forward(inputs)
    fsrs.backward(delta_t, y)
    inputs.append((delta_t, last_rating))


print(new_s)
print(fsrs.w)

2.3001622760377325
[0.212, 1.2931, 2.3001622760377325, 8.2956, 6.413300932597923, 0.8333659107173972, 3.0193454357079843, 0.0010000073250437235, 1.8811349346582065, 0.2204486571206611, 0.7847058197258312, 1.484221683979635, 0.06060066935390815, 0.2684440036518798, 1.6483267388708307, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.10006540653431892]
