In [None]:
import numpy as np
from IPython.display import clear_output, display
import torch
import random
import copy
import time
import os #to get current working directory
import matplotlib.pyplot as plt
import pickle #for storing data
from wurm.envs import SingleSnake
from wurm.envs import SimpleGridworld
from gym.wrappers.monitoring.video_recorder import VideoRecorder

DEFAULT_DEVICE = 'cuda' #set device

## Visualizing the neural network. Requires Tensorboard

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
writer.add_graph(qnet, torch.Tensor(env.reset()))
writer.close()
%load_ext tensorboard
%tensorboard --logdir=runs

## Replay Buffer

In [None]:
import collections

class Trajectory():
    def __init__(self):
        #data must be of the form (state,next_state,action,reward,terminal)
        self.buffer_states = []
        self.buffer_next_states = []
        self.buffer_actions = []
        self.buffer_rewards = []
        self.buffer_terminals = []
        self.buffer_log_probs = []
        
    def append_to_trajectory(self,
               state: torch.Tensor = None,
               next_state: torch.Tensor = None,
               action: torch.Tensor = None,
               log_prob: torch.Tensor = None,
               reward: torch.Tensor = None,
               value: torch.Tensor = None,
               terminal: torch.Tensor = None,
               entropy: torch.Tensor = None,
               hidden_state: torch.Tensor = None):
        """Adds a transition to the store.

        Each argument should be a vector of shape (num_envs, 1)
        """
        if state is not None:
            self.buffer_states.append(state)
        if next_state is not None:
            self.buffer_next_states.append(next_state)
        if action is not None:
            self.buffer_actions.append(action)
        if log_prob is not None:
            self.buffer_log_probs.append(log_prob)
        if reward is not None:
            self.buffer_rewards.append(reward)
        if value is not None:
            self.buffer_values.append(value)
        if terminal is not None:
            self.buffer_terminals.append(terminal)
        if entropy is not None:
            self.buffer_entropies.append(entropy)
        if hidden_state is not None:
            self.buffer_hiddens.append(hidden_state)
        
    def clear_trajectory(self):
        self.buffer_states = []
        self.buffer_next_states = []
        self.buffer_actions = []
        self.buffer_rewards = []
        self.buffer_terminals = []
        self.buffer_log_probs = []
     
    """@property
    def states(self):
        return torch.stack(self.buffer_states)
    @property
    def next_states(self):
        return torch.stack(self.buffer_next_states)
    @property
    def actions(self):
        return torch.stack(self.buffer_actions)
    @property
    def log_probs(self):
        return torch.stack(self.buffer_log_probs)
    @property
    def rewards(self):
        return torch.stack(self.buffer_rewards)
    @property
    def values(self):
        return torch.stack(self.buffer_values)
    @property
    def terminals(self):
        return torch.stack(self.buffer_terminals)
    @property
    def entropies(self):
        return torch.stack(self.buffer_entropies)
    @property
    def hidden_state(self):
        return torch.stack(self.buffer_hiddens)"""

## Policy Gradient Agents

def update(self):
        G = torch.zeros(len(self.buffer_states[0])).to(DEFAULT_DEVICE)
        for t in reversed(range(len(self.buffer_states))):
            G.mul_(self.discount_factor)
            G.add_(self.buffer_rewards[t])
            
        log_probs = self.model(self.buffer_states[0])
        action = self.buffer_actions[0]
        performance = log_probs*G[:,None]
        loss = torch.nn.NLLLoss(reduction = 'sum')(performance, action)
        self.model_optim.zero_grad()
        loss.backward()
        self.model_optim.step()
        

In [None]:
class A2C(Trajectory):
    def __init__(self, NN: object, NN_args: tuple = (), 
                 num_envs: int = 1, buffer_size: int = 800, 
                 lr: float = 0.0005, discount: float = 0.8, tau: float = 0.01,
                 lam = 10):
        super().__init__()
        self.model = NN(*NN_args)
        self.model_optim = torch.optim.Adam(self.model.parameters(), lr=lr) #set learning rate
        self.gamma = torch.Tensor([discount]).to(DEFAULT_DEVICE) # set discount factor
        
        self.lr = lr
        
    def load(self, path):
        self.model = torch.load(path)
        self.model_optim = torch.optim.Adam(self.qnet.parameters(), lr=self.lr)

    def train(self):
        self.model.train()
        torch.set_grad_enabled(True)
        
    def evaluate(self):
        self.model.eval()
        torch.set_grad_enabled(False)
        
    def action(self, state):
        #self.values and self.log_probs will be needed to update model in self.update()
        action_probs, self.values = self.model(state)
        dist = torch.distributions.Categorical(probs = action_probs)
        actions = dist.sample()
        self.log_probs = dist.log_prob(actions)
        return actions

    def update(self, next_states, rewards, terminals):
        _, next_state_values = self.model(next_states)
        
        returns = (rewards + (~terminals)*self.gamma*next_state_values).detach() 
        values = self.values
        
        value_loss = torch.nn.functional.smooth_l1_loss(returns,values).mean()
        advantages = returns-values
        policy_loss = -(advantages.detach()*self.log_probs).mean()
        loss = policy_loss + value_loss
        
        self.model_optim.zero_grad()
        #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)
        loss.backward()
        self.model_optim.step()

## Defining Some Neural Networks

In [None]:
class SnakeNet(torch.nn.Module):
    def __init__(self, size: int):
        super().__init__()
        self.layer_length = size*size
        self.common_layer = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            ).to(DEFAULT_DEVICE)
        
        self.policy_layer = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 2, kernel_size=(1,1), stride=(1,1)),
            torch.nn.Flatten(),
            torch.nn.Linear(2*self.layer_length,4),
            torch.nn.Softmax(dim=-1)
            ).to(DEFAULT_DEVICE)
        
        self.value_layer = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 1, kernel_size=(1,1), stride=(1,1)),
            torch.nn.Flatten(),
            torch.nn.ReLU(),
            torch.nn.Linear(self.layer_length,64),
            torch.nn.Linear(64,1)
            ).to(DEFAULT_DEVICE)
        
    def forward(self,x):
        x = self.common_layer(x)
        self.policy = self.policy_layer(x)
        self.value = self.value_layer(x)
        return (self.policy, self.value.squeeze(-1))
        

## Initializing Environment and Agent

In [None]:
environment = 'SingleSnake'
num_envs = 1000 #Number of parallel environments to simulate. Use small value for cpu (eg. 1)
test_num_envs = 100

if environment == 'SimpleGridworld':
    env = SimpleGridworld(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset=True)
    test_env = SimpleGridworld(num_envs=test_num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset=False)

    state = env.reset()
    state_dim = state.shape[1:]
    action_dim = 4

    #Effective buffer_size = buffer_size*num_envs
    agent=Reinforce_Agent(NN = FNN_1, NN_args = (state_dim, 512, action_dim),
                           lr=0.005, discount = 1.0)

elif environment == 'SingleSnake':
    env = SingleSnake(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset= True)
    test_env = SingleSnake(num_envs=test_num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset=False)

    state = env.reset()
    state_dim = state.shape[1:]
    action_dim = 4

    #Effective buffer_size = buffer_size*num_envs
    agent=A2C(NN = SnakeNet, NN_args = (10,), lr = 0.0005, discount =0.99)

else:
    raise Exception("Invalid option")

#agent.load("models/best_model.h5")
agent.train()
print(agent.model)

## Training

In [None]:
render=False
save_model = False
number_of_steps = 100000
epsilon = 1.0
####Code to compute total reward####

total_reward = torch.zeros(num_envs).to(DEFAULT_DEVICE)
step_list=[]
fc_list=[] #food collected
best_fc = 0
####Code to compute total reward####


agent.train()

state=env.reset()
#Learning
for i in range(0,number_of_episodes):
    ##############Learning######################

    action = agent.action(state) 
    next_state, reward, terminal, _ = env.step(action)
    agent.update(next_state, reward, terminal)  
    state = next_state

    #############Validation############################
    if i%100 == 0:
        agent.evaluate()                        
        t_state = test_env.reset()
        fc_sum = torch.zeros((test_num_envs,)).float().to(DEFAULT_DEVICE) #foot collected
        #hit_terminal = torch.zeros((test_num_envs,)).bool().to(DEFAULT_DEVICE)
        for _ in range(1000): #max steps
            t_action = agent.action(t_state)
            t_next_state, t_reward, t_terminal, _ = test_env.step(t_action)
            #anything with a positive reward is considered as food.
            fc_sum+=(t_reward>0).float()
            #hit_terminal |= t_terminal
            t_state = t_next_state
            if t_terminal.all():
                break

        t_sum = fc_sum.cpu().numpy()
        t_mean = np.mean(t_sum)
        print('Step:', i)
        print("Episode Completed:", t_terminal.sum().cpu().numpy(), "/", test_num_envs)
        print("Mean, Median, Max, Min, std:", 
              t_mean, 
              np.median(t_sum),
              np.max(t_sum),
              np.min(t_sum),
              np.std(t_sum))
        fc_list.append(t_mean)
        step_list.append(i)
        plt.plot(step_list, fc_list)
        plt.show()
        agent.train()
        clear_output(wait=True)
        if t_mean>best_fc:
            best_fc = t_mean
            #torch.save(agent.model,"models/best_model.h5")
        
    



## Visualize and Record Gameplay

In [None]:
%%time
env = SingleSnake(num_envs=1, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE)
agent.evaluate()
PATH = os.getcwd()
state = env.reset()
for episode in range(100):
    fc_sum = 0
    recorder = VideoRecorder(env, path=PATH + f'/videos/snake_{episode}.mp4')
    #env.render()
    recorder.capture_frame()
    time.sleep(0.2)
    counter = 0
    while(1):
        counter+=1
        action = agent.action(state)
        next_state, reward, terminal, _ = env.step(action)
        fc_sum+= (reward>0).cpu().numpy()
        #env.render()
        recorder.capture_frame()
        #time.sleep(0.2)
        state = next_state
        if terminal.all() or counter==1000:
            recorder.close()
            break
    print("Completed:", terminal.any().cpu().numpy())
    print('Episode:', episode, 'Food Collected:', fc_sum)

#env.close()

## Computing Average Return

In [None]:
test_env = SimpleGridworld(num_envs=num_envs, size=10, observation_mode='one_channel', device= DEFAULT_DEVICE, auto_reset=False)
agent.evaluate()

                       
t_state = test_env.reset()
fc_sum = torch.zeros((num_envs,)).float().to(DEFAULT_DEVICE) #foot collected

for steps in range(1000): #max steps
    t_action = agent.action(t_state)
    t_next_state, t_reward, t_terminal, _ = test_env.step(t_action)
    #anything with a positive reward is considered as food.
    fc_sum+=(t_reward>0).float()
    t_state = t_next_state
    if t_terminal.all():
        break

t_sum = fc_sum.cpu().numpy()
t_mean = np.mean(t_sum)
print("Completed:", t_terminal.sum().cpu().numpy())
print("Mean, Median, Max, Min, std:", 
      t_mean, 
      np.median(t_sum),
      np.max(t_sum),
      np.min(t_sum),
      np.std(t_sum))