#### 考虑连续State空间、离散Action 空间的Q函数

In [None]:
from collections import defaultdict
from typing import Callable, List, Tuple, Optional
from pathlib import Path

import numpy as np
from tqdm import tqdm
import gymnasium as gym
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import seaborn as sns
import torch


State = int
# 表示状态特征向量的维度
StateDim = int 

Action = int
Reward = float
ActionProbDistribution = List[float]

class AbstractQFunc():
    def get_value(self, state: State, action: Action) -> float:
        raise NotImplementedError()
    
    def get_action_distribute(self, state: State) -> ActionProbDistribution:
        raise NotImplementedError()

    def get_actions_count(self) -> int:
        raise NotImplementedError()
    
    def set_value(self, state: State, action: Action, value: float) -> None:
        raise NotImplementedError()

class DeepQFunc(AbstractQFunc, torch.nn.Module):
    def __init__(self, state_dim: int, action_nums: int) -> None:
        # here use full-connect layer to represent Q function
        super().__init__() 
        self._state_dims = state_dim 
        self._action_nums = action_nums
        
        self._fc1 = torch.nn.Linear(state_dim, 128)
        self._fc2 = torch.nn.Linear(128, action_nums)

    def forward(self, x): 
        x = torch.nn.functional.relu(self._fc1(x))
        return self._fc2(x)
        

    def set_value(self, state: State, action: Action, value: float) -> None:
        self._q_table[state][action] = value

    def get_action_distribute(self, state: State) -> ActionProbDistribution:
        return self._q_table[state]

    def get_actions_count(self) -> int:
        return self._action_nums