In [1]:
import os
import time
import json
import click
import gym
import logging
import torch
import numpy as np
import torch.nn.functional as F

from policy import Q_Net, process
from data import Data

LogFolder = os.path.join(os.getcwd(), 'log')
FRAME_SKIP = 4
GAMMA = 0.9
init_epsilon = 0.5
decay_every_timestep = 100000
epsilon_decay = 0.5
final_epsilon = 0.1

# training
batchsize = 32

# experience replay storage
D = Data()

In [5]:
with open('assault.json', 'r') as f:
    cfg = json.loads(f.read())
env = gym.make(cfg['game']['gamename'])
model = Q_Net(env.action_space.n)
  
action_n = env.action_space.n
env.frameskip = 1
epsilon = init_epsilon 
optimizer = torch.optim.RMSprop(model.parameters())

obs = env.reset()
obs_list = [obs, obs, obs, obs]
state_now = process(obs_list)

break_is_true = False
for step in range(cfg['game']['timesteplimit']):
    if np.random.rand() <= epsilon:
        action = np.random.randint(env.action_space.n)
    else:
        action = model(state_now).argmax().item()
    obs_list = []
    ep_r = 0
    for i in range(FRAME_SKIP):
        obs, reward, done, _ = env.step(action)
        obs_list.append(obs)
        ep_r += reward
        if done:
            break_is_true = True
    while len(obs_list) < FRAME_SKIP:
        # when len(obs_list) < 4, done=True
        # like start state obs_list, stack more end state together
        obs_list.append(obs_list[-1])
    sequence = [state_now, action, ep_r, process(obs_list), done]
    D.push(sequence)

    # sample data
    # train model
    if len(D.data) >= batchsize*2:
        break_is_true = True

    if break_is_true:
        break


In [6]:
len(D.data)

64

In [27]:
import random
selected_data = random.sample(D.data, batchsize)
state_batch = [batch[0] for batch in selected_data]
target_q_value = None
for i in range(batchsize):
    state_ = selected_data[i][0]
    action_ = selected_data[i][1]
    reward_ = selected_data[i][2]
    next_state_ = selected_data[i][3]
    done_ = selected_data[i][4]
    q_eval = model(state_)
    if target_q_value is None:
        target_q_value = q_eval
    else:
        target_q_value = torch.cat((target_q_value, q_eval))
#     更新最新的一行就行
#     print(target_q_value[-1][action_])
    if done_:
        target_q_value[-1][action_] = reward_
    else:
        target_q_value[-1][action_] = reward_ + GAMMA*model(next_state_).max().item()

In [31]:
# test training model
import torch.utils.data as Data
x_train = torch.cat(state_batch)
y_train = target_q_value
# dataset = Data.TensorDataset(state_batch, target_q_value)
optimizer.zero_grad()
predicted_q_value = model(x_train)
loss = F.mse_loss(predicted_q_value, y_train)
loss.backward()
optimizer.step()