In [1]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical
import torchvision
import torchvision.transforms as transforms
import numpy as np
from Simulation import Level
import math

# L1 and L2 are simple and complex examples we will use to test policy
# learning
L1 = ["11111",
     "11011",
     "11111",
     "11111",
     "11X11",
     "11111",
     "11111"]

L2 = [
     "11111",
     "11011",
     "11011",
     "11011",
     "11011",
     "11011",
     "11011",
     "11001",
     "11111",
     "11111",
     "11X11",
     "11111",
     "11111"]

simple_level = Level(L1)
complex_level = Level(L2)

In [2]:
print(simple_level)
print(complex_level)

11111
11011
11111
11111
11X11
11111
11111

11111
11011
11011
11011
11011
11011
11011
11001
11111
11111
11X11
11111
11111



In [3]:
# Hyperparameters
agent_view = 5*5*3
agent_choices = 8
learning_rate = 0.001
gamma = 0.99
hidden_size = 128
dropout_prob = 0
epsilon = 0.1
episodeNumber = 0

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.state_space = agent_view # Input vector
        self.action_space = agent_choices # Number of choices
        
        # Neural Net architecture
        self.l1 = nn.Linear(self.state_space, hidden_size, bias=True)
        self.l2 = nn.Linear(hidden_size, hidden_size, bias=True)
        self.l3 = nn.Linear(hidden_size, self.action_space, bias=False)
        
        self.gamma = gamma
        
        # Episode policy and reward history
        self.policy_history = Variable(torch.Tensor())
        self.reward_episode = []
        # Overall reward and loss history
        self.reward_history = []
        self.loss_history = []
    
    def forward(self, x):
        model = torch.nn.Sequential(
            self.l1,
            nn.Dropout(p=dropout_prob),
            nn.SELU(),
            self.l2,
            nn.Dropout(p=dropout_prob),
            nn.SELU(),
            self.l3,
            nn.Softmax(dim=-1)
        )
        return model(x)
        

In [4]:
policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=learning_rate)

In [5]:
# Implement select_action here
def select_action(state):
    state = torch.from_numpy(state).type(torch.FloatTensor)
    choices = policy(Variable(state))
    c = Categorical(choices)
    action = c.sample()
    
    if(random.random() < epsilon):
        tempArray = np.array([0.125,0.125,0.125,0.125,0.125,0.125,0.125,0.125])
        choices2 = torch.Tensor(tempArray)
        c2 = Categorical(choices2)
        action = c2.sample()
    
    if policy.policy_history.nelement() == 0:
        policy.policy_history = torch.stack([c.log_prob(action)])
    else:
        policy.policy_history = torch.cat([policy.policy_history, torch.stack([c.log_prob(action)])])

    return int(action)

In [6]:
# We apply Monte-Carlo Policy Gradient to improve out policy according
# to the equation
def update_policy():
    
    R = 0
    rewards = []

    # Discount future rewards back to the present using gamma
    for r in policy.reward_episode[::-1]:
        R = r + policy.gamma * R
        rewards.insert(0,R)

    # Scale rewards
    rewards = torch.FloatTensor(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float64).eps)

    # Calculate loss
    loss = (torch.sum(torch.mul(policy.policy_history, Variable(rewards)).mul(-1), -1))

    # Update network weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    policy.loss_history.append(loss.data.item())
    
    #Save and intialize episode history counter
    policy.reward_history.append(np.sum(policy.reward_episode))
    policy.policy_history = Variable(torch.Tensor())
    policy.reward_episode = []

In [7]:
def rfunc0(x,steps,done):
    reward = 0
    if done:
        if(x <= 0):
            reward = 1000
        else:
            reward = -1000
    else:
        reward = 1/(x+2)
    #reward -= steps/4
    return reward


def rfunc1(x,steps,done):
    reward = 20 - (x*3)
    if done:
        if(x <= 0):
            reward += 40 - steps
        else:
            reward -= 40 + steps
    #reward -= steps/4
    return reward

def rfunc2(x,steps,done):
    return 1/(x+2)

def rfunc3(x, steps, done):
    return random.random()

def rfunc4(x, steps, done):
    return 5-x

level = complex_level
max_reward = 1

ActDictionary = {0:"SU",
                 1:"SL",
                 2:"SR",
                 3:"SD",
                 4:"JU",
                 5:"JL",
                 6:"JR",
                 7:"JD"}

def main(episodes):
    global episodeNumber
    global epsilon
    for episode in range(episodes):
        episodeNumber += 1
        epsilon = 1/(math.log(episodeNumber+1,2)+1)
        done = False     
        level.Reset()
        stps = []
        while not done:
            state = np.asarray(level.getVector())
            action = select_action(state)
            x,steps,done = level.Act(action.item())
            reward = rfunc0(x,steps,done)
            stps += [ActDictionary[action.item()]]
            policy.reward_episode.append(reward)
            if x <= 0:
                print("\nReached the end!",end=" ")
            if done:
                break 
        update_policy()
        print("Episode Done!")
        print(stps)

In [8]:
main(1000)

Episode Done!
['JR', 'SR']
Episode Done!
['JL', 'SR', 'JU', 'SR', 'SL', 'SR', 'SD', 'JR', 'SR']
Episode Done!
['SU', 'JL', 'JR', 'JL', 'SD', 'SU', 'JU', 'JR']
Episode Done!
['JU', 'SL', 'SU', 'JL']
Episode Done!
['JR', 'JD', 'JL', 'JU', 'SU', 'JR', 'SD', 'SL', 'SR', 'SU', 'SL', 'JL', 'JR', 'JR']
Episode Done!
['SR', 'SU', 'SL', 'JR', 'JU', 'SD', 'JU', 'SR']
Episode Done!
['SU', 'SR', 'SR', 'SU', 'SU', 'SD', 'SD', 'SR']
Episode Done!
['JU', 'SL', 'SD', 'SD', 'JR', 'SU', 'JR']
Episode Done!
['JU', 'JL', 'JR', 'SU']
Episode Done!
['SD', 'JL', 'SL']
Episode Done!
['JR', 'SD', 'SU', 'JL', 'SD', 'JD']
Episode Done!
['JU', 'SL', 'SD', 'JR', 'JU']
Episode Done!
['SR', 'JR']
Episode Done!
['JL', 'SD', 'JU', 'JL']
Episode Done!
['JR', 'JU', 'SU', 'JL']
Episode Done!
['JU', 'JU']
Episode Done!
['JD', 'SD']
Episode Done!
['JR', 'JD', 'JU', 'JD', 'SL', 'JD']
Episode Done!
['SU', 'JU']
Episode Done!
['SU', 'SR', 'JD', 'JU', 'JU']
Episode Done!
['JL', 'SD', 'SU', 'SD', 'SR', 'SR', 'SL', 'SD', 'JU', '

Episode Done!
['SR', 'SL', 'SR', 'SL', 'JR', 'SL', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR']
Episode Done!
['SL', 'SL', 'SR', 'SR', 'JR', 'SL', 'JL', 'SR', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SU', 'JR']
Episode Done!
['SR', 'SL', 'JL', 'SR', 'SR', 'JR', 'SL', 'SL', 'SR', 'SL', 'SR', 'JL', 'SR', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SU', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL']
Episode Done!
['SR', 'SL', 'JD', 'SU', 'SR', 'SL', 'JR', 'SL', 'JU', 'SL', 'SR', 'SL', 'SR', 'JR']
Episode Done!
['SL', 'SR', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'SU', 'JR', 'SL', 'SL', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', 'JR', 'SL', 'JR']
Episode Done!
['JU', 'SL', 'SL', 'SR', 'SR', 'SL', 'JU', 'SL', 'SR', 'JU', 'SD', 'SL', 'SR', 'SL', 'SL']
Episode Done!
['SL', 'SR', 'SL', 'SR', 'JL', 'SR', 'SR', 'JR', 'SL', 'SL', 'SR', 'SL', 'SL', 'SR', 'SR', 'SL', 'SR', 'SL', 'SR', 'SL', '

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'SU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JD', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SD', 'SL', 'JU', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JL']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'J

Episode Done!
['JR', 'SL', 'SL', 'SL', 'JU', 'JU', 'JL']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JL']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'SR', 'SL', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JD', 'JU', 'JR', 'SL', 'SL', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'SL', 'JU']

Reached t

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'SR']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'SD', 'JU', 'J

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'JR', 'SL']

Reached the end! Episode Done!
['JU', 'SL', 'SU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'SR']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'SR', 'SL', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JL']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'SL', 'JU', 'JU']

Reached the end! Ep

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'SL', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JR', 'SL', 'SL', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['JL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'SD', 'JU', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'SR', 'SU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'SD', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'JR', 'SL']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached

Reached the end! Episode Done!
['JR', 'SL', 'SL', 'SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'SD', 'JU', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'JR', 'SL']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JR', 'SL', 'SL', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']
Episode Done!
['SL', 'JU', 'JU', 'SU', 'SR']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['SL', 'JU', 'JU', 'JU', 'JU', 'JU']

Reached the end! Episode Done!
['S