In [None]:
import gym
from gym.spaces import Discrete, Box

import numpy as np

In [None]:
def step_items(items, state, action):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        state[idx] += (1 if act == 1 else -1)

    return state

def calc_reward(items, state, max_weight, burst_reward):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_reward
    
    return reward, weight

# Env

In [None]:
class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.max_count = config["max_count"]
        self.burst_reward = config["burst_reward"]
        
        h = self.max_count
        
        self.action_space = Discrete(len(self.items) * 2 + 1)
        self.observation_space = Box(low = -h, high = h, shape = (len(self.items), ))
        
        self.reset()

    def reset(self):
        self.count = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = step_items(self.items, self.state, action)
        
        reward, _ = calc_reward(self.items, self.state, self.max_weight, self.burst_reward)
        
        self.count += 1
        done = self.count >= self.max_count
        
        return self.state, reward, done, {}

In [None]:
items = [
    [120, 10],
    [130, 12],
    [80, 7],
    [100, 9],
    [250, 21],
    [185, 16]
]


In [None]:
config = {
    "env": Knapsack, 
    "env_config": {"items": items, "max_count": 20, "max_weight": 65, "burst_reward": -100}
}


In [None]:
import ray

#ray.shutdown()
ray.init()


In [None]:
from ray.rllib.agents.ppo import PPOTrainer

trainer = PPOTrainer(config = config)

In [None]:
from ray.rllib.agents.dqn import DQNTrainer

trainer = DQNTrainer(config = config)

# Train

In [None]:
from ray.tune.logger import pretty_print

for _ in range(30):
    r = trainer.train()
    print(pretty_print(r))

# Evaluate

In [None]:
s = [0 for _ in range(len(items))]

for _ in range(config["env_config"]["max_count"]):
    a = trainer.compute_action(s)
    
    s = step_items(items, s, a)
    
    r, w = calc_reward(items, s, config["env_config"]["max_weight"], -1)
    
    print(f"{a}, {s}, {r}, {w}")
    

In [None]:
import collections

rs = []

for _ in range(100):

    ts = []
    
    s = [0 for _ in range(len(items))]

    for _ in range(config["env_config"]["max_count"]):
        a = trainer.compute_action(s)
        s = step_items(items, s, a)

        r, w = calc_reward(items, s, config["env_config"]["max_weight"], -1)
        
        ts.append((r, s.copy()))
        
        #print(f"{a}, {s}, {r}, {w}")
    
    idx = np.argmax([r for (r, _) in ts])
    t = ts[idx]

    rs.append(t[0])

collections.Counter(rs)


# Save

In [None]:
checkpoint = trainer.save()
checkpoint

# Load

In [None]:
trainer.restore(checkpoint)
