In [44]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import dgl
from dgl.nn.pytorch import GraphConv
import networkx as nx

In [58]:
class GnnDqn(nn.Module):
    def __init__(self, lr, hidden_dim=100, in_dim=3, out_dim=3):
        super(GnnDqn, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.A = nn.Linear(hidden_dim, out_dim)
        self.V = nn.Linear(hidden_dim, 1)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()

    def forward(self, g):
        h = g.ndata['features'] # maps the state vector to nodes
        h = F.relu(self.conv1(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.V(hg), self.A(hg)

In [59]:
def make_graph(obs):
    src = np.array([0,1])
    dst = np.array([1,0])
    graph = dgl.DGLGraph((src, dst))
    features = torch.tensor(np.array([[obs[0], obs[1], obs[4]],
                        [obs[2], obs[3], obs[5]]]), dtype=torch.float32)
    graph.ndata['features'] = features
    return graph

def make_empty_graph():
    src = np.array([0,1])
    dst = np.array([1,0])
    graph = dgl.DGLGraph((src, dst))
    return graph


def make_graph_batch(graphs, states):
    for i in range(len(graphs)):
        obs = states[i]
        graphs[i].ndata['features'] = torch.tensor(np.array([[obs[0], obs[1], obs[4]],
                                                    [obs[2], obs[3], obs[5]]]), dtype=torch.float32)
    return dgl.batch(graphs)

GRAPHS = []
for i in range(64):
    GRAPHS.append(make_empty_graph())

In [60]:
class ReplayBuffer():
    def __init__(self, max_size, input_shape):
        self.mem_size = max_size
        self.mem_cntr = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape), 
                                    dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_shape), 
                                    dtype=np.float32)
        self.action_memory = np.zeros(self.mem_size, 
                                    dtype=np.int64)
        self.reward_memory = np.zeros(self.mem_size, 
                                dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, 
                                dtype=np.int64)
        
    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_cntr % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.reward_memory[index] = reward
        self.action_memory[index] = action
        self.terminal_memory[index] = done
        self.mem_cntr += 1

    def sample_buffer(self, batch_size):
        max_mem = min(self.mem_cntr, self.mem_size)
        batch = np.random.choice(max_mem, batch_size, replace=False)
        states = self.state_memory[batch]
        actions = self.action_memory[batch]
        rewards = self.reward_memory[batch]
        states_ = self.new_state_memory[batch]
        terminal = self.terminal_memory[batch]

        return states, actions, rewards, states_, terminal

In [61]:
class Agent():
    def __init__(self, gamma, epsilon, lr, n_actions, input_dims,
                    mem_size, batch_size, eps_min=0.01, eps_dec=5e-7,
                    replace=1000):
        
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.learn_step_counter = 0
        self.batch_size = batch_size
        self.eps_min = eps_min
        self.eps_dec = eps_dec
        self.replace_target_cnt = replace

        self.action_space = [i for i in  range(self.n_actions)]

        self.memory = ReplayBuffer(mem_size, input_dims)

        self.q_eval = GnnDqn(self.lr)

        self.q_next = GnnDqn(self.lr)

    def choose_action(self, observation):
        if np.random.random() > self.epsilon:
            state_graph = make_graph(observation)
            _, advantage = self.q_eval(state_graph)
            action = torch.argmax(advantage).item()
        else:
            action = np.random.choice(self.action_space)
        return action

    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_cnt == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())
    
    def decrement_epsilon(self):
        self.epsilon = self.epsilon -  self.eps_dec if self.epsilon > self.eps_min else self.eps_min

    def learn(self):
        if self.memory.mem_cntr < self.batch_size:
            return

        self.q_eval.optimizer.zero_grad()

        self.replace_target_network()

        states, actions, rewards, states_, dones = self.memory.sample_buffer(self.batch_size)

        states = make_graph_batch(GRAPHS, states)
        actions = torch.tensor(actions)
        dones = torch.tensor(dones)
        rewards = torch.tensor(rewards)
        states_ = make_graph_batch(GRAPHS, states_)

        indices = np.arange(self.batch_size)

        V_s, A_s = self.q_eval(states)
        V_s_, A_s_ = self.q_next(states_)

        V_s_eval, A_s_eval = self.q_eval(states_)

        q_pred = torch.add(V_s,
                        (A_s - A_s.mean(dim=1, keepdim=True)))[indices, actions]

        q_next = torch.add(V_s_, (A_s_ - A_s_.mean(dim=1, keepdim=True)))
        q_eval = torch.add(V_s_eval, (A_s_eval - A_s_eval.mean(dim=1, keepdim=True)))

        max_actions = torch.argmax(q_eval, dim=1)

        q_target = rewards + (1 - dones)*self.gamma*q_next[indices, max_actions]

        loss = self.q_eval.loss(q_target, q_pred)

        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1
        self.decrement_epsilon()

In [62]:
if __name__ == '__main__':
    env = gym.make('Acrobot-v1')
    no_games = 1000
    agent = Agent(gamma=0.99, epsilon=1.0, 
                    lr=5e-3, input_dims=[6], n_actions=3, 
                    mem_size=1000000, eps_min=0.01, 
                    batch_size=64, eps_dec=1e-4, 
                    replace=100)

    scores = []
    avg_scores = []
    for i in range(no_games):
        done = False
        observation = env.reset()
        score = 0

        while not done:
            action = agent.choose_action(observation)
            observation_, reward, done, info = env.step(action)
            score += reward
            agent.store_transition(observation, action, reward, observation_, done)
            agent.learn()
            observation = observation_

        scores.append(score)
        avg_score = np.mean(scores[-100:])
        avg_scores.append(avg_score)
        print(i, avg_score, agent.epsilon)

        if i >= 300 and i%50 ==0: 
            plt.plot(np.arange(len(avg_scores)), avg_scores)
            plt.xlabel("No. of games played")
            plt.ylabel("Avg. returns")
            plt.show()

0 -500.0 0.9563000000000048
1 -500.0 0.9063000000000103
2 -500.0 0.8563000000000158
3 -500.0 0.8063000000000213
4 -500.0 0.7563000000000268
5 -500.0 0.7063000000000323
6 -500.0 0.6563000000000379
7 -500.0 0.6063000000000434
8 -500.0 0.5563000000000489
9 -500.0 0.5063000000000544
10 -500.0 0.4563000000000599
11 -500.0 0.4063000000000654
12 -480.3076923076923 0.3818000000000681
13 -477.92857142857144 0.337000000000073
14 -467.26666666666665 0.30510000000007653
15 -460.9375 0.2684000000000806
16 -454.52941176470586 0.23310000000008446
17 -445.8333333333333 0.20320000000008775
18 -429.4736842105263 0.18960000000008925
19 -433.0 0.13960000000009476
20 -422.6190476190476 0.11800000000009617
21 -413.0 0.09680000000009556
22 -403.17391304347825 0.07800000000009502
23 -401.5416666666667 0.041500000000093976
24 -392.64 0.02350000000009373
25 -384.5769230769231 0.01
26 -388.85185185185185 0.01
27 -381.32142857142856 0.01
28 -385.41379310344826 0.01
29 -385.2 0.01
30 -378.51612903225805 0.01
31 -3

KeyboardInterrupt: 