In [1]:
import numpy as np
import psutil
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam
from torch.distributions import Categorical
from tqdm import tqdm_notebook as tqdm

import sys; sys.path.append("../screeps_rl_env")
from screeps_rl_env import ScreepsEnv
from screeps_rl_env.utils import kill_backend_processes

In [2]:
class SimplePolicy(nn.Module):

    def __init__(self, env, H = 30, gamma = 0.99):
        super().__init__()

        in_dim = 4  # two set of (x, y) coords
        out_dim = 8  # can move in 8 directions

        self.linear1 = torch.nn.Linear(in_dim, H)
        self.linear2 = torch.nn.Linear(H, out_dim)

        self.gamma = gamma

        # Episode policy and reward history
        self.policy_history = Variable(torch.Tensor())
        self.reward_episode = []
        
        # Overall reward and loss history
        self.reward_history = []
        self.loss_history = []
        
    def select_action(self, state):
#         state = torch.from_numpy(state).type(torch.FloatTensor)
        state = torch.FloatTensor(state)
        state = self.forward(state)
        c = Categorical(state)
        action = c.sample()
        log_prob = c.log_prob(action).unsqueeze(0)
        
        # Add log probability of our chosen action to our history    
        if torch.numel(self.policy_history) == 0:
            self.policy_history = log_prob
        else:
            self.policy_history = torch.cat((self.policy_history, log_prob))
        return action
        
    def update_policy(self):
        R = 0
        rewards = []
        
        # Discount future rewards back to the present using gamma
        for r in self.reward_episode[::-1]:
            R = r + self.gamma * R
            rewards.insert(0,R)
            
        # Scale rewards
        rewards = torch.FloatTensor(rewards)
        rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps)
        
        # Calculate loss
        loss = torch.sum(torch.mul(self.policy_history, Variable(rewards)).mul(-1), -1)
        
        # Update network weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #Save and intialize episode history counters
        self.loss_history.append(loss.item())
        self.reward_history.append(np.sum(self.reward_episode))
        self.policy_history = Variable(torch.Tensor())
        self.reward_episode = []
    
    def forward(self, x):
        """Returns a size-8 vector of one-hot probabilities to move in whichever direction"""
        out = self.linear1(x)
        out = nn.ReLU()(out)
        out = self.linear2(out)
        out = nn.Softmax(dim=0)(out)
        return out


In [3]:
def train(episodes):
    running_reward = 10
    
    for episode in range(episodes):
        
        print(f"Starting episode {episode}")
        
        state = env.reset() # Reset environment and record the starting state
    
        iterator = tqdm(range(250))
        for time in iterator:
            
            action = policy.select_action(state)
            
            # Step through environment using chosen action
            state, reward, done, _ = env.step(action.item())
#             print(state)
            # Save reward
            policy.reward_episode.append(reward)
            if done:
                break
                
            iterator.set_description("â„›={:.2f}".format(reward), refresh=False)
        
        # Used to determine when the environment is solved.
        running_reward = (running_reward * 0.99) + (time * 0.01)

        policy.update_policy()

        if episode % 50 == 0:
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(episode, time, running_reward))

In [None]:
# Hyperparameters
learning_rate = 0.01

env = ScreepsEnv(index=0)
policy = SimplePolicy(env)
optimizer = Adam(policy.parameters(), lr=learning_rate)

train(10)

None
starting interface with index 0
Starting remote server at 22025...
Resetting training environment
Starting episode 0
Resetting training environment


HBox(children=(IntProgress(value=0, max=250), HTML(value='')))


Episode 0	Last length:   249	Average length: 12.39
Starting episode 1
Resetting training environment


HBox(children=(IntProgress(value=0, max=250), HTML(value='')))

In [6]:
env.close()

Stopping
Exiting
Polling
Response: 0


In [9]:
env = ScreepsEnv(index=0)

None
starting interface with index 0
Starting remote server at 22025...
Connected; response: [None]


In [31]:
c.tick()

LostRemote: Lost remote after 10s heartbeat

In [10]:
env.close()

Stopping
Exiting
Polling
Response: None
