In [2]:
'''
This is a simple implementation of deep q-network (DQN) for mountain car 
(MountainCarContinuous-v0) modlue from OpenAI Gym
'''
import gym
import numpy as np 
import torch
import torch.nn as nn 
import torch.optim as optim

BATCH_SIZE = 1000
LEARNING_RATE = 0.001
GAMMA = 0.9
ACTION_SIZE = 21
MAX_EPISODES = 10

# neural network model 
class Net(nn.Module):
    def __init__(self, obs_size, ACTION_SIZE):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, ACTION_SIZE)
        )

    def forward(self, x):
        return self.net(x) 

def state_action(net, obs, act): 
    obs_v = torch.FloatTensor([obs])
    value = net(obs_v) 
    act_index = int(float(ACTION_SIZE/2)*act[0])+int(ACTION_SIZE/2) 
    v = value[0][act_index] 
    return v 

def evaluate(batch, net, net_target): 
    value = 0.0 
    for state, next_state, action, reward, done in batch:
        if done == False: 
            v = state_action(net,state,action) 
            state_v = torch.FloatTensor([next_state])
            target = net_target(state_v)  
            index = target.max(dim = -1)[1]
            vv_target = target[0][index] 
            v_target = vv_target[0]
            difference = v - reward-GAMMA*v_target
            value += difference*difference 
        else: 
            v = state_action(net,state,batch_action) 
            difference = v-reward 
            value += difference*difference 
    return value

if __name__ == "__main__":
    env = gym.make('MountainCarContinuous-v0')

    net = Net(env.observation_space.shape[0], ACTION_SIZE) 
    net_target = Net(env.observation_space.shape[0], ACTION_SIZE) 
    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    print(net)
#   prepare initial training batch
    batch_state = [] 
    batch_next_state = []
    batch_reward = [] 
    batch_action = []
    batch_done = []
    total_reward = 0.0 
    act = np.zeros((1,),dtype = float) 
    no_iter = 0 
    print("no of iterations         mean reward:")
    no_done = 0 
    obs = env.reset() 
    while no_done < MAX_EPISODES: 
        batch_state.append(obs) 
        state_v = torch.FloatTensor([obs])
        target = net_target(state_v)  
        index = target.max(dim = -1)[1] 
        index_v = index.numpy() 
        act[0] = (index_v-int(ACTION_SIZE/2))/float(int(ACTION_SIZE/2))
        batch_action.append(act)
        obs, reward, done, _ = env.step(act) 
        batch_next_state.append(obs)
        batch_reward.append(reward)
        batch_done.append(done)
        if done: 
            no_done += 1
            obs = env.reset()
    batch = list(zip(batch_state, batch_next_state, batch_action,batch_reward, 
                                       batch_done))
    
    # start traing
    batch_state_t = [] 
    batch_next_state_t = []
    batch_reward_t = [] 
    batch_action_t = []
    batch_done_t = []
    while True:  
        batch_total_reward = 0.0
        total_reward = 0.0 
        obs = env.reset() 
        no_done = 0
        while no_done < MAX_EPISODES: 
            batch_state_t.append(obs) 
            state_v = torch.FloatTensor([obs])
            target = net_target(state_v)  
            index = target.max(dim = -1)[1] 
            index_v = index.numpy() 
            act[0] = (index_v-int(ACTION_SIZE/2))/float(int(ACTION_SIZE/2))
            batch_action_t.append(act)
            obs, reward, done, _ = env.step(act) 
            batch_next_state_t.append(obs)
            batch_reward_t.append(reward)
            batch_done_t.append(done) 
            total_reward += reward 
            if done: 
                no_done += 1
                obs = env.reset() 
                batch_total_reward += total_reward
                total_reward = 0.0 
        if len(batch_done_t) > BATCH_SIZE:
            batch_t = random.sample(list(zip(batch_state_t, batch_next_state_t, batch_action_t,
                                             batch_reward_t, batch_done_t)), BATCH_SIZE) 
            size = BATCH_SIZE 
        else: 
            batch_t = list(zip(batch_state_t, batch_next_state_t, batch_action_t,batch_reward_t,
                               batch_done_t))
            size = len(batch_done_t)
        
        del batch[len(batch)-size:]
        for b in batch_t: 
            batch.append(b)

        mean = batch_total_reward/no_done
        if mean > 90: 
            torch.save(net.state_dict(), "best_solution_for_mountain_car_continuous_v0.data")
            no_iter = no_iter+1
            print("problem solved with number of iterations and best mean reward: ") 
            print("      ",no_iter, "           ",mean) 
            break
        no_iter = no_iter+1
        print("      ",no_iter, "           ",mean) 
        batch_t = random.sample(batch, size)
        optimizer.zero_grad() 
        loss_t = evaluate(batch_t, net, net_target)
        loss_t.backward()
        optimizer.step()                                 
        net_target.load_state_dict(net.state_dict()) 
        batch_state_t = []
        batch_next_state_t = []
        batch_reward_t = [] 
        batch_action_t = []
        batch_done_t = []


[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
Net(
  (net): Sequential(
    (0): Linear(in_features=2, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=21, bias=True)
  )
)
no of iterations         mean reward:
       1             -0.9990000000000008
       2             -48.95099999999961
       3             -48.95099999999961
       4             -48.95099999999961
       5             -48.95099999999961
       6             -48.95099999999961
       7             -54.93899999999978
       8             -58.211999999999875
       9             -59.6684999999999
       10             -60.44999999999992
       11             -61.23899999999996
       12             -62.7255
       13             -62.342