In [5]:
import numpy as np
from nptyping import NDArray, Float, Shape

# State transition probabilities
# Each element consists of the probability of moving from state S to S'
T : NDArray[Shape["2, 2"], Float] = np.array(
    [  #  S1  S2
        [0.1, 0.1], # S1'
        [0.9, 0.9]  # S2'
    ]
)

# Rewards
# Rewards for going from state S to state S'
R : NDArray[Shape["2, 2"], Float] = np.array(
    [  #  S1  S2
        [  1, -1],   # S1'
        [ -1,  1]    # S2'
    ]
)

# ===========================================
# We want to estimate the transition probabilities and reward functions,
# so the matrices defined above will just be for reference.

# Initial Q value
# NOTE: We have to specify the dtype of this array because, if we create
# an array with just "0" and then try to assign to that specific value a float
# numpy will simply ignore your request. Make sure the dtypes align.
Q_hat : NDArray[Shape["2, 2"], Float] = np.array(
    [ #   S1   S2
        [ 0.0, 0.0], # Action 1: Move to S1
        [ 0.0, 0.0]  # Action 2: Move to S2
    ]
, dtype=float)  

# Discount factor
gamma : float = 0.5

# Exponential moving average factor
alpha : float = 0.75

# Collected samples
samples = np.array(
    [  # S  S'  R
        [1, 1,  1],
        [1, 2, -1],
        [2, 1,  1]
    ]
)

# Go through each sample
for i, sample in enumerate(samples):

    # Sample reward
    R_s = sample[2]

    # Sample "Current" state S
    # We subtract 1 to make it "indexing" compatible
    S : int = sample[0] - 1

    # Sample "Future" state S'
    # We subtract 1 to make it "indexing" compatible
    S_prime : int = sample[1] - 1

    # Action taken (Moved from state S)
    A : int = sample[0] - 1

    # Sample k
    # (Maximizes Q value for the future state S')
    S_k = R_s + gamma * np.max(Q_hat[S_prime, :])

    # Iterate for the new Q value
    # (Updates the value for the "current action and state" pair, Q(S, a))
    Q_hat[A, S] = alpha * S_k + (1 - alpha) * Q_hat[A, S]

    print(f"(Epoch {i}) S: {S+1} / A: {A+1} / S': {S_prime+1}")
    print(Q_hat)


(Epoch 0) S: 1 / A: 1 / S': 1
[[0.75 0.  ]
 [0.   0.  ]]
(Epoch 1) S: 1 / A: 1 / S': 2
[[-0.5625  0.    ]
 [ 0.      0.    ]]
(Epoch 2) S: 2 / A: 2 / S': 1
[[-0.5625  0.    ]
 [ 0.      0.75  ]]
