模拟数据集中的动作,就是传统的深度学习而已

In [1]:
import torch


#封装数据集
class Dataset(torch.utils.data.Dataset):

    def __init__(self):
        import numpy as np
        data = np.loadtxt('离散动作.txt')
        self.state = torch.FloatTensor(data[:, :4])
        self.action = torch.LongTensor(data[:, -1])

    def __len__(self):
        return len(self.state)

    def __getitem__(self, i):
        return self.state[i], self.action[i]


dataset = Dataset()

len(dataset), dataset[0]

(20000, (tensor([-0.0028,  0.0180, -0.0188, -0.0368]), tensor(0)))

In [2]:
#数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=8,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader))

(2500,
 [tensor([[ 0.1513,  0.3574, -0.0094, -0.4717],
          [ 0.0492,  0.1657, -0.0149, -0.1564],
          [ 0.1058, -0.0175, -0.0058,  0.1005],
          [ 0.0723,  0.1691, -0.0016, -0.1711],
          [ 0.0157,  0.2248,  0.0029, -0.2410],
          [ 0.1097,  0.1393, -0.0061, -0.1300],
          [ 0.1481, -0.0563, -0.0111,  0.1441],
          [ 0.1530, -0.2372, -0.0022,  0.4056]]),
  tensor([0, 0, 1, 0, 0, 0, 0, 1])])

In [3]:
#定义模型
model = torch.nn.Sequential(
    torch.nn.Linear(4, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 2),
)

model(torch.randn(2, 4)).shape

torch.Size([2, 2])

In [4]:
#训练
def train():
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(10):
        for i, (state, action) in enumerate(loader):
            out = model(state)

            loss = loss_fn(out, action)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        if epoch % 1 == 0:
            out = out.argmax(dim=1)
            acc = (out == action).sum().item() / len(action)
            print(epoch, loss.item(), acc)


train()

0 0.29450610280036926 0.875



KeyboardInterrupt



In [None]:
import gym


#定义环境
class MyWrapper(gym.Wrapper):

    def __init__(self):
        env = gym.make('CartPole-v1', render_mode='rgb_array')
        super().__init__(env)
        self.env = env
        self.step_n = 0

    def reset(self):
        state, _ = self.env.reset()
        self.step_n = 0
        return state

    def step(self, action):
        state, reward, terminated, truncated, info = self.env.step(action)
        over = terminated or truncated

        #限制最大步数
        self.step_n += 1
        if self.step_n >= 200:
            over = True
        
        #没坚持到最后,扣分
        if over and self.step_n < 200:
            reward = -1000

        return state, reward, over

    #打印游戏图像
    def show(self):
        from matplotlib import pyplot as plt
        plt.figure(figsize=(3, 3))
        plt.imshow(self.env.render())
        plt.show()


env = MyWrapper()

env.reset()

env.show()

In [None]:
from IPython import display
import random


#玩一局游戏并记录数据
def play(show=False):
    reward_sum = 0

    state = env.reset()
    over = False
    while not over:
        action = model(torch.FloatTensor(state).reshape(1, 4)).argmax().item()
        if random.random() < 0.1:
            action = env.action_space.sample()

        state, reward, over = env.step(action)
        reward_sum += reward

        if show:
            display.clear_output(wait=True)
            env.show()

    return reward_sum


#测试
sum([play() for _ in range(20)]) / 20

In [None]:
play(True)