In [None]:
from dataclasses import dataclass, field
from typing import Generic, Optional, TypeVar

A = TypeVar("A")


@dataclass
class Step(Generic[A]):
    parent: Optional["Step[A]"]
    action: A
    reward: float
    children: list["Step[A]"] = field(default_factory=list)

    @property
    def state(self) -> list[A]:
        """
        S - The state (the history of actions) before taking this action
        """
        if not self.parent:
            return []
        return self.parent.next_state

    @property
    def next_state(self) -> list[A]:
        """
        S' - The next state (the history of actions) after taking this action
        """
        return self.state + [self.action]

    def state_value(self, gamma: float = 0.9) -> float:
        """
        V(s) - The expected future reward before taking this action
        Calculated as sum of Q-values for all possible actions from the parent state
        Same as Q(s,a) if this is the root node
        """
        return (
            self.parent.next_state_value(gamma)
            if self.parent
            else self.state_action_value(gamma)
        )

    def next_state_value(self, gamma: float = 0.9) -> float:
        """
        V(s') - The expected future reward after taking this action
        Calculated as sum of Q-values for all possible actions
        """
        if not self.children:
            return 0.0
        return sum(child.state_action_value(gamma) for child in self.children) / len(
            self.children
        )

    def state_action_value(self, gamma: float = 0.9) -> float:
        """
        Q(s,a) - The expected future reward from taking this action in the parent state
        Calculated as immediate reward plus discounted future state value
        """
        return self.reward + gamma * self.next_state_value(gamma)

    def advantage(self, gamma: float = 0.9) -> float:
        """
        A(s, a) - The advantage of taking this action in the current state
        Calculated as the difference between Q(s,a) and V(s)
        """
        return self.state_action_value(gamma) - self.state_value(gamma)