In [None]:
import numpy as np
from IPython.display import clear_output, display
import torch
import random
import copy
import time
import os #to get current working directory
import matplotlib.pyplot as plt
import pickle #for storing data
from wurm.envs import SingleSnake
from gym.wrappers.monitoring.video_recorder import VideoRecorder

DEFAULT_DEVICE = 'cuda' #set device

## Visualizing the neural network. Requires Tensorboard

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
writer.add_graph(qnet, torch.Tensor(env.reset()))
writer.close()
%load_ext tensorboard
%tensorboard --logdir=runs

## Replay Buffer

In [None]:
import collections

class ReplayBuffer():
    def __init__(self, max_buffer_size: int):
        #data must be of the form (state,next_state,action,reward,terminal)
        self.buffer = collections.deque(maxlen=max_buffer_size)

    def add_to_buffer(self, data):
        #data must be of the form (state,next_state,action,reward,terminal)
        self.buffer.append(data)
        
    #Sample batches and sub sample parallel environments
    def sample_minibatch(self,minibatch_length, subbatch_length):
        states = []
        next_states = []
        actions = []
        rewards = []
        terminals = []
        sub_length = self.buffer[0][0].shape[0]
        for rand_int in np.random.randint(0, len(self.buffer), minibatch_length):
            rand_int_1 = np.random.randint(0, sub_length, subbatch_length)
            transition = self.buffer[rand_int]
            states.append(transition[0][rand_int_1])
            next_states.append(transition[1][rand_int_1])
            actions.append(transition[2][rand_int_1])
            rewards.append(transition[3][rand_int_1])
            terminals.append(transition[4][rand_int_1])
        return torch.cat(states), torch.cat(next_states), torch.cat(actions), torch.cat(rewards), torch.cat(terminals)

    def sample_batch(self,minibatch_length):
        states = []
        next_states = []
        actions = []
        rewards = []
        terminals = []
        for rand_int in np.random.randint(0, len(self.buffer), minibatch_length):
            transition = self.buffer[rand_int]
            states.append(transition[0])
            next_states.append(transition[1])
            actions.append(transition[2])
            rewards.append(transition[3])
            terminals.append(transition[4])
        return torch.cat(states), torch.cat(next_states), torch.cat(actions), torch.cat(rewards), torch.cat(terminals)
    
    #sample parallel environments from a randomly selected memory.
    def sample(self, subbatch_length):
            rand_int = np.random.randint(0, len(self.buffer))
            rand_int_1 = np.random.randint(0, len(self.buffer[0][0]), subbatch_length)
            transition = self.buffer[rand_int]
            states=transition[0][rand_int_1]
            next_states=transition[1][rand_int_1]
            actions=transition[2][rand_int_1]
            rewards=transition[3][rand_int_1]
            terminals=transition[4][rand_int_1]
            return (states,next_states,actions,rewards,terminals)

        
#A buffer with lesser correlation between samples. Implemented with pytorch. 
#Presently not working properly. Not sure why.
class BetterBuffer():
    def __init__(self, max_envs: int = 1000):
        #data must be of the form (state,next_state,action,reward,terminal)
        self.buffer_0 = torch.empty(0).to(DEFAULT_DEVICE)
        self.buffer_1 = torch.empty(0).to(DEFAULT_DEVICE)
        self.buffer_2 = torch.empty(0).long().to(DEFAULT_DEVICE)
        self.buffer_3 = torch.empty(0).to(DEFAULT_DEVICE)
        self.buffer_4 = torch.empty(0).bool().to(DEFAULT_DEVICE)
        self.max_length = max_envs
        self.pointer = 0
        self.full = False

    def add_to_buffer(self, data):
        #data must be of the form (state,next_state,action,reward,terminal)
        if self.full == True:
            if self.pointer==self.max_length:
                self.pointer=0
            self.buffer_0[self.pointer*num_envs:(self.pointer+1)*num_envs] = data[0]
            self.buffer_1[self.pointer*num_envs:(self.pointer+1)*num_envs] = data[1]
            self.buffer_2[self.pointer*num_envs:(self.pointer+1)*num_envs] = data[2]
            self.buffer_3[self.pointer*num_envs:(self.pointer+1)*num_envs] = data[3]
            self.buffer_4[self.pointer*num_envs:(self.pointer+1)*num_envs] = data[4]
            self.pointer+=1
        else:
            self.buffer_0=torch.cat((self.buffer_0,data[0]))
            self.buffer_1=torch.cat((self.buffer_1,data[1]))
            self.buffer_2=torch.cat((self.buffer_2,data[2]))
            self.buffer_3=torch.cat((self.buffer_3,data[3]))
            self.buffer_4=torch.cat((self.buffer_4,data[4]))
            self.pointer+=1
            if self.pointer==self.max_length:
                self.full=True
        
    def sample(self, batch_size):
        if self.full == True:
            randint = torch.randint(0, self.max_length*num_envs,(batch_size,))
        else:
            randint = torch.randint(0,self.pointer*num_envs, (batch_size,))
        return self.buffer_0[randint], self.buffer_1[randint], self.buffer_2[randint], self.buffer_3[randint], self.buffer_4[randint]
    

## DQN Agent

In [None]:
#################Simple DQN Agent########################################
class DQNAgent():
    def __init__(self, NN: object, NN_args: tuple = (), 
                 num_envs: int = 1, buffer_size: int = 800, 
                 lr: float = 0.0005, discount: float = 0.8, tau: float = 0.01):
        #self.qnet = torch.load("dqn80x80.h5")
        self.qnet = NN(*NN_args)
        self.qnet_target = copy.deepcopy(self.qnet)
        self.qnet_optim = torch.optim.Adam( self.qnet.parameters(), lr=lr) #set learning rate
        self.discount_factor = torch.Tensor([discount]).to(DEFAULT_DEVICE) # set discount factor
        self.MSELoss_function = torch.nn.MSELoss()
        self.replay_buffer = ReplayBuffer(buffer_size) #set size of replay buffer.
        self.old_buffer = ReplayBuffer(300)
        
        self.target_update_interval = 500 #set for target update interval for hard target network updates
        self.update_count= 0 #internal working variable. Don't change
        self.tau = tau # set tau for soft target network updates
        self.num_envs = num_envs
        pass
    
    def add_to_buffer(self, data):
        self.replay_buffer.add_to_buffer(data)
    def add_to_old_buffer(self, data):
        self.old_buffer.add_to_buffer(data)
    
    def train(self):
        self.qnet.train()
        
    def evaluate(self):
        self.qnet.eval()
        
    #Hard update target network
    def target_update(self,network,target_network):
        for net_params, target_net_params in zip(network.parameters(), target_network.parameters()):
            target_net_params.data.copy_(net_params.data)
     
    #Soft update target network
    def soft_target_update(self,network,target_network):
        for net_params, target_net_params in zip(network.parameters(), target_network.parameters()):
            target_net_params.data.copy_(net_params.data*self.tau + (1-self.tau)*target_net_params)
    
    def epsilon_greedy_action(self, state, epsilon):
        if np.random.uniform(0, 1) < epsilon:
                return env.random_action()  # choose random action
        else:                
            return torch.argmax(self.qnet(state), dim=1)  # choose greedy action
    
    #update Q Network by calculating gradient of neural network loss
    def update_Q_Network(self, state, next_state, action, reward, terminals):
        qsa = torch.gather(self.qnet(state), dim=1, index=action.unsqueeze(-1)).squeeze()
        qsa_next_action = self.qnet_target(next_state)
        qsa_next_action = torch.max(qsa_next_action, dim=1)[0]
        not_terminals = ~terminals
        qsa_next_target = reward + not_terminals * self.discount_factor * qsa_next_action
        q_network_loss = self.MSELoss_function(qsa, qsa_next_target.detach())
        self.qnet_optim.zero_grad()
        q_network_loss.backward()
        self.qnet_optim.step()
     
    #call this to update Q network (train) and then make hard update of target network
    def update(self, update_rate):
        for _ in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample(self.num_envs//3)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.update_count+=1
            if self.update_count==self.target_update_interval:
                self.target_update(self.qnet, self.qnet_target)
                self.update_count=0
                
    #call this to update Q network (train) and then make soft update of target network
    def soft_update(self, update_rate, batch_size):
        for _ in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample(batch_size)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.soft_target_update(self.qnet, self.qnet_target)

###############Simple DQN Agent######################################################            
#################Double DQN Agent smooth##############################################
#Based on https://arxiv.org/abs/1509.06461v3
class DDQNAgent_smooth(DQNAgent):
    def __init__(self, NN: object, NN_args: tuple = (), 
                 num_envs: int = 1, buffer_size: int = 800,
                 lr: float = 0.0005, discount: float = 0.8, tau: float = 0.01):
        super().__init__(NN, NN_args, num_envs, buffer_size, lr, discount, tau)
    
    #update Q Network by calculating gradient of neural network loss
    def update_Q_Network(self, state, next_state, action, reward, terminals):
        qsa = torch.gather(self.qnet(state), dim=1, index=action.unsqueeze(-1)).squeeze()
        q_target_next_state_a = self.qnet_target(next_state)
        q_target_next_state_max_a = torch.argmax(q_target_next_state_a, dim=1)
        q_next_state_a = torch.gather(self.qnet(next_state), dim=1, index=q_target_next_state_max_a.unsqueeze(-1)).squeeze()
        not_terminals = ~terminals
        qsa_next_target = reward + not_terminals * self.discount_factor * q_next_state_a
        q_network_loss = self.MSELoss_function(qsa, qsa_next_target.detach())
        self.qnet_optim.zero_grad()
        q_network_loss.backward()
        self.qnet_optim.step()
        
    #call this to update Q network (train) and then make soft update of target network
    def update(self, update_rate, batch_size, rehearse_rate = 0):
        for _ in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample(batch_size)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.soft_target_update(self.qnet, self.qnet_target)
        for _ in range(rehearse_rate):
            states, next_states, actions, rewards, terminals = self.old_buffer.sample(batch_size)
            self.update_Q_Network(states, next_states, actions, rewards, terminals)
            self.soft_target_update(self.qnet, self.qnet_target)
#################Double DQN Agent smooth##################################

#################Double DQN Agent########################################
#Based on https://arxiv.org/abs/1509.06461v1
class DDQNAgent(DQNAgent):
    def __init__(self, NN: object, NN_args: tuple = (), 
                 num_envs: int = 1, buffer_size: int = 800,
                 lr: float = 0.0005, discount: float = 0.8):
        #self.qnet = torch.load("dqn80x80.h5")
        self.Q_A = NN(*NN_args)
        self.Q_B = NN(*NN_args)
        self.Q_A_optim = torch.optim.Adam( self.Q_A.parameters(), lr=lr) #set learning rate
        self.Q_B_optim = torch.optim.Adam( self.Q_B.parameters(), lr=lr) #set learning rate
        self.discount_factor = torch.Tensor([discount]).to(DEFAULT_DEVICE) # set discount factor
        self.MSELoss_function = torch.nn.MSELoss()
        self.replay_buffer = ReplayBuffer(buffer_size) #set size of replay buffer.
        self.num_envs = num_envs
        pass
    
    def train(self):
        self.Q_A.train()
        self.Q_A.train()
        
    def evaluate(self):
        self.Q_A.eval()
        self.Q_B.eval()
    
    def epsilon_greedy_action(self, state, epsilon):
        if np.random.uniform(0, 1) < epsilon:
                return env.random_action()  # choose random action
        else:                
            return torch.argmax(self.Q_A(state), dim=1)  # choose greedy action
    
    #update Q Network by calculating gradient of neural network loss
    def update_Q_A_Network(self, state, next_state, action, reward, terminals):
        QA_s_a = torch.gather(self.Q_A(state), dim=1, index=action.unsqueeze(-1)).squeeze()
        QA_sn_a = self.Q_A(next_state)
        QA_sn_a_max = torch.argmax(QA_sn_a, dim=1)
        QB_sn_a = torch.gather(self.Q_B(next_state), dim=1, index=QA_sn_a_max.unsqueeze(-1)).squeeze()
        not_terminals = ~terminals
        QA_s_a_target = reward + not_terminals * self.discount_factor * QB_sn_a
        q_network_loss = self.MSELoss_function(QA_s_a, QA_s_a_target.detach())
        self.Q_A_optim.zero_grad()
        q_network_loss.backward()
        self.Q_A_optim.step()
        
    def update_Q_B_Network(self, state, next_state, action, reward, terminals):
        QB_s_a = torch.gather(self.Q_B(state), dim=1, index=action.unsqueeze(-1)).squeeze()
        QB_sn_a = self.Q_B(next_state)
        QB_sn_a_max = torch.argmax(QB_sn_a, dim=1)
        QA_sn_a = torch.gather(self.Q_A(next_state), dim=1, index=QB_sn_a_max.unsqueeze(-1)).squeeze()
        not_terminals = ~terminals
        QB_s_a_target = reward + not_terminals * self.discount_factor * QA_sn_a
        q_network_loss = self.MSELoss_function(QB_s_a, QB_s_a_target.detach())
        self.Q_B_optim.zero_grad()
        q_network_loss.backward()
        self.Q_B_optim.step()
        
    #call this to update Q network (train) and then make soft update of target network
    def update(self, update_rate, batch_size):
        for _ in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_buffer.sample(batch_size)
            if np.random.uniform()<0.5:
                self.update_Q_A_Network(states, next_states, actions, rewards, terminals)
            else:
                self.update_Q_B_Network(states, next_states, actions, rewards, terminals)
#################Double DQN Agent##############################################



## Defining Some Neural Networks

In [None]:
# defining a fully connected neural network
def FNN_1(shape, hidden_dim, action_dim):
    flat_shape = np.product(shape) #length of the flattened state
    model = torch.nn.Sequential(
        torch.nn.Flatten(),
        torch.nn.Linear(flat_shape,hidden_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_dim,hidden_dim),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_dim, action_dim),
         ).to(DEFAULT_DEVICE)
    return model

def CNN_1():
    model = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.Dropout(0.1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.Dropout(0.2),
        torch.nn.ReLU(),
        torch.nn.AdaptiveMaxPool2d((1,1)),
        torch.nn.Flatten(),
        torch.nn.Linear(128, 128),
        torch.nn.Dropout(0.5),
        torch.nn.ReLU(),
        torch.nn.Linear(128, 4),
        ).to(DEFAULT_DEVICE)
    return model

def CNN_2():
    model = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.AdaptiveMaxPool2d((1,1)),
        torch.nn.Flatten(),
        torch.nn.Linear(32, 64),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 4),
        ).to(DEFAULT_DEVICE)
    return model

def CNN_3():
    model = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.ReLU(),
        torch.nn.Conv2d(32, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
        torch.nn.Dropout(0.2),
        torch.nn.ReLU(),
        torch.nn.AdaptiveMaxPool2d((1,1)),
        torch.nn.Flatten(),
        torch.nn.Linear(256, 64),
        torch.nn.Dropout(0.1),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 4),
        ).to(DEFAULT_DEVICE)
    return model


## Initializing Environment and Agent

In [None]:
num_envs = 1300 #Number of parallel environments to simulate. Use small value for cpu (eg. 1)
test_num_envs = 100
env = SingleSnake(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
test_env = SingleSnake(num_envs=test_num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset=False)

state = env.reset()
state_dim = state.shape[1:]
action_dim = 4

#Effective buffer_size = buffer_size*num_envs
agent=DDQNAgent_smooth(NN = CNN_1, num_envs = num_envs, buffer_size = 600, lr = 0.0005, discount = 0.8, tau =0.01)

agent.train()
print(agent.qnet)

## Training

In [None]:
#%%time
render=False
save_model = False
number_of_episodes = 100000

####Code to compute total reward####
counter = np.zeros(num_envs, dtype = np.int)
reward_buffer = np.zeros((200,num_envs), dtype=np.float32)
total_reward = torch.zeros(num_envs).to(DEFAULT_DEVICE)
reward_list=[]
episode_list=[]
range_constant = np.arange(num_envs)
best_reward=-10
####Code to compute total reward####

state=env.reset()
agent.train()

#Filling up the buffer
for i in range(30):
    action = agent.epsilon_greedy_action( state , 1.0) #set epsilon
    next_state, reward, terminal, _ = env.step(action)
    agent.add_to_buffer((state,next_state,action,reward,terminal))
    state = next_state


state=env.reset()

#Learning
for i in range(1,number_of_episodes):
    action = agent.epsilon_greedy_action( state , 0.05) #set epsilon
    next_state, reward, terminal, _ = env.step(action)
    
    agent.add_to_buffer((state,next_state,action,reward,terminal))
    
    agent.update(10,400)
    
    state = next_state
    
    
    if render:
        env.render()
    """    
    #Storing and printing average snake reward
    #Code is a bit complicated and contrived.
    total_reward.add_(reward)
    term = terminal.cpu().numpy()
    if term.any():
        print('episode:', i)
        if reward_list != []:
            print('Mean Reward: ', reward_list[-1], best_reward)
        reward_buffer[counter, range_constant] = total_reward.cpu().numpy()
        total_reward[terminal] = 0
        counter[counter<199]+=term[counter<199] #prevent buffer overflow.
        clear_output(wait=True) 
        if (counter>0).all():
            mean = np.mean(reward_buffer[0])
            reward_buffer = np.roll(reward_buffer,-1,axis=0)
            counter-=1
            reward_list.append(mean)
            episode_list.append(i)
            #To deal with catastrophic forgetting, let's save the neural network with highest score.
            if mean>best_reward:
                best_reward = mean
                torch.save(agent.qnet,"models/best_model.h5")
    """
    #New code for recording data. Slower but more accurate.
    if (i%100==0):
        agent.evaluate()                        
        t_state = test_env.reset()
        reward_sum = torch.zeros((test_num_envs,)).to(DEFAULT_DEVICE)
        hit_terminal = torch.zeros((test_num_envs,)).bool().to(DEFAULT_DEVICE)

        for steps in range(1000): #max steps
            t_action = agent.epsilon_greedy_action( t_state , 0.0)
            t_next_state, t_reward, t_terminal, _ = test_env.step(t_action)
            reward_sum+= (~hit_terminal)*t_reward
            hit_terminal |= t_terminal
            t_state = t_next_state
            if t_terminal.all():
                break
        r_sum = reward_sum.cpu().numpy()
        mean = np.mean(r_sum)
        print('episode:', i, "Mean, Median, Max, Min, std:", 
              mean, 
              np.median(r_sum),
              np.max(r_sum),
              np.min(r_sum),
              np.std(r_sum))
        reward_list.append(mean)
        episode_list.append(i)
        agent.train()
        clear_output(wait=True)
        if mean>best_reward:
            best_reward = mean
            torch.save(agent.qnet,"models/best_model.h5")
    
if render:    
    env.close()
    
if save_model:
    model = torch.load("best_model.h5")
    with open('models/best_model.pickle', 'wb') as f:
        pickle.dump((model, episode_list,reward_list), f)

## Storing Data of best model and associated runtime data.

In [None]:
#Visualizing reward
plt.plot(episode_list, reward_list)
plt.show()

#Store data about the best model as a pickle file
import pickle
model = torch.load("models/best_model.h5")
with open('models/cnn_256_avg_822.pickle', 'xb') as f:
    pickle.dump((model, episode_list,reward_list), f)

In [None]:
#to load the saved model
with open('models/cnn_256_avg_8.5.pickle', 'rb') as f:
     data = pickle.load(f)

agent.qnet= data[0]

plt.plot(data[1], data[2])
plt.show()
print(agent.qnet)

## Visualize and Record Gameplay

In [None]:
env = SingleSnake(num_envs=1, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
agent.evaluate()
PATH = os.getcwd()
state = env.reset()
for episode in range(5):
    reward_sum = 0
    recorder = VideoRecorder(env, path=PATH + f'/videos/{episode}.mp4')
    env.render()
    recorder.capture_frame()
    time.sleep(0.2)
    while(1):
        action = agent.epsilon_greedy_action( state , 0.0)
        next_state, reward, terminal, _ = env.step(action)
        reward_sum+= reward.cpu().numpy()
        env.render()
        recorder.capture_frame()
        time.sleep(0.2)
        state = next_state
        if terminal.all():
            recorder.close()
            break
    print('episode:', episode, 'sum_of_rewards_for_episode:', reward_sum)

env.close()

## Computing Average Return

In [None]:
env = SingleSnake(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset=False)
agent.evaluate()

for episode in range(5):
    state = env.reset()
    reward_sum = torch.zeros((num_envs,)).to(DEFAULT_DEVICE)
    hit_terminal = torch.zeros((num_envs,)).bool().to(DEFAULT_DEVICE)

    for steps in range(1000):
        action = agent.epsilon_greedy_action( state , 0)
        next_state, reward, terminal, _ = env.step(action)
        reward_sum+= (~hit_terminal)*reward
        hit_terminal |= terminal
        state = next_state
        if terminal.all():
            
            break
    print(terminal.sum())
    print('episode:', episode, "Mean, Max, Min, Median:", 
          torch.mean(reward_sum).cpu().numpy(), 
          torch.max(reward_sum).cpu().numpy(),
          torch.min(reward_sum).cpu().numpy(),
          torch.median(reward_sum).cpu().numpy())
