In [None]:
import gym
from gym.spaces import Discrete, Box

import numpy as np

In [None]:
def next_state(items, state, action):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        state[idx] += (1 if act == 1 else -1)

    return state

def calc_reward(items, state, max_weight, burst_reward):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_reward
    
    return reward, weight

# Env

In [None]:
class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.max_count = config["max_count"]
        self.burst_reward = config["burst_reward"]
        
        n = self.max_count
        
        self.action_space = Discrete(len(self.items) * 2 + 1)
        self.observation_space = Box(low = -n, high = n, shape = (len(self.items), ))
        
        self.reset()

    def reset(self):
        self.count = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action)
        
        reward, _ = calc_reward(self.items, self.state, self.max_weight, self.burst_reward)
        
        self.count += 1
        done = self.count >= self.max_count
        
        return self.state, reward, done, {}

In [None]:
items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]


In [None]:
config = {
    "env": Knapsack, 
    "env_config": {"items": items, "max_count": 10, "max_weight": 35, "burst_reward": -100}
}


In [None]:
import ray

#ray.shutdown()
ray.init()


In [None]:
ray.__version__

# PPO

In [None]:
from ray.rllib.agents.ppo import PPOTrainer

trainer = PPOTrainer(config = config)

# DQN

In [None]:
from ray.rllib.agents.dqn import DQNTrainer

trainer = DQNTrainer(config = config)

# Train

In [None]:
r_max = []
r_min = []
r_mean = []


In [None]:
from ray.tune.logger import pretty_print

for _ in range(10):
    r = trainer.train()
    print(pretty_print(r))
    
    r_max.append(r["episode_reward_max"])
    r_min.append(r["episode_reward_min"])
    r_mean.append(r["episode_reward_mean"])


In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

plt.plot(r_max, label = "reward_max", color = "red")
plt.plot(r_min, label = "reward_min", color = "green")
plt.plot(r_mean, label = "reward_mean", color = "blue")

plt.legend(loc = "upper left")

plt.ylim([-1000, 3700])
plt.ylabel("reward")

plt.show()

# Evaluate

In [None]:
s = [0 for _ in range(len(items))]

for _ in range(config["env_config"]["max_count"]):
    a = trainer.compute_action(s)
    
    s = next_state(items, s, a)
    
    r, w = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
    
    print(f"{a}, {s}, {r}, {w}")
    

In [None]:
import collections

rs = []

for _ in range(100):
    
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["max_count"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a)

        r, w = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
        
        r_tmp = max(r, r_tmp)
        
        #print(f"{a}, {s}, {r}, {w}")

    rs.append(r_tmp)

collections.Counter(rs)

# Save

In [None]:
checkpoint = trainer.save()
checkpoint

# Load

In [None]:
trainer.restore(checkpoint)
