In [1]:
from sympy import symbols, log, diff, exp

# Symbols
y, s, t, p = symbols("y s t p")

# Retrievability (R)
R = (1 + t / s) ** -p

# Loss Function (L)
L = -(y * log(R) + (1 - y) * log(1 - R))

# Gradient of Loss w.r.t. Stability (s)
dL_ds = diff(L, s).simplify()

# Gradient of Loss w.r.t. power (p, which is w[20])
dL_dp = diff(L, p).simplify()

print(f"{dL_ds=}")
print(f"{dL_dp=}")

dL_ds=p*t*(-y*(((s + t)/s)**p - 1) - y + 1)/(s*(s + t)*(((s + t)/s)**p - 1))
dL_dp=(y*((s + t)/s)**p - 1)*log((s + t)/s)/(((s + t)/s)**p - 1)


In [2]:
# Symbols for the success case
last_s, last_d, t, p = symbols("last_s last_d t p")
w8, w9, w10, w15, w16 = symbols("w8 w9 w10 w15 w16")

# Retrievability (r)
r = (1 + t / last_s) ** -p

# new_s formula (using w15 and w16 as placeholders for hard_penalty and easy_bonus)
new_s = last_s * (
    1
    + exp(w8)
    * (11 - last_d)
    * last_s ** (-w9)
    * (exp((1 - r) * w10) - 1)
    * w15 # hard_penalty
    * w16 # easy_bonus
)

# Gradients w.r.t weights
grad_w8 = diff(new_s, w8).simplify()
grad_w9 = diff(new_s, w9).simplify()
grad_w10 = diff(new_s, w10).simplify()

# Gradient w.r.t last difficulty
grad_last_d = diff(new_s, last_d).simplify()

# Gradients for penalty/bonus are the term itself
grad_w15 = diff(new_s, w15).simplify()
grad_w16 = diff(new_s, w16).simplify()


print(f"{grad_w8=}")
print(f"{grad_w9=}")
print(f"{grad_w10=}")
print(f"{grad_last_d=}")
print(f"{grad_w15=}")
print(f"{grad_w16=}")

grad_w8=last_s**(1 - w9)*w15*w16*(1 - exp(w10*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p))*(last_d - 11)*exp(w8)
grad_w9=last_s**(1 - w9)*w15*w16*(last_d - 11)*(exp(w10*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p) - 1)*exp(w8)*log(last_s)
grad_w10=-last_s**(1 - w9)*w15*w16*(last_d - 11)*(((last_s + t)/last_s)**p - 1)*exp(w10*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p + w8)/((last_s + t)/last_s)**p
grad_last_d=last_s**(1 - w9)*w15*w16*(1 - exp(w10*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p))*exp(w8)
grad_w15=last_s**(1 - w9)*w16*(1 - exp(w10*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p))*(last_d - 11)*exp(w8)
grad_w16=last_s**(1 - w9)*w15*(1 - exp(w10*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p))*(last_d - 11)*exp(w8)


In [3]:
# Symbols for the failure case
last_s, last_d, t, p = symbols("last_s last_d t p")
w11, w12, w13, w14, w17, w18 = symbols("w11 w12 w13 w14 w17 w18")

# Retrievability (r)
r = (1 + t / last_s) ** -p

# Main failure formula for new_s
new_s_main = w11 * last_d**(-w12) * ((last_s + 1)**w13 - 1) * exp((1 - r) * w14)
# Minimum stability formula
new_s_min = last_s / exp(w17 * w18)

# Gradients for the main formula
grad_w11 = diff(new_s_main, w11).simplify()
grad_w12 = diff(new_s_main, w12).simplify()
grad_w13 = diff(new_s_main, w13).simplify()
grad_w14 = diff(new_s_main, w14).simplify()
grad_last_d_fail = diff(new_s_main, last_d).simplify()

# Gradients for the minimum stability formula
grad_w17 = diff(new_s_min, w17).simplify()
grad_w18 = diff(new_s_min, w18).simplify()


print(f"{grad_w11=}")
print(f"{grad_w12=}")
print(f"{grad_w13=}")
print(f"{grad_w14=}")
print(f"{grad_last_d_fail=}")
print(f"{grad_w17=}")
print(f"{grad_w18=}")

grad_w11=((last_s + 1)**w13 - 1)*exp(w14 - w14/((last_s + t)/last_s)**p)/last_d**w12
grad_w12=w11*(1 - (last_s + 1)**w13)*exp(w14 - w14/((last_s + t)/last_s)**p)*log(last_d)/last_d**w12
grad_w13=w11*(last_s + 1)**w13*exp(w14 - w14/((last_s + t)/last_s)**p)*log(last_s + 1)/last_d**w12
grad_w14=w11*(((last_s + t)/last_s)**p - 1)*((last_s + 1)**w13 - 1)*exp(w14*(((last_s + t)/last_s)**p - 1)/((last_s + t)/last_s)**p)/(last_d**w12*((last_s + t)/last_s)**p)
grad_last_d_fail=last_d**(-w12 - 1)*w11*w12*(1 - (last_s + 1)**w13)*exp(w14 - w14/((last_s + t)/last_s)**p)
grad_w17=-last_s*w18*exp(-w17*w18)
grad_w18=-last_s*w17*exp(-w17*w18)


In [4]:
from sympy import symbols, diff, exp

# Symbols for difficulty calculation
last_d, rating = symbols("last_d rating")
w4, w5, w6, w7 = symbols("w4 w5 w6 w7")

# Formula for initial D when rating is 4 (for mean reversion)
init_d_4 = w4 - exp(w5 * (4 - 1)) + 1

# Intermediate new_d before mean reversion
# Note: delta_d * (10 - old_d) / 9 is linear_damping from the pytorch code
delta_d = -w6 * (rating - 3)
d_intermediate = last_d + delta_d * (10 - last_d) / 9

# Final new_d after mean reversion
new_d = w7 * init_d_4 + (1 - w7) * d_intermediate

# Gradients of new_d w.r.t relevant weights
grad_d_w4 = diff(new_d, w4).simplify()
grad_d_w5 = diff(new_d, w5).simplify()
grad_d_w6 = diff(new_d, w6).simplify()
grad_d_w7 = diff(new_d, w7).simplify()

print(f"{grad_d_w4=}")
print(f"{grad_d_w5=}")
print(f"{grad_d_w6=}")
print(f"{grad_d_w7=}")

grad_d_w4=w7
grad_d_w5=-3*w7*exp(3*w5)
grad_d_w6=-(last_d - 10)*(rating - 3)*(w7 - 1)/9
grad_d_w7=-last_d + w4 - w6*(last_d - 10)*(rating - 3)/9 - exp(3*w5) + 1


In [5]:
from typing import List, Optional
import math

S_MIN = 0.001

class FSRS_one_step:
    def __init__(self, w: List[float], lr: float = 1e-2):
        self.w = w
        self.lr = lr

    def power_forgetting_curve(self, t, s):
        return (1 + t / s) ** (-self.w[20])

    def init_stability(self, rating: int) -> float:
        return max(S_MIN, self.w[rating - 1])

    def init_difficulty(self, rating: int) -> float:
        return max(1, min(10, self.w[4] - math.exp(self.w[5] * (rating - 1)) + 1))
    
    def next_difficulty(self, d: float, rating: int) -> float:
        # From PyTorch code:
        init_d_4 = self.w[4] - math.exp(self.w[5] * (4 - 1)) + 1
        delta_d = -self.w[6] * (rating - 3)
        linear_damping = delta_d * (10 - d) / 9
        d_intermediate = d + linear_damping
        # Mean reversion
        new_d = self.w[7] * init_d_4 + (1 - self.w[7]) * d_intermediate
        return max(1, min(10, new_d))

    def stability_after_success(self, s: float, d: float, r: float, rating: int) -> float:
        hard_penalty = self.w[15] if rating == 2 else 1.0
        easy_bonus = self.w[16] if rating == 4 else 1.0
        new_s = s * (
            1
            + math.exp(self.w[8])
            * (11 - d)
            * math.pow(s, -self.w[9])
            * (math.exp((1 - r) * self.w[10]) - 1)
            * hard_penalty
            * easy_bonus
        )
        return max(S_MIN, new_s)

    def stability_after_failure(self, s: float, d: float, r: float) -> float:
        s_main = (
            self.w[11]
            * math.pow(d, -self.w[12])
            * (math.pow(s + 1, self.w[13]) - 1)
            * math.exp((1 - r) * self.w[14])
        )
        s_min_penalty = s / math.exp(self.w[17] * self.w[18])
        return max(S_MIN, min(s_main, s_min_penalty))

    def update_weights(self, last_s: Optional[float], last_d: Optional[float], delta_t: int, rating: int, y: int):
        """
        Perform a single step of backpropagation.
        :param last_s: Stability before the review.
        :param last_d: Difficulty before the review.
        :param delta_t: Time elapsed in days.
        :param rating: User feedback (1:Fail, 2:Hard, 3:Good, 4:Easy).
        :param y: Actual outcome (0 for fail, 1 for success).
        """
        # Initial review case
        if last_s is None:
            s0 = self.init_stability(rating)
            d0 = self.init_difficulty(rating)
            r = self.power_forgetting_curve(delta_t, s0)
            
            # Simplified gradient for initial stabilities
            grad_s = self.w[20] * delta_t * (y * (1 - ((delta_t + s0) / s0)**self.w[20]) + y - 1) / (s0 * (delta_t + s0) * (((delta_t + s0) / s0)**self.w[20] - 1))
            self.w[rating - 1] -= self.lr * grad_s * 5 # Amplify learning for first reviews
            
            # No update for difficulty on first review as it's directly calculated
            return

        # Subsequent review
        r = self.power_forgetting_curve(delta_t, last_s)
        
        # dL/ds_new
        cur_s = (
            self.stability_after_success(last_s, last_d, r, rating)
            if rating > 1
            else self.stability_after_failure(last_s, last_d, r)
        )
        
        r_new = self.power_forgetting_curve(delta_t, cur_s)
        
        # Using a simplified, more stable gradient for dL/ds
        # Based on dL/dR * dR/ds
        # dL/dR = (R - y) / (R * (1 - R))
        # dR/ds = (p * t * R) / (s * (s + t))
        if r_new < 0.999 and r_new > 0.001:
            grad_s = (r_new - y) / (r_new * (1 - r_new)) * (self.w[20] * delta_t * r_new) / (cur_s * (cur_s + delta_t))
        else:
            grad_s = (r_new - y) * self.w[20] * delta_t / (cur_s * (cur_s + delta_t))

        # dL/dp (w20)
        grad_p = (y * ((last_s + delta_t)/last_s)**self.w[20] - 1) * math.log((last_s + delta_t)/last_s) / (((last_s + delta_t)/last_s)**self.w[20] - 1)
        self.w[20] -= self.lr * grad_s * grad_p * 0.1 # Smaller learning rate for p

        # Gradients for difficulty parameters (chain rule)
        # d(last_d)/dw_i
        grad_d_w4 = self.w[7]
        grad_d_w5 = -3 * self.w[7] * math.exp(3 * self.w[5])
        grad_d_w6 = -(last_d - 10) * (rating - 3) * (self.w[7] - 1) / 9
        init_d_4 = self.w[4] - math.exp(self.w[5] * 3) + 1
        delta_d_term = -self.w[6] * (rating - 3) * (last_d - 10) / 9
        grad_d_w7 = init_d_4 - (last_d + delta_d_term)

        if rating > 1: # Success
            # Common term: ds_new/d(last_d)
            ds_new_d_last_d = last_s**(1-self.w[9]) * (self.w[15] if rating==2 else 1) * \
                             (self.w[16] if rating==4 else 1) * \
                             (1 - math.exp(self.w[10]*(1-r))) * math.exp(self.w[8])

            # Gradients of S w.r.t. w_i
            g8 = last_s**(1-self.w[9]) * (self.w[15] if rating==2 else 1) * \
                 (self.w[16] if rating==4 else 1) * (1 - math.exp(self.w[10]*(1-r))) * \
                 (last_d-11) * math.exp(self.w[8])
            g9 = last_s**(1-self.w[9]) * (self.w[15] if rating==2 else 1) * \
                 (self.w[16] if rating==4 else 1) * (last_d-11) * \
                 (math.exp(self.w[10]*(1-r)) - 1) * math.exp(self.w[8]) * math.log(last_s)
            g10 = -last_s**(1-self.w[9]) * (self.w[15] if rating==2 else 1) * \
                  (self.w[16] if rating==4 else 1) * (last_d-11) * (1-r) * \
                  math.exp(self.w[10]*(1-r) + self.w[8])
            
            self.w[8] -= self.lr * grad_s * g8
            self.w[9] -= self.lr * grad_s * g9
            self.w[10] -= self.lr * grad_s * g10

            if rating == 2:
                g15 = last_s**(1-self.w[9]) * (self.w[16] if rating==4 else 1) * \
                      (1-math.exp(self.w[10]*(1-r))) * (last_d-11) * math.exp(self.w[8])
                self.w[15] -= self.lr * grad_s * g15
            if rating == 4:
                g16 = last_s**(1-self.w[9]) * (self.w[15] if rating==2 else 1) * \
                      (1-math.exp(self.w[10]*(1-r))) * (last_d-11) * math.exp(self.w[8])
                self.w[16] -= self.lr * grad_s * g16

        else: # Failure
            s_main = self.w[11] * math.pow(last_d, -self.w[12]) * \
                     (math.pow(last_s + 1, self.w[13]) - 1) * math.exp((1 - r) * self.w[14])
            s_min_penalty = last_s / math.exp(self.w[17] * self.w[18])
            
            if s_main < s_min_penalty:
                # Gradients for main failure formula
                ds_new_d_last_d = last_d**(-self.w[12]-1) * self.w[11] * self.w[12] * \
                                  (1-(last_s+1)**self.w[13]) * math.exp((1-r)*self.w[14])

                g11 = ((last_s+1)**self.w[13] - 1) * math.exp((1-r)*self.w[14]) / last_d**self.w[12]
                g12 = self.w[11]*(1-(last_s+1)**self.w[13]) * math.exp((1-r)*self.w[14]) * \
                      math.log(last_d) / last_d**self.w[12]
                g13 = self.w[11]*(last_s+1)**self.w[13] * math.exp((1-r)*self.w[14]) * \
                      math.log(last_s+1) / last_d**self.w[12]
                g14 = self.w[11]*((last_s+1)**self.w[13]-1)*(1-r) * \
                      math.exp((1-r)*self.w[14]) / last_d**self.w[12]
                
                self.w[11] -= self.lr * grad_s * g11
                self.w[12] -= self.lr * grad_s * g12
                self.w[13] -= self.lr * grad_s * g13
                self.w[14] -= self.lr * grad_s * g14
            else:
                # Gradients for min stability penalty
                ds_new_d_last_d = 0 # No dependency on difficulty
                g17 = -last_s * self.w[18] * math.exp(-self.w[17] * self.w[18])
                g18 = -last_s * self.w[17] * math.exp(-self.w[17] * self.w[18])
                self.w[17] -= self.lr * grad_s * g17
                self.w[18] -= self.lr * grad_s * g18

        # Update difficulty weights via chain rule: dL/dw = dL/dS * dS/dD * dD/dw
        if ds_new_d_last_d != 0:
            self.w[4] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w4
            self.w[5] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w5
            self.w[6] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w6
            self.w[7] -= self.lr * grad_s * ds_new_d_last_d * grad_d_w7

        self.clamp_weights()

    def clamp_weights(self):
        # Clamping bounds based on provided instructions
        self.w[0] = max(S_MIN, min(self.w[0], 100))
        self.w[1] = max(S_MIN, min(self.w[1], 100))
        self.w[2] = max(S_MIN, min(self.w[2], 100))
        self.w[3] = max(S_MIN, min(self.w[3], 100))
        self.w[4] = max(1, min(self.w[4], 10))
        self.w[5] = max(0.001, min(self.w[5], 4))
        self.w[6] = max(0.001, min(self.w[6], 4))
        self.w[7] = max(0.001, min(self.w[7], 0.75))
        self.w[8] = max(0, min(self.w[8], 4.5))
        self.w[9] = max(0, min(self.w[9], 0.8))
        self.w[10] = max(0.001, min(self.w[10], 3.5))
        self.w[11] = max(0.001, min(self.w[11], 5))
        self.w[12] = max(0.001, min(self.w[12], 0.25))
        self.w[13] = max(0.001, min(self.w[13], 0.9))
        self.w[14] = max(0, min(self.w[14], 4))
        self.w[15] = max(0, min(self.w[15], 1))
        self.w[16] = max(1, min(self.w[16], 6))
        self.w[17] = max(0, min(self.w[17], 2))
        self.w[18] = max(0, min(self.w[18], 2))
        self.w[19] = max(0.01, min(self.w[19], 0.8))
        self.w[20] = max(0.1, min(self.w[20], 0.8))

In [6]:
DEFAULT_PARAMETER = [
    0.212,  # w[0] initial stability for again
    1.2931,  # w[1] initial stability for hard
    2.3065,  # w[2] initial stability for good
    8.2956,  # w[3] initial stability for easy
    6.4133,  # w[4] initial difficulty
    0.8334,  # w[5] initial difficulty rating offset
    3.0194,  # w[6] next difficulty rating offset
    0.001,  # w[7] next difficulty reversion
    1.8722,  # w[8] stability after success
    0.1666,  # w[9] stability after success S decay
    0.796,  # w[10] stability after success R bonus
    1.4835,  # w[11] stability after failure
    0.0614,  # w[12] stability after failure
    0.2629,  # w[13] stability after failure
    1.6483,  # w[14] stability after failure
    0.6014,  # w[15] stability after success
    1.8729,  # w[16] stability after success
    0.5425,  # w[17] short term stability
    0.0912,  # w[18] short term stability
    0.0658,  # w[19] short term stability
    0.1542,  # w[20] forgetting curve decay
]
fsrs = FSRS_one_step(DEFAULT_PARAMETER)

In [7]:
last_s = None
last_d = None
last_rating = 3

new_s = fsrs.init_stability(last_rating)
print(new_s)
new_d = fsrs.init_difficulty(last_rating)

delta_t = 1

fsrs.update_weights(last_s, last_d, delta_t, last_rating, 1)
print(fsrs.w)

2.3065
[0.212, 1.2931, 2.3075109563210816, 8.2956, 6.4133, 0.8334, 3.0194, 0.001, 1.8722, 0.1666, 0.796, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.1542]


In [8]:
last_s = new_s
last_d = new_d
last_t = 2
r = fsrs.power_forgetting_curve(last_t, last_s)
last_rating = 3

new_s = fsrs.stability_after_success(last_s, last_d, r, last_rating)
print(new_s)
new_d = fsrs.next_difficulty(last_d, last_rating)

delta_t = 3

fsrs.update_weights(last_s, last_d, delta_t, last_rating, 1)
print(fsrs.w)

11.09180365110379
[0.212, 1.2931, 2.3075109563210816, 8.2956, 6.413299974388176, 0.8334009362348808, 3.0194, 0.001176458667875016, 1.8724274815527482, 0.16640988656348088, 0.7962997137236069, 1.4835, 0.0614, 0.2629, 1.6483, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.1542016235475612]


In [9]:
last_s = new_s
last_d = new_d
last_t = 2
r = fsrs.power_forgetting_curve(last_t, last_s)
last_rating = 1

new_s = fsrs.stability_after_failure(last_s, last_d, r)
print(new_s)
new_d = fsrs.next_difficulty(last_d, last_rating)

delta_t = 3

fsrs.update_weights(last_s, last_d, delta_t, last_rating, 1)
print(fsrs.w)

1.36740013700412
[0.212, 1.2931, 2.3075109563210816, 8.2956, 6.413299938353644, 0.8334022534734087, 3.0193463674272074, 0.0012251500068517023, 1.8724274815527482, 0.16640988656348088, 0.7962997137236069, 1.484209935014972, 0.060612990868730676, 0.26836092728593314, 1.6483381684385614, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.15421973009588114]


In [10]:
last_s = None
last_d = None
last_rating = 3

for y in [0 if i % 10 == 0 else 1 for i in range(700)]:
    new_s = fsrs.init_stability(last_rating)
    new_d = fsrs.init_difficulty(last_rating)
    delta_t = 1
    fsrs.update_weights(last_s, last_d, delta_t, last_rating, y)

print(new_s)
print(fsrs.w)

3.701100094648134
[0.212, 1.2931, 3.701543274098277, 8.2956, 6.413299938353644, 0.8334022534734087, 3.0193463674272074, 0.0012251500068517023, 1.8724274815527482, 0.16640988656348088, 0.7962997137236069, 1.484209935014972, 0.060612990868730676, 0.26836092728593314, 1.6483381684385614, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.15421973009588114]


In [11]:
last_s = 2
last_d = 5
last_t = 2
r = fsrs.power_forgetting_curve(last_t, last_s)
last_rating = 3
for y in [0 if i % 10 == 0 else 1 for i in range(600)]:
    new_s = fsrs.stability_after_success(last_s, last_d, r, last_rating)
    new_d = fsrs.next_difficulty(last_d, last_rating)

    delta_t = 10

    fsrs.update_weights(last_s, last_d, delta_t, last_rating, y)
print(new_s)
print(fsrs.w)

6.342864079476962
[0.212, 1.2931, 3.701543274098277, 8.2956, 6.4136497486461606, 0.82084281609334, 3.0193463674272074, 0.009681091822580885, 1.7885961898479326, 0.22451731005137343, 0.6753931757123799, 1.484209935014972, 0.060612990868730676, 0.26836092728593314, 1.6483381684385614, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.17414558907713734]


In [12]:
last_s = 50
last_d = 5
last_t = 2
r = fsrs.power_forgetting_curve(last_t, last_s)
last_rating = 1
for y in [0 if i % 10 == 0 else 1 for i in range(500)]:
    new_s = fsrs.stability_after_failure(last_s, last_d, r)
    new_d = fsrs.next_difficulty(last_d, last_rating)

    delta_t = 2

    fsrs.update_weights(last_s, last_d, delta_t, last_rating, y)
print(new_s)
print(fsrs.w)

3.7886169429977454
[0.212, 1.2931, 3.701543274098277, 8.2956, 6.413645735056117, 0.8209840243848162, 3.018806098946315, 0.012603483722860518, 1.7885961898479326, 0.22451731005137343, 0.6753931757123799, 1.4833571896314541, 0.06381863076347875, 0.3441622866905756, 1.6483368526083544, 0.6014, 1.8729, 0.5425, 0.0912, 0.0658, 0.231276937759554]
