# Demo: Knapsack Problem

In this notebook, we demonstrate how to solve the knapsack problem, a classic Operations Research problem. In this problem, we have a knapsack with fixed weight capacity and a set of items each associated with a weight and a value. We want fit items into the knapsack in a way such as the total value of all fitted items is as high as possible, however, the sum of item weights cannot exceed the knapsack weight capacity. 

While is problem is traditionally solved with linear programming, we can convert it into a contextual bandit problem (a simplified 1-stage Markov decision problem) and solve it using pymdp.

Let us define our actions `a_i` as whether to include an item or not for each item i. The state `s_i` of the system is defined as whether an item is included or not, i.e., copying the action variables over to the corresponding state variables. We also need another state variable `z` which represents whether the knapsack capacity is exceeded. If an item is included, i.e., `s_i = 1`, we get a reward `r_i`, otherwise, we get a reward of 0 when `s_i = 0`. We can thus define our preference of including valueable items to be proportional to the expnenital of reward: `C[s_i] = softmax([0, r_i])`. Our preference on the capacity constraint variable `z` is to never violate it, i.e., `C[z] = [1, 0]`. Since the system is fully observable, we will set all observation matrices to diagonal.

In [25]:
import numpy as np
from jax import numpy as jnp
from jax import tree_util as jtu
import jax.nn as nn

from pymdp.jax.agent import Agent
from pymdp.jax import distribution

## Specify model structure

In [28]:
# knapsack problem setup
num_items = 5  # agents
max_capacity = 20
item_weights = np.random.uniform(2, 8, size=(num_items,))
rewards = np.random.uniform(0, 1, size=(num_items,))

state_config = {
    f"s_{i}": {"elements": ["not enclude", "include"], "depends_on": [f"s_{i}"], "controlled_by": [f"a_{i}"]} 
    for i in range(num_items)
}
state_config["z"] = {
    "elements": ["not violated", "violated"], "depends_on": ["z"], "controlled_by": [f"a_{i}" for i in range(num_items)]
} 

obs_config = {
    k: {"elements": v["elements"], "depends_on": [k]} for k, v in state_config.items()
}

act_config = {
    f"a_{i}": {"elements": ["not enclude", "include"]} 
    for i in range(num_items)
}

model = {
    "observations": obs_config,
    "controls": act_config,
    "states": state_config,
}

## Specify model parameters

In [10]:
As, Bs = distribution.compile_model(model)

print("A shapes", [a.data.shape for a in As])
print("B shapes", [a.data.shape for a in Bs])

A shapes [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
B shapes [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2, 2, 2, 2, 2)]


In [117]:
def create_identity_transition_factor(mat):
    for i in range(mat.shape[1]):
        mat[:, i] = np.eye(len(mat))
    return mat

def create_constraint_factor_z_greater_than(act_dim, maximum, num_items, weights):
    # Create an array of shape (2, act_dim, act_dim, ..., act_dim)
    tensor_shape = (2,) + (act_dim,) * num_items
    
    # Create an array with indices from 0 to act_dim - 1 along each dimension
    indices = np.indices(tensor_shape[1:])

    # Reshape weights to fit indices shape
    weights_reshaped = np.array(weights).reshape((-1,) + (1,) * (indices.ndim - 1))
    # Multiply weights with matrix that conforms to constraint
    result = np.array(indices == 1) * weights_reshaped

    # Calculate the total for each combination of actions
    total = np.sum(result, axis=0)
    
    # Create the tensor based on the total hours condition
    tensor = np.where(total > maximum, 1, 0)

    # Stack the tensor along the first axis to create the final tensor
    tensor = np.stack((1 - tensor, tensor), axis=0)

    # Stack to make a copy for self state
    # tensor = np.expand_dims(tensor, axis=1)
    tensor = np.stack([tensor, tensor], axis=1)
    return tensor

# update A tensor
for i in range(len(As)):
    As[i].data = np.eye(len(As[i].data))

# update B tensors
for i in range(num_items):
    Bs[i].data = create_identity_transition_factor(Bs[i].data)

Bs[-1].data = create_constraint_factor_z_greater_than(2, max_capacity, num_items, item_weights)

# create C tensors
preferences = nn.softmax(np.stack([np.zeros_like(rewards), rewards], axis=-1), axis=-1)
Cs = [None for _ in range(len(As))]
for i in range(len(As)):
    Cs[i] = preferences[i]
    
Cs[-1] = np.array([1., 0]) # capacity constraint cannot be violated

print("A shapes", [a.data.shape for a in As])
print("B shapes", [a.data.shape for a in Bs])

print("A normalized", [np.isclose(a.data.sum(0), 1.).all() for a in As])
print("B normalized", [np.isclose(a.data.sum(0), 1.).all() for a in As])
print("C normalized", [np.isclose(a.sum(0), 1.).all() for a in Cs])

A shapes [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
B shapes [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2, 2, 2, 2, 2)]
A normalized [True, True, True, True, True, True]
B normalized [True, True, True, True, True, True]
C normalized [True, True, True, True, True, True]


## Run agent

In [123]:
B_action_dependencies = [
    [list(model["controls"].keys()).index(i) for i in s["controlled_by"]] 
    for s in model["states"].values()
]
num_controls = [len(c["elements"]) for c in model["controls"].values()]

agent = Agent(
    As, Bs, Cs,
    B_action_dependencies=B_action_dependencies,
    num_controls=num_controls,
)

In [124]:
qs = jtu.tree_map(lambda x: jnp.expand_dims(x, axis=0), agent.D)
q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)

In [162]:
print("best action", action)
print("best action multi", action_multi)
print("item weights")
print(item_weights)
print("item rewards")
print(rewards)

best action [[1 1 1 1 1 0]]
best action multi [[0 0 0 0 0]]
item weights
[4.29400262 6.53872627 7.64026691 3.47688306 4.93437645]
item rewards
[0.17008879 0.85895517 0.66385495 0.90542861 0.28839602]


In [166]:
from pymdp import utils
for i, idx in enumerate(np.argsort(q_pi[0])[::-1]):
    print("action", agent.policies[idx, 0])
    print("action multi", utils.index_to_combination(agent.policies[idx, 0][-1].tolist(), agent.num_controls_multi))
    print(G[0, idx])
    if i == 20:
        break

action [ 1  1  1  1  1 28]
action multi [1, 1, 1, 0, 0]
4.1886554
action [ 1  1  1  1  1 27]
action multi [1, 1, 0, 1, 1]
4.1886554
action [ 1  1  1  1  1 26]
action multi [1, 1, 0, 1, 0]
4.1886554
action [ 1  1  1  1  1 25]
action multi [1, 1, 0, 0, 1]
4.1886554
action [ 1  1  1  1  1 24]
action multi [1, 1, 0, 0, 0]
4.1886554
action [ 1  1  1  1  1 22]
action multi [1, 0, 1, 1, 0]
4.1886554
action [ 1  1  1  1  1 21]
action multi [1, 0, 1, 0, 1]
4.1886554
action [ 1  1  1  1  1 20]
action multi [1, 0, 1, 0, 0]
4.1886554
action [ 1  1  1  1  1 19]
action multi [1, 0, 0, 1, 1]
4.1886554
action [ 1  1  1  1  1 18]
action multi [1, 0, 0, 1, 0]
4.1886554
action [ 1  1  1  1  1 17]
action multi [1, 0, 0, 0, 1]
4.1886554
action [ 1  1  1  1  1 16]
action multi [1, 0, 0, 0, 0]
4.1886554
action [ 1  1  1  1  1 14]
action multi [0, 1, 1, 1, 0]
4.1886554
action [ 1  1  1  1  1 13]
action multi [0, 1, 1, 0, 1]
4.1886554
action [ 1  1  1  1  1 12]
action multi [0, 1, 1, 0, 0]
4.1886554
action [ 1