In [1]:
from sympy import exp
from sympy import symbols, log, diff

y, s, t = symbols("y s t")

# define retrievability function R(t, s)
R_s = (1 + t / (9 * s)) ** -1

# define loss function L
L_new = -(y * log(R_s) + (1 - y) * log(1 - R_s))

# compute dL/ds
dL_ds = diff(L_new, s).simplify()
dL_ds

(9*s*(1 - y) - t*y)/(s*(9*s + t))

In [2]:
last_s, last_d = symbols("last_s last_d")
w8, w9, w10, w15, w16 = symbols("w8 w9 w10 w15 w16")

r = (1 + t / (9 * last_s)) ** -1

# define new_s
new_s = last_s * (
    1
    + exp(w8)
    * (11 - last_d)
    * last_s ** (-w9)
    * (exp((1 - r) * w10) - 1)
    * w15
    * w16
)

# compute gradient
gradient_w8 = diff(new_s, w8).simplify()
gradient_w9 = diff(new_s, w9).simplify()
gradient_w10 = diff(new_s, w10).simplify()
gradient_w15 = diff(new_s, w15)
gradient_w16 = diff(new_s, w16)

gradient_w8, gradient_w9, gradient_w10, gradient_w15, gradient_w16

(last_s**(1 - w9)*w15*w16*(1 - exp(t*w10/(9*last_s + t)))*(last_d - 11)*exp(w8),
 last_s**(1 - w9)*w15*w16*(last_d - 11)*(exp(t*w10/(9*last_s + t)) - 1)*exp(w8)*log(last_s),
 last_s**(1 - w9)*t*w15*w16*(11 - last_d)*exp(t*w10/(9*last_s + t) + w8)/(9*last_s + t),
 last_s*w16*(11 - last_d)*(exp(w10*(1 - 1/(1 + t/(9*last_s)))) - 1)*exp(w8)/last_s**w9,
 last_s*w15*(11 - last_d)*(exp(w10*(1 - 1/(1 + t/(9*last_s)))) - 1)*exp(w8)/last_s**w9)

In [3]:
w11, w12, w13, w14 = symbols("w11 w12 w13 w14")

new_s = w11 * last_d ** (-w12) * ((last_s + 1) ** w13 - 1) * exp((1 - r) * w14)

# compute gradient
gradient_w11 = diff(new_s, w11).simplify()
gradient_w12 = diff(new_s, w12).simplify()
gradient_w13 = diff(new_s, w13).simplify()
gradient_w14 = diff(new_s, w14).simplify()

gradient_w11, gradient_w12, gradient_w13, gradient_w14

(((last_s + 1)**w13 - 1)*exp(t*w14/(9*last_s + t))/last_d**w12,
 w11*(1 - (last_s + 1)**w13)*exp(t*w14/(9*last_s + t))*log(last_d)/last_d**w12,
 w11*(last_s + 1)**w13*exp(t*w14/(9*last_s + t))*log(last_s + 1)/last_d**w12,
 t*w11*((last_s + 1)**w13 - 1)*exp(t*w14/(9*last_s + t))/(last_d**w12*(9*last_s + t)))

In [4]:
from typing import List
import math


def power_forgetting_curve(t, s):
    return (1 + t / (9 * s)) ** -1


class FSRS:
    def __init__(self, w: List[float]):
        self.w = w
        self.lr = 1e-3

    def init_stability(self, rating):
        return self.w[rating - 1]

    def init_difficulty(self, rating):
        return self.w[4] - self.w[5] * (rating - 3)

    def next_difficulty(self, last_d, rating):
        return last_d - self.w[6] * (rating - 3)

    def stability_after_success(self, last_s, last_d, r, rating):
        hard_penalty = self.w[15] if rating == 2 else 1
        easy_bonus = self.w[16] if rating == 4 else 1
        new_s = last_s * (
            1
            + math.exp(self.w[8])
            * (11 - last_d)
            * math.pow(last_s, -self.w[9])
            * (math.exp((1 - r) * self.w[10]) - 1)
            * hard_penalty
            * easy_bonus
        )
        return new_s

    def stability_after_failure(self, last_s, last_d, r):
        new_s = (
            self.w[11]
            * math.pow(last_d, -self.w[12])
            * (math.pow(last_s + 1, self.w[13]) - 1)
            * math.exp((1 - r) * self.w[14])
        )
        return new_s
    
    def clamp_weights(self):
        self.w[0] = max(0, min(self.w[0], 365))
        self.w[1] = max(0, min(self.w[1], 365))
        self.w[2] = max(0, min(self.w[2], 365))
        self.w[3] = max(0, min(self.w[3], 365))
        self.w[4] = max(1, min(self.w[4], 10))
        self.w[5] = max(0.1, min(self.w[5], 5))
        self.w[6] = max(0.1, min(self.w[6], 5))
        self.w[7] = max(0, min(self.w[7], 0.5))
        self.w[8] = max(0, min(self.w[8], 3))
        self.w[9] = max(0.1, min(self.w[9], 0.8))
        self.w[10] = max(0.01, min(self.w[10], 2.5))
        self.w[11] = max(0.5, min(self.w[11], 5))
        self.w[12] = max(0.01, min(self.w[12], 0.2))
        self.w[13] = max(0.01, min(self.w[13], 0.9))
        self.w[14] = max(0.01, min(self.w[14], 2))
        self.w[15] = max(0, min(self.w[15], 1))
        self.w[16] = max(1, min(self.w[16], 4))

    def update_weights(self, last_s, last_d, cur_s, cur_d, last_rating, delta_t, y):
        if last_s is None:
            grad = (-delta_t * y + 9 * self.w[last_rating - 1] * (1 - y)) / (
                self.w[last_rating - 1] * (delta_t + 9 * self.w[last_rating - 1])
            )
            self.w[last_rating - 1] -= grad * self.lr
        elif last_rating > 1:
            # (9*s*(1 - y) - t*y)/(s*(9*s + t)), (9*s*(1 - y) - t*y)/(s*(9*s + t))
            grad_s = (9 * cur_s * (1 - y) - delta_t * y) / \
                (cur_s * (9 * cur_s + delta_t))
            # last_s**(1 - w9)*w15*w16*(1 - exp(t*w10/(9*s + t)))*(last_d - 11)*exp(w8)
            grad_8 = (
                last_s ** (1 - self.w[9])
                * (self.w[15] if last_rating == 2 else 1)
                * (self.w[16] if last_rating == 4 else 1)
                * (1 - math.exp(delta_t * self.w[10] / (9 * last_s + delta_t)))
                * (last_d - 11)
                * math.exp(self.w[8])
            )
            # last_s**(1 - w9)*w15*w16*(last_d - 11)*(exp(t*w10/(9*s + t)) - 1)*exp(w8)*log(last_s)
            grad_9 = (
                last_s ** (1 - self.w[9])
                * (self.w[15] if last_rating == 2 else 1)
                * (self.w[16] if last_rating == 4 else 1)
                * (last_d - 11)
                * (math.exp(delta_t * self.w[10] / (9 * last_s + delta_t)) - 1)
                * math.exp(self.w[8])
                * math.log(last_s)
            )
            # last_s**(1 - w9)*t*w15*w16*(11 - last_d)*exp(t*w10/(9*s + t) + w8)/(9*s + t)
            grad_10 = (
                last_s ** (1 - self.w[9])
                * delta_t
                * (self.w[15] if last_rating == 2 else 1)
                * (self.w[16] if last_rating == 4 else 1)
                * (11 - last_d)
                * math.exp((delta_t * self.w[10] / (9 * last_s + delta_t)) + self.w[8])
                / (9 * last_s + delta_t)
            )
            # last_s*w16*(11 - last_d)*(exp(w10*(1 - 1/(1 + t/(9*s)))) - 1)*exp(w8)/last_s**w9,
            grad_15 = (
                last_s
                * self.w[16]
                * (11 - last_d)
                * (math.exp(self.w[10] * (1 - 1 / (1 + delta_t / (9 * last_s)))) - 1)
                * math.exp(self.w[8])
                / last_s ** self.w[9]
            ) if last_rating == 2 else 0
            # last_s*w15*(11 - last_d)*(exp(w10*(1 - 1/(1 + t/(9*s)))) - 1)*exp(w8)/last_s**w9
            grad_16 = (
                last_s
                * self.w[15]
                * (11 - last_d)
                * (math.exp(self.w[10] * (1 - 1 / (1 + delta_t / (9 * last_s)))) - 1)
                * math.exp(self.w[8])
                / last_s ** self.w[9]
            ) if last_rating == 4 else 0
            self.w[8] -= grad_s * grad_8 * self.lr
            self.w[9] -= grad_s * grad_9 * self.lr
            self.w[10] -= grad_s * grad_10 * self.lr
            self.w[15] -= grad_s * grad_15 * self.lr
            self.w[16] -= grad_s * grad_16 * self.lr
        else:
            # (9*s*(1 - y) - t*y)/(s*(9*s + t)), (9*s*(1 - y) - t*y)/(s*(9*s + t))
            grad_s = (9 * cur_s * (1 - y) - delta_t * y) / \
                (cur_s * (9 * cur_s + delta_t))
            # ((s + 1)**w13 - 1)*exp(t*w14/(9*s + t))/d**w12
            grad_11 = ((last_s + 1) ** self.w[13] - 1) * math.exp(
                delta_t * self.w[14] / (9 * last_s + delta_t)
            ) / last_d ** self.w[12]
            #  w11*(1 - (s + 1)**w13)*exp(t*w14/(9*s + t))*log(d)/d**w12
            grad_12 = (
                self.w[11]
                * (1 - (last_s + 1) ** self.w[13])
                * math.exp(delta_t * self.w[14] / (9 * last_s + delta_t))
                * math.log(last_d)
                / last_d ** self.w[12]
            )
            # w11*(s + 1)**w13*exp(t*w14/(9*s + t))*log(s + 1)/d**w12
            grad_13 = (
                self.w[11]
                * (last_s + 1) ** self.w[13]
                * math.exp(delta_t * self.w[14] / (9 * last_s + delta_t))
                * math.log(last_s + 1)
                / last_d ** self.w[12]
            )
            # t*w11*((s + 1)**w13 - 1)*exp(t*w14/(9*s + t))/(d**w12*(9*s + t))
            grad_14 = (
                delta_t
                * self.w[11]
                * ((last_s + 1) ** self.w[13] - 1)
                * math.exp(delta_t * self.w[14] / (9 * last_s + delta_t))
                / (last_d ** self.w[12] * (9 * last_s + delta_t))
            )
            self.w[11] -= grad_s * grad_11 * self.lr
            self.w[12] -= grad_s * grad_12 * self.lr
            self.w[13] -= grad_s * grad_13 * self.lr
            self.w[14] -= grad_s * grad_14 * self.lr
        
        self.clamp_weights()

In [5]:
init_w = [
    0.4,
    0.9,
    2.3,
    10.9,
    4.93,
    0.94,
    0.86,
    0.01,
    1.49,
    0.14,
    0.94,
    2.18,
    0.05,
    0.34,
    1.26,
    0.29,
    2.61]

fsrs = FSRS(init_w)

In [6]:
last_s = None
last_d = None
last_rating = 3

new_s = fsrs.init_stability(last_rating)
print(new_s)
new_d = fsrs.init_difficulty(last_rating)

delta_t = 1

fsrs.update_weights(last_s, last_d, new_s, new_d, last_rating, delta_t, 1)
print(fsrs.w)

2.3
[0.4, 0.9, 2.300020036064917, 10.9, 4.93, 0.94, 0.86, 0.01, 1.49, 0.14, 0.94, 2.18, 0.05, 0.34, 1.26, 0.29, 2.61]


In [7]:
last_s = 2
last_d = 5
last_t = 2
r = power_forgetting_curve(last_t, last_s)
last_rating = 3

new_s = fsrs.stability_after_success(last_s, last_d, r, last_rating)
print(new_s)
new_d = fsrs.next_difficulty(last_d, last_rating)

delta_t = 3

fsrs.update_weights(last_s, last_d, new_s, new_d, last_rating, delta_t, 1)
print(fsrs.w)

6.762504509640942
[0.4, 0.9, 2.300020036064917, 10.9, 4.93, 0.94, 0.86, 0.01, 1.490048241318444, 0.13996656166613405, 0.9400548434573033, 2.18, 0.05, 0.34, 1.26, 0.29, 2.61]


In [8]:
last_s = 2
last_d = 5
last_t = 2
r = power_forgetting_curve(last_t, last_s)
last_rating = 1

new_s = fsrs.stability_after_failure(last_s, last_d, r)
print(new_s)
new_d = fsrs.next_difficulty(last_d, last_rating)

delta_t = 3

fsrs.update_weights(last_s, last_d, new_s, new_d, last_rating, delta_t, 1)
print(fsrs.w)

1.033201219928722
[0.4, 0.9, 2.300020036064917, 10.9, 4.93, 0.94, 0.86, 0.01, 1.490048241318444, 0.13996656166613405, 0.9400548434573033, 2.1801181009894797, 0.049585633862277455, 0.34090744413250423, 1.260036780022438, 0.29, 2.61]
