In [1]:
import numpy as np
from IPython.display import clear_output, display
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
import time
from wurm.envs import SingleSnake
DEFAULT_DEVICE = 'cpu' #set device

## Defining Neural Network.

In [2]:
# defining a fully connected neural network
def Q_Network(shape, action_dim, hidden_dim):
    flat_shape = np.product(shape) #length of the flattened state
    model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(flat_shape,hidden_dim),
                        nn.ReLU(),
                        nn.Linear(hidden_dim,hidden_dim),
                        nn.ReLU(),
                        nn.Linear(hidden_dim, action_dim),
                         ).to(DEFAULT_DEVICE)
    return model
    

## Initializing Environment for Visualization

In [17]:
env = SingleSnake(num_envs=1, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
state = env.reset()
state_dim = state.shape[1:]
action_dim = 4
qnet=Q_Network(state_dim, 4, 30)

## Visualizing the neural network. Requires Tensorboard

In [9]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
writer.add_graph(qnet, torch.Tensor(env.reset()))
writer.close()
%load_ext tensorboard
%tensorboard --logdir=runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 30600), started 14:01:49 ago. (Use '!kill 30600' to kill it.)

## Replay Buffer

In [3]:
import collections

class ReplayBuffer():
    def __init__(self, max_buffer_size: int):
        self.buffer = collections.deque()
        self.max_buffer_size = max_buffer_size

    def add_to_buffer(self, data):
        #data must be of the form (state,next_state,action,reward,terminal)
        if(len(self.buffer)==self.max_buffer_size):
            self.buffer.pop()
        self.buffer.append(data)
        
        
    def sample_minibatch(self,minibatch_length):
        states = []
        next_states = []
        actions = []
        rewards = []
        terminals = []
        for rand_int in np.random.randint(0, len(self.buffer)-1, minibatch_length):
            transition = self.buffer[rand_int]
            states.append(transition[0])
            next_states.append(transition[1])
            actions.append(transition[2])
            rewards.append(transition[3])
            terminals.append(transition[4])
        return torch.cat(states), torch.cat(next_states), torch.cat(actions), torch.cat(rewards), torch.cat(terminals)

## DQN Agent

In [4]:
class DQNAgent():
    def __init__(self, state_dim, action_dim):
        self.qnet = Q_Network(state_dim, action_dim, 50) #Set diminesion of hidden layer
        self.qnet_target = copy.deepcopy(self.qnet)
        self.qnet_optim = torch.optim.Adam( self.qnet.parameters(), lr=0.0005) #set learning rate
        self.discount_factor = torch.Tensor([0.99]).to(DEFAULT_DEVICE) # set discount factor
        self.MSELoss_function = nn.MSELoss()
        self.replay_buffer = ReplayBuffer(1000) #set size of replay buffer.
        self.tau = torch.Tensor([0.95]).to(DEFAULT_DEVICE) #set soft update tau.
        pass
    
    def soft_target_update(self,network,target_network,tau):
        for net_params, target_net_params in zip(network.parameters(), target_network.parameters()):
            target_net_params.data.copy_(net_params.data * tau + target_net_params.data * (1 - tau))
    
    def epsilon_greedy_action(self, state, epsilon):
        if np.random.uniform(0, 1) < epsilon:
                return env.random_action()  # choose random action
        else:                
            return torch.argmax(self.qnet(state), dim=1)  # choose greedy action

    def update_Q_Network(self, state, next_state, action, reward, terminals):
        qsa = torch.gather(self.qnet(state), dim=1, index=action.unsqueeze(-1)).squeeze()
        qsa_next_action = self.qnet_target(next_state)
        qsa_next_action = torch.max(qsa_next_action, dim=1)[0]
        not_terminals = ~terminals
        qsa_next_target = reward + not_terminals * self.discount_factor * qsa_next_action
        q_network_loss = self.MSELoss_function(qsa, qsa_next_target.detach())
        self.qnet_optim.zero_grad()
        q_network_loss.backward()
        self.qnet_optim.step()
        
    def update(self, update_rate):
        for i in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample_minibatch(128)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.soft_target_update(self.qnet, self.qnet_target, self.tau)

## Initializing Environment and Agent

In [5]:
num_envs = 1
env = SingleSnake(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
state = env.reset()
state_dim = state.shape[1:]
action_dim = 4
agent=DQNAgent(state_dim,action_dim)
agent.qnet.train()

Sequential(
  (0): Flatten()
  (1): Linear(in_features=100, out_features=50, bias=True)
  (2): ReLU()
  (3): Linear(in_features=50, out_features=50, bias=True)
  (4): ReLU()
  (5): Linear(in_features=50, out_features=4, bias=True)
)

## Training

In [6]:
%%time
render=False

number_of_episodes = 1000

reward_sum=torch.zeros([num_envs]).to(DEFAULT_DEVICE)
state=env.reset()
agent.qnet.train()

for i in range(1,number_of_episodes):
    action = agent.epsilon_greedy_action( state , 0.5-0.5*i/number_of_episodes) #set epsilon
    next_state, reward, terminal, _ = env.step(action)
    reward_sum+=reward
    agent.replay_buffer.add_to_buffer( (state,next_state,action,reward,terminal) )
    state = next_state
    if render:
        env.render()
    if terminal.any():
        print('episode:', i, 'sum_of_rewards_for_episode:', reward_sum[terminal])        
        clear_output(wait=True)
        state=env.reset(terminal)
        reward_sum[terminal]=0
    if i>3:
        agent.update(2)

if render:    
    env.close()

Wall time: 7.59 s


## Visualize Agent

In [7]:
env = SingleSnake(num_envs=1, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
agent.qnet.eval()
for episode in range(5):
    reward_sum = 0
    state = env.reset()

    while(1):
        action = agent.epsilon_greedy_action( state , 0.0)
        #print(agent.qnet(state))
        next_state, reward, terminal, _ = env.step(action)
        reward_sum+= reward.numpy()
        env.render()
        time.sleep(0.5)
        state = next_state
        if terminal:
            break
    print('episode:', episode, 'sum_of_rewards_for_episode:', reward_sum)

env.close()

episode: 0 sum_of_rewards_for_episode: [0.]
episode: 1 sum_of_rewards_for_episode: [0.]
episode: 2 sum_of_rewards_for_episode: [0.]
episode: 3 sum_of_rewards_for_episode: [1.]
episode: 4 sum_of_rewards_for_episode: [0.]


In [56]:
env.close()