In [1]:
from collections import deque
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import gym
from tqdm.notebook import tqdm
import numpy as np
from typing import NamedTuple
from itertools import chain

In [2]:
from policy_generator.policy_instances.envs.simple_arena import ActionSpace

In [3]:
#device to run model on 
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
class ObsSpace(NamedTuple):
    agent: np.ndarray
    agent_direction: int
    target: np.ndarray
    velocity: int

In [5]:
#Using a neural network to learn our policy parameters
class QLearningNetwork(nn.Module):
    #Takes in observations and outputs actions
    def __init__(self, observation_space, action_space, shape):
        super(QLearningNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(observation_space, shape),
            nn.ReLU(),
            nn.Linear(shape, shape),
            nn.ReLU(),
            nn.Linear(shape, shape),
            nn.ReLU(),
            nn.Linear(shape, action_space)
        )
    
    #forward pass
    def forward(self, x):
        return self.model(x)

class QBot:
    def __init__(self):
        self.model = None
        self.train_config = None
        
    def training_config(self, **kwargs):
        self.train_config = {
                'discount_factor': 0.923,
                'eps': 0.42,
                'eps_min': 0.075,
                'eps_decay': 0.96,
                'learning_rate': 0.053,
                'num_episodes': 50,
                'batch_size': 32,
                'network_shape': 128,
                }
        if kwargs:
            self.train_config.update(**kwargs)
            
#             b_params = {'discount_factor': 0.9232432057242249,
#                          'eps': 0.41987501329393667,
#                          'eps_min': 0.07487696385957002,
#                          'eps_decay': 0.9625856506789202,
#                          'learning_rate': 0.052947112503709155,
#                          'network_shape': 27}
            
    def _init_environment(self):
        #Make environment
        env = gym.make("policy_instances/SimpleArena-v0")

        network = QLearningNetwork(env.shape, env.action_space.n, self.train_config['network_shape']).to(DEVICE)

        loss_fn = nn.MSELoss()
        optimizer = optim.Adam(network.parameters(), lr=self.train_config['learning_rate'])
        return env, network, loss_fn, optimizer
    
    def train(self, verbose=0):
        env, network, loss_fn, optimizer = self._init_environment()
        
        scores = []

        memory = deque(maxlen=4000)

        for i in tqdm(range(self.train_config['num_episodes']), position=0, leave=True):
            eps = self.train_config['eps']
            state = env.reset()
            done = False
            eps *= self.train_config['eps_decay']
            score = 0
            actions_dist = []
            #while game not ended
            while not done:
                env.render()
                if isinstance(state, tuple):
                    state = state[0]
                #choose move with epsilon greedy
                if np.random.random() < eps:
                    #exploration
                    action = np.random.randint(0, env.action_space.n)
                else:
                    #exploitation
                    #use expand_dims here to add a dimension for input layer
                    action = select_action(network, np.expand_dims(state, axis=0))[0]
                    actions_dist.append(action)

                #execute move
                new_state, reward, done, _, _ = env.step(action)
                score += reward

                #memorize
                memory.append([np.expand_dims(state, axis=0), action, reward, np.expand_dims(new_state, axis=0), done])

                #update
                #instead of training every state, we train in batch_size
                if len(memory) > self.train_config['batch_size']:
                    #sample batch_size so model could be fit on any random states in memory not just the latest state
                    minibatch = random.sample(memory, self.train_config['batch_size'])
                    loss = 0

                    #iterate through the sampled batch
                    for b_state, b_action, b_reward, b_new_state, b_done in minibatch:
                        #if current game is done then target = reward cuz theres no future utility
                        if b_done:
                            target = b_reward
                        else:
                            #what we think the state's q_val should be, reward + discounted future utility
                            target = b_reward + self.train_config['discount_factor'] * select_action(network, b_new_state)[2]

                        #what we thought the current state's q_val should be
                        target_vector = select_action(network, b_state)[1]

                        #update the target_vector 
                        target_vector[0][b_action] = target

                        #instead of finding temporal difference between new q_val and old q_val, we train the model by giving it the new q_val
                        # and let the network do the updating 
                        #train the model with the batch
                        loss = loss_fn(select_action(network, b_state)[1], target_vector)
                        optimizer.zero_grad()
                        loss.backward()

                        step = optimizer.step()
#                     optimizer.zero_grad()
#                     loss.backward()

#                     step = optimizer.step()

                    #update epsilon
                    if eps > self.train_config['eps_min']:
                        eps *= self.train_config['eps_decay']

                #new state
                state = new_state
            if verbose > 0:
                print(score)
                print(list(zip(*np.unique(actions_dist, return_counts=True))))
            scores.append(score)
        return scores

            
def select_action(network, state):
    ''' Selects an action given current state
    Args:
    - network (Torch NN): network to process state
    - state (Array): Array of action space in an environment
    
    Return:
    - (int): action that is selected
    - (float): log probability of selecting that action given state and network
    '''
    #convert state to float tensor, add 1 dimension, allocate tensor on device
    state = ObsSpace(**state[0] if isinstance(state, (tuple, list, np.ndarray)) else state)
    unpack_state = list(chain(state.agent, state.target, [state.velocity, state.agent_direction]))
    state = torch.Tensor(unpack_state).float().unsqueeze(0).to(DEVICE)
    
    #use network to predict action probabilities
    q_vals = network(state)
    
    #sample an action using the probability distribution
    action = torch.argmax(q_vals)
    max_value = torch.max(q_vals)
    
    #return action
    return action.item(), q_vals, max_value

In [6]:
bot = QBot()
bot.training_config(**{"num_episodes": 200})
bot.train(1)

  0%|          | 0/200 [00:00<?, ?it/s]

-34
[(0, 6), (1, 40), (2, 7), (4, 29)]
26
[(0, 12), (1, 43), (2, 16), (3, 4), (4, 15)]
-36
[(0, 11), (1, 35), (2, 19), (3, 8), (4, 13)]
-28
[(0, 11), (1, 38), (2, 13), (3, 8), (4, 16)]
-40
[(0, 13), (1, 33), (2, 17), (3, 9), (4, 20)]


KeyboardInterrupt: 

In [16]:
import optuna

def objective(trial):
    bot = QBot()
    bot.training_config(**{
        'discount_factor': trial.suggest_float('discount_factor', 0.90, 0.99),
        'eps': trial.suggest_float('eps', 0.4, 0.6),
        'eps_min': trial.suggest_float('eps_min', 0.005, 0.1),
        'eps_decay': trial.suggest_float('eps_decay', 0.90, 0.99),
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-1),
        'num_episodes': 20,
        'batch_size': 32,
        'network_shape': trial.suggest_int('network_shape', 16, 128),
                    })
    score = bot.train()
    return -1 * np.mean(score)

study = optuna.create_study()
study.optimize(objective, n_trials=100)

study.best_params

[32m[I 2022-09-19 12:08:52,772][0m A new study created in memory with name: no-name-36dc40b2-ab18-4de0-b16d-0e7bbcc87e24[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:12:12,123][0m Trial 0 finished with value: 333.0 and parameters: {'discount_factor': 0.961149283518651, 'eps': 0.5039940575634445, 'eps_min': 0.09187487028585081, 'eps_decay': 0.9368721683610204, 'learning_rate': 0.06780245382687256, 'network_shape': 46}. Best is trial 0 with value: 333.0.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:15:30,832][0m Trial 1 finished with value: 287.25 and parameters: {'discount_factor': 0.9204618347705573, 'eps': 0.5541758911752037, 'eps_min': 0.02966008613644321, 'eps_decay': 0.9052466827620687, 'learning_rate': 0.0023205613233050378, 'network_shape': 27}. Best is trial 1 with value: 287.25.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:18:15,788][0m Trial 2 finished with value: 290.15 and parameters: {'discount_factor': 0.9657599568701325, 'eps': 0.5760968127381298, 'eps_min': 0.022180455926810384, 'eps_decay': 0.9758470524431863, 'learning_rate': 0.06805540409416005, 'network_shape': 82}. Best is trial 1 with value: 287.25.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:21:02,719][0m Trial 3 finished with value: 357.2 and parameters: {'discount_factor': 0.9755657022088203, 'eps': 0.5026651397028264, 'eps_min': 0.08810196789618005, 'eps_decay': 0.9256995934260474, 'learning_rate': 0.06677159485268187, 'network_shape': 27}. Best is trial 1 with value: 287.25.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:23:44,732][0m Trial 4 finished with value: 291.35 and parameters: {'discount_factor': 0.9154736206094876, 'eps': 0.5282040482250558, 'eps_min': 0.09664420840470955, 'eps_decay': 0.9541866090127227, 'learning_rate': 0.04175005796072797, 'network_shape': 46}. Best is trial 1 with value: 287.25.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:26:16,115][0m Trial 5 finished with value: 243.5 and parameters: {'discount_factor': 0.9232432057242249, 'eps': 0.41987501329393667, 'eps_min': 0.07487696385957002, 'eps_decay': 0.9625856506789202, 'learning_rate': 0.052947112503709155, 'network_shape': 27}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:29:03,607][0m Trial 6 finished with value: 314.3 and parameters: {'discount_factor': 0.9868209047525768, 'eps': 0.5400615465493351, 'eps_min': 0.020663074849635824, 'eps_decay': 0.9246991413940536, 'learning_rate': 0.07993476848982771, 'network_shape': 67}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:31:55,259][0m Trial 7 finished with value: 306.7 and parameters: {'discount_factor': 0.9628895212745423, 'eps': 0.4041883969749285, 'eps_min': 0.06662333921525006, 'eps_decay': 0.9257158569469242, 'learning_rate': 0.08431223683141345, 'network_shape': 104}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:34:48,741][0m Trial 8 finished with value: 300.5 and parameters: {'discount_factor': 0.972701445799947, 'eps': 0.48564461978575524, 'eps_min': 0.05894362227766433, 'eps_decay': 0.980417836759394, 'learning_rate': 0.07235538608777568, 'network_shape': 100}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:37:37,294][0m Trial 9 finished with value: 284.8 and parameters: {'discount_factor': 0.9218981272306215, 'eps': 0.5274292596565957, 'eps_min': 0.024921208879632673, 'eps_decay': 0.9728961938094166, 'learning_rate': 0.050953253496474754, 'network_shape': 34}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:40:28,510][0m Trial 10 finished with value: 312.0 and parameters: {'discount_factor': 0.9382937119557917, 'eps': 0.40040910627871346, 'eps_min': 0.07079996153507256, 'eps_decay': 0.9565450689623964, 'learning_rate': 0.022674871534790836, 'network_shape': 67}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:43:06,484][0m Trial 11 finished with value: 259.55 and parameters: {'discount_factor': 0.9020403421283949, 'eps': 0.45260838531611824, 'eps_min': 0.04200571980399199, 'eps_decay': 0.9677347211244547, 'learning_rate': 0.04452318659690049, 'network_shape': 19}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:46:10,177][0m Trial 12 finished with value: 337.05 and parameters: {'discount_factor': 0.9024333296156373, 'eps': 0.4464886831248577, 'eps_min': 0.046468622859365386, 'eps_decay': 0.9627665705450368, 'learning_rate': 0.04202512330223472, 'network_shape': 18}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:49:27,838][0m Trial 13 finished with value: 280.15 and parameters: {'discount_factor': 0.900527639311822, 'eps': 0.44552892504072783, 'eps_min': 0.043007392903319334, 'eps_decay': 0.9898998637765115, 'learning_rate': 0.028205780539233324, 'network_shape': 123}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:52:34,812][0m Trial 14 finished with value: 285.8 and parameters: {'discount_factor': 0.9361690368371527, 'eps': 0.4424294212584185, 'eps_min': 0.007001599072509093, 'eps_decay': 0.9637444923924804, 'learning_rate': 0.09901368143375255, 'network_shape': 48}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:56:27,846][0m Trial 15 finished with value: 292.1 and parameters: {'discount_factor': 0.9106980188405855, 'eps': 0.4700349340460984, 'eps_min': 0.07631616611539907, 'eps_decay': 0.9442749857298575, 'learning_rate': 0.05496416902124114, 'network_shape': 17}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 12:59:51,773][0m Trial 16 finished with value: 308.55 and parameters: {'discount_factor': 0.9297243343899873, 'eps': 0.4258698824652537, 'eps_min': 0.038451238828430025, 'eps_decay': 0.9498841725936474, 'learning_rate': 0.026162876037118983, 'network_shape': 54}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:03:08,990][0m Trial 17 finished with value: 289.5 and parameters: {'discount_factor': 0.9106002695190426, 'eps': 0.4679101458391433, 'eps_min': 0.05521292736993094, 'eps_decay': 0.9878001447323808, 'learning_rate': 0.03730962742237273, 'network_shape': 33}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:06:45,301][0m Trial 18 finished with value: 505.8 and parameters: {'discount_factor': 0.9445001153833145, 'eps': 0.4294505544030271, 'eps_min': 0.08063505796740161, 'eps_decay': 0.9666171516290221, 'learning_rate': 0.05581853884426979, 'network_shape': 83}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:09:56,296][0m Trial 19 finished with value: 323.1 and parameters: {'discount_factor': 0.9275072657267993, 'eps': 0.5976543969842596, 'eps_min': 0.06347910566561282, 'eps_decay': 0.9416328870151569, 'learning_rate': 0.011219886830063683, 'network_shape': 59}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:12:53,189][0m Trial 20 finished with value: 299.9 and parameters: {'discount_factor': 0.952652728536153, 'eps': 0.420203049228468, 'eps_min': 0.05077711728165197, 'eps_decay': 0.9707611054548358, 'learning_rate': 0.03611895244501147, 'network_shape': 40}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:15:39,365][0m Trial 21 finished with value: 256.55 and parameters: {'discount_factor': 0.9000415814386163, 'eps': 0.456587989850518, 'eps_min': 0.04203199406175129, 'eps_decay': 0.9862980020057288, 'learning_rate': 0.02768486309277752, 'network_shape': 114}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:18:30,265][0m Trial 22 finished with value: 313.95 and parameters: {'discount_factor': 0.9072753662928454, 'eps': 0.4662708596168669, 'eps_min': 0.0373415484859001, 'eps_decay': 0.9813942079123562, 'learning_rate': 0.018245174690514857, 'network_shape': 126}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:21:26,770][0m Trial 23 finished with value: 267.8 and parameters: {'discount_factor': 0.9194533035760254, 'eps': 0.48399316210284793, 'eps_min': 0.03103786076313951, 'eps_decay': 0.9814570762289655, 'learning_rate': 0.04771409561624446, 'network_shape': 113}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:24:18,024][0m Trial 24 finished with value: 268.8 and parameters: {'discount_factor': 0.9012800331669323, 'eps': 0.4607826104685591, 'eps_min': 0.010070367208906791, 'eps_decay': 0.9596626127908405, 'learning_rate': 0.05996226492844958, 'network_shape': 95}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:27:22,986][0m Trial 25 finished with value: 270.05 and parameters: {'discount_factor': 0.9122928607921339, 'eps': 0.4183008390131988, 'eps_min': 0.05165645936944031, 'eps_decay': 0.9696403051422513, 'learning_rate': 0.03281426476286201, 'network_shape': 87}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:30:18,407][0m Trial 26 finished with value: 329.35 and parameters: {'discount_factor': 0.9264760890719698, 'eps': 0.4382793254203726, 'eps_min': 0.08330906650323627, 'eps_decay': 0.9765816949415721, 'learning_rate': 0.018684630620693386, 'network_shape': 16}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:33:09,362][0m Trial 27 finished with value: 294.5 and parameters: {'discount_factor': 0.9073267742267811, 'eps': 0.45487241225147934, 'eps_min': 0.04172621730019866, 'eps_decay': 0.952524432360578, 'learning_rate': 0.04595211942905824, 'network_shape': 111}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:36:10,264][0m Trial 28 finished with value: 305.25 and parameters: {'discount_factor': 0.9363572617069099, 'eps': 0.4824076024799333, 'eps_min': 0.07281555599798896, 'eps_decay': 0.9845004031163413, 'learning_rate': 0.010770537021020432, 'network_shape': 26}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:39:12,792][0m Trial 29 finished with value: 273.2 and parameters: {'discount_factor': 0.9165558538115819, 'eps': 0.4138259566930562, 'eps_min': 0.06102644011539608, 'eps_decay': 0.9342665323529074, 'learning_rate': 0.0352906528929673, 'network_shape': 76}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:42:17,268][0m Trial 30 finished with value: 302.6 and parameters: {'discount_factor': 0.9510967328014727, 'eps': 0.4928877523353116, 'eps_min': 0.01516365086271073, 'eps_decay': 0.9664775419927011, 'learning_rate': 0.06103910678061443, 'network_shape': 39}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:45:19,562][0m Trial 31 finished with value: 362.2 and parameters: {'discount_factor': 0.9203687768478876, 'eps': 0.47626870164185153, 'eps_min': 0.0345160393478457, 'eps_decay': 0.978564405653511, 'learning_rate': 0.04638630191703614, 'network_shape': 115}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:48:17,827][0m Trial 32 finished with value: 330.25 and parameters: {'discount_factor': 0.906950892101215, 'eps': 0.5160748490191691, 'eps_min': 0.032399134950359953, 'eps_decay': 0.9853657393778402, 'learning_rate': 0.04989973982223846, 'network_shape': 112}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:51:08,205][0m Trial 33 finished with value: 269.35 and parameters: {'discount_factor': 0.9186456770554765, 'eps': 0.45440240869106774, 'eps_min': 0.04703867333546632, 'eps_decay': 0.9775742659260322, 'learning_rate': 0.02994066704438118, 'network_shape': 92}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:54:01,782][0m Trial 34 finished with value: 271.45 and parameters: {'discount_factor': 0.9244072086056627, 'eps': 0.4322863092090061, 'eps_min': 0.029096730946900903, 'eps_decay': 0.9018239744600379, 'learning_rate': 0.04225950549883806, 'network_shape': 119}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:56:53,081][0m Trial 35 finished with value: 285.55 and parameters: {'discount_factor': 0.9317242332340007, 'eps': 0.5063001407175899, 'eps_min': 0.030275069540249062, 'eps_decay': 0.9720773327449259, 'learning_rate': 0.06692193918218131, 'network_shape': 104}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 13:59:50,120][0m Trial 36 finished with value: 314.5 and parameters: {'discount_factor': 0.9130708689244466, 'eps': 0.49067481249556977, 'eps_min': 0.09768047979610572, 'eps_decay': 0.9840742640187423, 'learning_rate': 0.05965018613555617, 'network_shape': 26}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:02:44,666][0m Trial 37 finished with value: 296.9 and parameters: {'discount_factor': 0.9054394862006312, 'eps': 0.45360723570576916, 'eps_min': 0.021080485690601693, 'eps_decay': 0.9587936078500843, 'learning_rate': 0.05058156949718323, 'network_shape': 76}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:05:36,986][0m Trial 38 finished with value: 346.95 and parameters: {'discount_factor': 0.9168175949072784, 'eps': 0.5642941292950694, 'eps_min': 0.08826504954950629, 'eps_decay': 0.9738840964345181, 'learning_rate': 0.00040500787233980345, 'network_shape': 128}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:08:41,844][0m Trial 39 finished with value: 312.65 and parameters: {'discount_factor': 0.9440328343442953, 'eps': 0.41109742633107915, 'eps_min': 0.02543615038650547, 'eps_decay': 0.9899236506984791, 'learning_rate': 0.07201471323649261, 'network_shape': 105}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:12:28,614][0m Trial 40 finished with value: 272.2 and parameters: {'discount_factor': 0.9219037269100138, 'eps': 0.5044273495452218, 'eps_min': 0.0562133102419478, 'eps_decay': 0.9826067046756578, 'learning_rate': 0.07937222676946318, 'network_shape': 59}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:15:37,896][0m Trial 41 finished with value: 312.75 and parameters: {'discount_factor': 0.9013431734907722, 'eps': 0.4615000093739598, 'eps_min': 0.005808852407811296, 'eps_decay': 0.9608798246848349, 'learning_rate': 0.062036986229461716, 'network_shape': 93}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:19:07,288][0m Trial 42 finished with value: 294.6 and parameters: {'discount_factor': 0.9057768529022464, 'eps': 0.4786465985430229, 'eps_min': 0.014514961619319842, 'eps_decay': 0.9517135204024784, 'learning_rate': 0.055750959847004185, 'network_shape': 97}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:22:46,074][0m Trial 43 finished with value: 343.8 and parameters: {'discount_factor': 0.900789493045145, 'eps': 0.45897788377538934, 'eps_min': 0.015252198615228724, 'eps_decay': 0.96749495696047, 'learning_rate': 0.040667975487691625, 'network_shape': 119}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:26:14,987][0m Trial 44 finished with value: 298.35 and parameters: {'discount_factor': 0.9137958022158995, 'eps': 0.4374551811633576, 'eps_min': 0.04625093647854328, 'eps_decay': 0.9477984192071232, 'learning_rate': 0.04805228255295631, 'network_shape': 108}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:29:48,947][0m Trial 45 finished with value: 317.55 and parameters: {'discount_factor': 0.9866778050962999, 'eps': 0.4737707657080069, 'eps_min': 0.04030676839287313, 'eps_decay': 0.9106943125753981, 'learning_rate': 0.06250407997686803, 'network_shape': 22}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:33:16,652][0m Trial 46 finished with value: 304.6 and parameters: {'discount_factor': 0.9040969723161207, 'eps': 0.4482009409729332, 'eps_min': 0.025854325849780018, 'eps_decay': 0.9560568050383423, 'learning_rate': 0.05365657595505758, 'network_shape': 33}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:36:34,176][0m Trial 47 finished with value: 313.35 and parameters: {'discount_factor': 0.9097249343003336, 'eps': 0.49259441403971, 'eps_min': 0.010723752278286593, 'eps_decay': 0.9623970759463727, 'learning_rate': 0.07055394991547854, 'network_shape': 100}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:39:20,009][0m Trial 48 finished with value: 293.8 and parameters: {'discount_factor': 0.9000826459380852, 'eps': 0.40648518877872186, 'eps_min': 0.03685284724522832, 'eps_decay': 0.9763056366298405, 'learning_rate': 0.07623923523708045, 'network_shape': 121}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:42:01,309][0m Trial 49 finished with value: 381.2 and parameters: {'discount_factor': 0.9091355282062648, 'eps': 0.42726773686503433, 'eps_min': 0.06474562555167401, 'eps_decay': 0.9576879150770636, 'learning_rate': 0.04310553268499151, 'network_shape': 115}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:44:51,687][0m Trial 50 finished with value: 299.2 and parameters: {'discount_factor': 0.9316101195480816, 'eps': 0.517870532049557, 'eps_min': 0.04404294110616024, 'eps_decay': 0.9386275836682137, 'learning_rate': 0.022351673440121702, 'network_shape': 68}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:47:43,037][0m Trial 51 finished with value: 312.2 and parameters: {'discount_factor': 0.9187592673143874, 'eps': 0.4493980049431178, 'eps_min': 0.048587882165265196, 'eps_decay': 0.9800207037408522, 'learning_rate': 0.0299396732238222, 'network_shape': 89}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:50:35,274][0m Trial 52 finished with value: 299.6 and parameters: {'discount_factor': 0.9162756494932406, 'eps': 0.46232285392364547, 'eps_min': 0.05636433031647497, 'eps_decay': 0.9743037324119515, 'learning_rate': 0.03242652930384, 'network_shape': 97}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:53:18,601][0m Trial 53 finished with value: 279.85 and parameters: {'discount_factor': 0.9230953650135364, 'eps': 0.44112393622353885, 'eps_min': 0.049671284593094056, 'eps_decay': 0.9646557603474704, 'learning_rate': 0.08956135245981461, 'network_shape': 88}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:55:47,375][0m Trial 54 finished with value: 254.35 and parameters: {'discount_factor': 0.9042163803103526, 'eps': 0.4831590258243272, 'eps_min': 0.04563629520407415, 'eps_decay': 0.9692136679088366, 'learning_rate': 0.039184918664191735, 'network_shape': 80}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 14:58:26,343][0m Trial 55 finished with value: 295.25 and parameters: {'discount_factor': 0.9023777662805514, 'eps': 0.48549984048866085, 'eps_min': 0.03380418503450741, 'eps_decay': 0.9681984433735727, 'learning_rate': 0.037329596348282715, 'network_shape': 82}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:01:05,297][0m Trial 56 finished with value: 284.0 and parameters: {'discount_factor': 0.9037450970614067, 'eps': 0.4711076221770268, 'eps_min': 0.04386861714430511, 'eps_decay': 0.9866478393706781, 'learning_rate': 0.05790413188998715, 'network_shape': 77}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:03:33,451][0m Trial 57 finished with value: 261.85 and parameters: {'discount_factor': 0.9117066033774038, 'eps': 0.5426877866983916, 'eps_min': 0.05438511146286216, 'eps_decay': 0.9715612808838497, 'learning_rate': 0.05258922134777687, 'network_shape': 51}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:06:24,961][0m Trial 58 finished with value: 316.05 and parameters: {'discount_factor': 0.912719587973138, 'eps': 0.5543994361013974, 'eps_min': 0.05333762243865384, 'eps_decay': 0.9711386427290304, 'learning_rate': 0.039747862506020304, 'network_shape': 45}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:09:17,079][0m Trial 59 finished with value: 339.55 and parameters: {'discount_factor': 0.9083720347241839, 'eps': 0.5337757470216922, 'eps_min': 0.03909478304643466, 'eps_decay': 0.9804327099374831, 'learning_rate': 0.04560285491115749, 'network_shape': 52}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:11:57,464][0m Trial 60 finished with value: 330.15 and parameters: {'discount_factor': 0.9099906593143684, 'eps': 0.5960002381134404, 'eps_min': 0.05968865570212383, 'eps_decay': 0.9745887068550294, 'learning_rate': 0.05258845236163396, 'network_shape': 21}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:14:42,331][0m Trial 61 finished with value: 299.9 and parameters: {'discount_factor': 0.9044094889239448, 'eps': 0.49856878967972684, 'eps_min': 0.06906760143804301, 'eps_decay': 0.9644345921668364, 'learning_rate': 0.06379852477446413, 'network_shape': 29}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:17:25,010][0m Trial 62 finished with value: 269.85 and parameters: {'discount_factor': 0.9153903089615268, 'eps': 0.46777712810407096, 'eps_min': 0.0758048511826406, 'eps_decay': 0.9602905304430747, 'learning_rate': 0.04903311066969456, 'network_shape': 71}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:20:19,292][0m Trial 63 finished with value: 299.45 and parameters: {'discount_factor': 0.9001025964874425, 'eps': 0.5477350855574329, 'eps_min': 0.029233655802492115, 'eps_decay': 0.9550558965180961, 'learning_rate': 0.05743921278037831, 'network_shape': 62}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:23:08,285][0m Trial 64 finished with value: 278.9 and parameters: {'discount_factor': 0.9119250092596446, 'eps': 0.5135788259796936, 'eps_min': 0.0178849925448216, 'eps_decay': 0.967347193983475, 'learning_rate': 0.0447129299739174, 'network_shape': 39}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:25:53,001][0m Trial 65 finished with value: 360.9 and parameters: {'discount_factor': 0.9066557258796758, 'eps': 0.4806790789916853, 'eps_min': 0.09480094650007655, 'eps_decay': 0.9697696695964584, 'learning_rate': 0.025694129914588446, 'network_shape': 23}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:28:40,265][0m Trial 66 finished with value: 319.65 and parameters: {'discount_factor': 0.9046245306377942, 'eps': 0.4981727141745808, 'eps_min': 0.05348020487277615, 'eps_decay': 0.9782698155315969, 'learning_rate': 0.03355224562060747, 'network_shape': 83}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:31:48,866][0m Trial 67 finished with value: 276.25 and parameters: {'discount_factor': 0.9261019602719679, 'eps': 0.43331405277915286, 'eps_min': 0.036167798999535705, 'eps_decay': 0.9728848836178643, 'learning_rate': 0.038909634205955, 'network_shape': 43}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:35:29,622][0m Trial 68 finished with value: 309.45 and parameters: {'discount_factor': 0.9603990100126655, 'eps': 0.42079334318356193, 'eps_min': 0.010017561307250462, 'eps_decay': 0.9874174851164934, 'learning_rate': 0.053270204652352264, 'network_shape': 108}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:38:47,975][0m Trial 69 finished with value: 247.45 and parameters: {'discount_factor': 0.9188465738665691, 'eps': 0.5777231998722732, 'eps_min': 0.0413751696725178, 'eps_decay': 0.9818785808150706, 'learning_rate': 0.06498457067052898, 'network_shape': 30}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:42:14,470][0m Trial 70 finished with value: 283.15 and parameters: {'discount_factor': 0.9297355436246113, 'eps': 0.5681470393790147, 'eps_min': 0.04302313363039839, 'eps_decay': 0.9826473178739942, 'learning_rate': 0.06611018393059111, 'network_shape': 30}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:45:45,829][0m Trial 71 finished with value: 307.4 and parameters: {'discount_factor': 0.9189630211389542, 'eps': 0.5729470788473932, 'eps_min': 0.0817187388082237, 'eps_decay': 0.9648568084559452, 'learning_rate': 0.058327836576360634, 'network_shape': 35}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:49:17,663][0m Trial 72 finished with value: 288.9 and parameters: {'discount_factor': 0.914134238096938, 'eps': 0.5823892771299087, 'eps_min': 0.040584377861766455, 'eps_decay': 0.9766657068125331, 'learning_rate': 0.06514569143633744, 'network_shape': 18}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:52:48,091][0m Trial 73 finished with value: 305.3 and parameters: {'discount_factor': 0.9075000109242606, 'eps': 0.5852530894906038, 'eps_min': 0.04719503183017766, 'eps_decay': 0.9853302235071416, 'learning_rate': 0.06958348445601502, 'network_shape': 50}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:56:02,515][0m Trial 74 finished with value: 278.15 and parameters: {'discount_factor': 0.9029241015356163, 'eps': 0.526529435900279, 'eps_min': 0.032874792959691986, 'eps_decay': 0.9808585187176648, 'learning_rate': 0.05122393613065901, 'network_shape': 26}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 15:59:35,586][0m Trial 75 finished with value: 274.85 and parameters: {'discount_factor': 0.9112573820956965, 'eps': 0.4653945781053541, 'eps_min': 0.04496400999042894, 'eps_decay': 0.9705303667350929, 'learning_rate': 0.048113939034630016, 'network_shape': 35}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[32m[I 2022-09-19 16:02:46,627][0m Trial 76 finished with value: 265.9 and parameters: {'discount_factor': 0.9402241326062796, 'eps': 0.5424744667471213, 'eps_min': 0.022253502283242216, 'eps_decay': 0.961973286393732, 'learning_rate': 0.07479473831652023, 'network_shape': 79}. Best is trial 5 with value: 243.5.[0m


  0%|          | 0/20 [00:00<?, ?it/s]

[33m[W 2022-09-19 16:03:39,184][0m Trial 77 failed because of the following error: KeyboardInterrupt()[0m
Traceback (most recent call last):
  File "/home/beast/.local/lib/python3.8/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_685/1136845993.py", line 15, in objective
    score = bot.train()
  File "/tmp/ipykernel_685/1305696908.py", line 117, in train
    loss.backward()
  File "/home/beast/.local/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/beast/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt


KeyboardInterrupt: 

In [17]:
study.best_params

{'discount_factor': 0.9232432057242249,
 'eps': 0.41987501329393667,
 'eps_min': 0.07487696385957002,
 'eps_decay': 0.9625856506789202,
 'learning_rate': 0.052947112503709155,
 'network_shape': 27}

In [3]:
discount_factor = 0.95
eps = 0.5
eps_min = 0.01
eps_decay = 0.99
learning_rate = 0.8
num_episodes = 50
batch_size = 32

In [5]:
#Using a neural network to learn our policy parameters
class QLearningNetwork(nn.Module):
    
    #Takes in observations and outputs actions
    def __init__(self, observation_space, action_space):
        super(QLearningNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(observation_space, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, action_space)
        )
    
    #forward pass
    def forward(self, x):
        return self.model(x)

In [6]:
def select_action(network, state):
    ''' Selects an action given current state
    Args:
    - network (Torch NN): network to process state
    - state (Array): Array of action space in an environment
    
    Return:
    - (int): action that is selected
    - (float): log probability of selecting that action given state and network
    '''
    #convert state to float tensor, add 1 dimension, allocate tensor on device
    state = ObsSpace(**state[0] if isinstance(state, (tuple, list, np.ndarray)) else state)
    unpack_state = list(chain(state.agent, state.target, [state.velocity, state.agent_direction]))
    state = torch.Tensor(unpack_state).float().unsqueeze(0).to(DEVICE)
    
    #use network to predict action probabilities
    q_vals = network(state)
    
    #sample an action using the probability distribution
    action = torch.argmax(q_vals)
    max_value = torch.max(q_vals)
    
    #return action
    return action.item(), q_vals, max_value

In [7]:
#Make environment
env = gym.make("policy_instances/SimpleArena-v0")

network = QLearningNetwork(env.shape, env.action_space.n).to(DEVICE)

loss_fn = nn.MSELoss()
optimizer = optim.Adam(network.parameters(), lr=1e-3)

In [11]:
def check_availability():
    state = env.reset()
    print(select_action(network, state))
check_availability()

(1, tensor([[-49.4963, -48.6610, -49.2828, -51.9491, -49.0963]], device='cuda:0',
       grad_fn=<AddmmBackward0>), tensor(-48.6610, device='cuda:0', grad_fn=<MaxBackward1>))


In [9]:
scores = []

memory = deque(maxlen=4000)

for i in tqdm(range(num_episodes), position=0, leave=True):
    state = env.reset()
    done = False
    eps *= eps_decay
    score = 0
    actions_dist = []
    #while game not ended
    while not done:
        env.render()
        if isinstance(state, tuple):
            state = state[0]
        #choose move with epsilon greedy
        if np.random.random() < eps:
            #exploration
            action = np.random.randint(0, env.action_space.n)
        else:
            #exploitation
            #use expand_dims here to add a dimension for input layer
#             q_vals = model(state_torch)#.to('cpu').detach().numpy()
#             action = torch.argmax(q_vals).item()
            action = select_action(network, np.expand_dims(state, axis=0))[0]
            actions_dist.append(action)
        
        #execute move
        new_state, reward, done, _, _ = env.step(action)
        score += reward
        
        #modify reward so it scales with pole angle. Pole angle range [-0.418, 0.418]
        # reward = 1 - abs(state[2])/0.418
        
        #memorize
        memory.append([np.expand_dims(state, axis=0), action, reward, np.expand_dims(new_state, axis=0), done])
        
        #update
        #instead of training every state, we train in batch_size
        if len(memory) > batch_size:
            #sample batch_size so model could be fit on any random states in memory not just the latest state
            minibatch = random.sample(memory, batch_size)
            loss = 0
            
            #iterate through the sampled batch
            for b_state, b_action, b_reward, b_new_state, b_done in minibatch:
                #if current game is done then target = reward cuz theres no future utility
                if b_done:
                    target = b_reward
                else:
                    #what we think the state's q_val should be, reward + discounted future utility
                    target = b_reward + discount_factor * select_action(network, b_new_state)[2]
                
                #what we thought the current state's q_val should be
                target_vector = select_action(network, b_state)[1]
                
                #update the target_vector 
                target_vector[0][b_action] = target
                
                #instead of finding temporal difference between new q_val and old q_val, we train the model by giving it the new q_val
                # and let the network do the updating 
                #train the model with the batch
                loss += loss_fn(select_action(network, b_state)[1], target_vector)
                # optimizer.zero_grad()
                # loss.backward()
                
                # step = optimizer.step()
            optimizer.zero_grad()
            loss.backward()

            step = optimizer.step()

            #update epsilon
            if eps > eps_min:
                eps *= eps_decay
                
        #new state
        state = new_state
    print(score)
    print(list(zip(*np.unique(actions_dist, return_counts=True))))
    scores.append(score)

  0%|          | 0/50 [00:00<?, ?it/s]

  logger.warn(f"{pre} is not within the observation space.")


-379
[(0, 27), (1, 6), (2, 31)]
-511
[(0, 11), (1, 9), (2, 63)]
-429
[(0, 2), (3, 6), (4, 89)]
-446
[(0, 1), (4, 98)]
-380
[(0, 14), (1, 23), (2, 14), (3, 2), (4, 47)]
-420
[(0, 2), (1, 8), (2, 17), (3, 54), (4, 18)]
-385
[(0, 18), (1, 15), (2, 22), (3, 1), (4, 45)]
-310
[(0, 23), (1, 20), (2, 26), (3, 7), (4, 25)]
-375
[(0, 10), (1, 2), (2, 25), (3, 26), (4, 37)]
-377
[(0, 9), (1, 9), (2, 30), (3, 28), (4, 25)]
-376
[(0, 3), (1, 4), (2, 23), (3, 27), (4, 43)]
-392
[(0, 13), (1, 7), (2, 37), (3, 11), (4, 33)]
-429
[(0, 17), (1, 2), (2, 46), (3, 9), (4, 26)]
-412
[(0, 15), (1, 4), (2, 43), (3, 22), (4, 15)]
-361
[(0, 13), (1, 15), (2, 22), (3, 7), (4, 43)]
-308
[(0, 34), (1, 13), (2, 16), (3, 15), (4, 22)]
-370
[(0, 18), (1, 4), (2, 8), (3, 50), (4, 20)]
-383
[(0, 11), (1, 4), (2, 18), (3, 22), (4, 46)]
-385
[(0, 12), (1, 5), (2, 10), (3, 8), (4, 66)]
-379
[(0, 13), (1, 14), (2, 28), (3, 5), (4, 41)]
-377
[(0, 12), (1, 3), (2, 14), (3, 20), (4, 51)]
-319
[(0, 15), (1, 6), (2, 20), (3, 3

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

plt.plot(scores)
plt.ylabel('score')
plt.xlabel('episodes')
plt.title('Score of RL Agent over episodes')

reg = LinearRegression().fit(np.arange(len(scores)).reshape(-1, 1), np.array(scores).reshape(-1, 1))
y_pred = reg.predict(np.arange(len(scores)).reshape(-1, 1))
plt.plot(y_pred)

In [None]:
scores = []
while len(scores) < 50:
    state = env.reset()
    done = False
    score = 0
    while not done:
        env.render()
        action = select_action(network, state)[0]
        print(action)

        new_state, reward, done, _,_ = env.step(action)
        score += reward
        state = new_state
    scores.append(score)

In [None]:
np.array(scores).mean()