In [1]:
import numpy as np
from IPython.display import clear_output, display
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
import time
from wurm.envs import SingleSnake
DEFAULT_DEVICE = 'cuda' #set device

## Initializing Environment for Visualization

env = SingleSnake(num_envs=1, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
state = env.reset()
state_dim = state.shape[1:]
action_dim = 4
qnet=Q_Network(state_dim, 4, 30)

## Visualizing the neural network. Requires Tensorboard

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
writer.add_graph(qnet, torch.Tensor(env.reset()))
writer.close()
%load_ext tensorboard
%tensorboard --logdir=runs

## Replay Buffer

In [169]:
import collections

class ReplayBuffer():
    def __init__(self, max_buffer_size: int):
        #data must be of the form (state,next_state,action,reward,terminal)
        self.buffer = collections.deque(maxlen=max_buffer_size)

    def add_to_buffer(self, data):
        #data must be of the form (state,next_state,action,reward,terminal)
        self.buffer.append(data)
        
    #Sample batches and sub sample parallel environments
    def sample_minibatch(self,minibatch_length, subbatch_length):
        states = []
        next_states = []
        actions = []
        rewards = []
        terminals = []
        sub_length = self.buffer[0][0].shape[0]
        for rand_int in np.random.randint(0, len(self.buffer), minibatch_length):
            rand_int_1 = np.random.randint(0, sub_length, subbatch_length)
            transition = self.buffer[rand_int]
            states.append(transition[0][rand_int_1])
            next_states.append(transition[1][rand_int_1])
            actions.append(transition[2][rand_int_1])
            rewards.append(transition[3][rand_int_1])
            terminals.append(transition[4][rand_int_1])
        return torch.cat(states), torch.cat(next_states), torch.cat(actions), torch.cat(rewards), torch.cat(terminals)
    
    #sample parallel environments from a randomly selected memory.
    def sample(self, subbatch_length):
            rand_int = np.random.randint(0, len(self.buffer))
            rand_int_1 = np.random.randint(0, len(self.buffer[0][0]), subbatch_length)
            transition = self.buffer[rand_int]
            states=transition[0][rand_int_1]
            next_states=transition[1][rand_int_1]
            actions=transition[2][rand_int_1]
            rewards=transition[3][rand_int_1]
            terminals=transition[4][rand_int_1]
            return (states,next_states,actions,rewards,terminals)

## DQN Agent

In [170]:
class DQNAgent():
    def __init__(self, num_envs: int, buffer_size: int, NN: object, *NN_args):
        #self.qnet = torch.load("dqn80x80.h5")
        self.qnet = NN(*NN_args)
        self.qnet_target = copy.deepcopy(self.qnet)
        self.qnet_optim = torch.optim.Adam( self.qnet.parameters(), lr=0.0005) #set learning rate
        self.discount_factor = torch.Tensor([0.8]).to(DEFAULT_DEVICE) # set discount factor
        self.MSELoss_function = nn.SmoothL1Loss()
        self.replay_buffer = ReplayBuffer(buffer_size) #set size of replay buffer.
        self.target_update_interval = 1 #set for target update interval for hard target network updates
        self.update_count=0
        self.tau = 0.1 # set tau for soft target network updates
        self.num_envs = num_envs
        pass
    
    def add_to_buffer(self, data):
        self.replay_buffer.add_to_buffer(data)
        
#Update target network
    def target_update(self,network,target_network):
        for net_params, target_net_params in zip(network.parameters(), target_network.parameters()):
            target_net_params.data.copy_(net_params.data)
     
#Soft update target network
    def soft_target_update(self,network,target_network):
        for net_params, target_net_params in zip(network.parameters(), target_network.parameters()):
            target_net_params.data.copy_(net_params.data*self.tau + (1-self.tau)*target_net_params)
    
    def epsilon_greedy_action(self, state, epsilon):
        if np.random.uniform(0, 1) < epsilon:
                return env.random_action()  # choose random action
        else:                
            return torch.argmax(self.qnet(state), dim=1)  # choose greedy action
    
#update Q Network by calculating gradient of neural network loss
    def update_Q_Network(self, state, next_state, action, reward, terminals):
        qsa = torch.gather(self.qnet(state), dim=1, index=action.unsqueeze(-1)).squeeze()
        qsa_next_action = self.qnet_target(next_state)
        qsa_next_action = torch.max(qsa_next_action, dim=1)[0]
        not_terminals = ~terminals
        qsa_next_target = reward + not_terminals * self.discount_factor * qsa_next_action
        q_network_loss = self.MSELoss_function(qsa, qsa_next_target.detach())
        self.qnet_optim.zero_grad()
        q_network_loss.backward()
        self.qnet_optim.step()
     
#call this to update Q network (train) and then make hard update of target network
    def update(self, update_rate):
        for _ in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample(self.num_envs//3)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.update_count+=1
            if self.update_count==self.target_update_interval:
                self.target_update(self.qnet, self.qnet_target)
                self.update_count=0
                
#call this to update Q network (train) and then make soft update of target network
    def soft_update(self, update_rate):
        for _ in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample(self.num_envs//3)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.soft_target_update(self.qnet, self.qnet_target)


## Defining Some Neural Networks

In [171]:
# defining a fully connected neural network
def FNN_1(shape, action_dim, hidden_dim):
    flat_shape = np.product(shape) #length of the flattened state
    model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(flat_shape,hidden_dim),
                        nn.ReLU(),
                        #nn.Linear(hidden_dim,hidden_dim),
                        #nn.ReLU(),
                        nn.Linear(hidden_dim, action_dim),
                         ).to(DEFAULT_DEVICE)
    return model

def CNN_1():
    model = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.AdaptiveMaxPool2d((1,1)),
        torch.nn.Flatten(),
        torch.nn.Linear(32, 64),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 4),
        ).to(DEFAULT_DEVICE)
    return model

def CNN_2():
    model = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.AdaptiveMaxPool2d((1,1)),
        torch.nn.Flatten(),
        torch.nn.Linear(32, 64),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 4),
        ).to(DEFAULT_DEVICE)
    return model

## Initializing Environment and Agent

In [174]:
num_envs = 1300 #Number of parallel environments to simulate. Use small value for cpu (eg. 1)
env = SingleSnake(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)

state = env.reset()
state_dim = state.shape[1:]
action_dim = 4

agent=DQNAgent(num_envs = num_envs, buffer_size = 100, NN = CNN_2)
agent.qnet.train()

Sequential(
  (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU()
  (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): ReLU()
  (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU()
  (8): AdaptiveMaxPool2d(output_size=(1, 1))
  (9): Flatten()
  (10): Linear(in_features=32, out_features=64, bias=True)
  (11): ReLU()
  (12): Linear(in_features=64, out_features=4, bias=True)
)

## Training

In [None]:
%%time
render=False
save_model = True
number_of_episodes = 2000

reward_sum=torch.zeros([num_envs]).to(DEFAULT_DEVICE)
total_moves = torch.zeros([num_envs]).to(DEFAULT_DEVICE)
list_reward =[]

state=env.reset()
agent.qnet.train()


for i in range(1,number_of_episodes):
    action = agent.epsilon_greedy_action( state , 0.1) #set epsilon
    next_state, reward, terminal, _ = env.step(action)
    
    agent.add_to_buffer((state,next_state,action,reward,terminal))
    
    agent.soft_update(40)
    
    state = next_state
    
    reward_sum.add_(reward)
    total_moves.add_(1)
    
    if render:
        env.render()
        
    if terminal.any():
        print('episode:', i)
        print('\navg_reward:', np.mean(reward_sum[terminal].cpu().numpy()),'\navg_moves:', np.mean(total_moves[terminal].cpu().numpy()))       
        clear_output(wait=True)
        list_reward.append(np.mean(reward_sum[terminal].cpu().numpy()))
        reward_sum[terminal]=0
        total_moves[terminal]=0
    
if render:    
    env.close()
    
if save_model:
    torch.save(agent.qnet,"save_model.h5")

episode: 294

avg_reward: 0.75757575 
avg_moves: 14.787879


import matplotlib.pyplot as plt
reward_1 = [np.mean(list_reward[i-100:i]) for i in range(100, len(list_reward))]

plt.plot(reward_1)
plt.plot(reward_2)
plt.plot(reward_3)
plt.show()

## Visualize Agent

In [161]:
env = SingleSnake(num_envs=1, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
agent.qnet.eval()
state = env.reset()
for episode in range(5):
    reward_sum = 0


    while(1):
        action = agent.epsilon_greedy_action( state , 0.0)
        #print(agent.qnet(state))
        next_state, reward, terminal, _ = env.step(action)
        reward_sum+= reward.cpu().numpy()
        env.render()
        time.sleep(0.5)
        state = next_state
        if terminal:
            print(reward)
            break
    print('episode:', episode, 'sum_of_rewards_for_episode:', reward_sum)

env.close()

tensor([[3.8948, 3.9775, 3.3580, 7.7497]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[4.9992, 5.8710, 4.6602, 9.2249]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[5.0585, 6.5257, 6.4130, 5.4418]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[3.6331, 5.6519, 4.9770, 4.1272]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[ 0.8641,  3.9909,  3.9817, -8.8551]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([-10.], device='cuda:0')
episode: 0 sum_of_rewards_for_episode: [-5.]
tensor([[3.3407, 5.9662, 3.8319, 4.9328]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[2.3820, 4.7629, 2.8067, 4.0156]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[2.3867, 3.8322, 2.2346, 2.7260]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([[ 1.2927,  2.5089,  1.8961, -9.6555]], device='cuda:0',
       grad_fn=<AddmmBackward>)
tensor([-10.], device='cuda:0')
episode: 1 sum_of_rewards_for_episode: [-10.]
tensor([[

In [49]:
env.close()