# Demo: Knapsack Problem

In this notebook, we demonstrate how to solve the knapsack problem, a classic Operations Research problem. In this problem, we have a knapsack with fixed weight capacity and a set of items each associated with a weight and a value. We want fit items into the knapsack in a way such as the total value of all fitted items is as high as possible, however, the sum of item weights cannot exceed the knapsack weight capacity. 

While is problem is traditionally solved with linear programming, we can convert it into a contextual bandit problem (a simplified 1-stage Markov decision problem) and solve it using pymdp.

Let us define our actions `a_i` as whether to include an item or not for each item i. The state `s_i` of the system is defined as whether an item is included or not, i.e., copying the action variables over to the corresponding state variables. We also need another state variable `z` which represents whether the knapsack capacity is exceeded. If an item is included, i.e., `s_i = 1`, we get a reward `r_i`, otherwise, we get a reward of 0 when `s_i = 0`. We can thus define our preference of including valuable items to be proportional to the exponential of reward: `C[s_i] = softmax([0, r_i])`. Our preference on the capacity constraint variable `z` is to never violate it, i.e., `C[z] = [1, 0]`. Since the system is fully observable, we will set all the Categorical/Dirichlet parameters of the emission model (the `A` tensors) to be diagonal.

In [1]:
import numpy as np
from jax import numpy as jnp, random as jr
from jax import tree_util as jtu
import jax.nn as nn
import itertools

from pymdp.agent import Agent
from pymdp import distribution

## Specify model structure

In [2]:
# knapsack problem setup
num_items = 5
max_capacity = 20
key = jr.PRNGKey(0)
key_w, key_r = jr.split(key)
item_weights = max_capacity / num_items + jr.uniform(key_w, shape=(num_items,), minval=0.0, maxval=5.0)
rewards = jr.uniform(key_r, shape=(num_items,), minval=0.0, maxval=1.0)

print("item rewards", rewards)
print("item weights", item_weights)
print("item weight sum", item_weights.sum())

# mdp config
state_config = {
    f"s_{i}": {"elements": ["not enclude", "include"], "depends_on": [f"s_{i}"], "controlled_by": [f"a_{i}"]} 
    for i in range(num_items)
}
state_config["z"] = {
    "elements": ["not violated", "violated"], "depends_on": ["z"], "controlled_by": [f"a_{i}" for i in range(num_items)]
} 

obs_config = {
    k: {"elements": v["elements"], "depends_on": [k]} for k, v in state_config.items()
}

act_config = {
    f"a_{i}": {"elements": ["not enclude", "include"]} 
    for i in range(num_items)
}

model_description = {
    "observations": obs_config,
    "controls": act_config,
    "states": state_config,
}

item rewards [0.00729382 0.02089119 0.5814265  0.36183798 0.22303772]
item weights [8.211571  4.9118934 5.1358905 4.603628  4.9590673]
item weight sum 27.822052


## Specify model parameters

In [3]:
model = distribution.compile_model(model_description)

print("A shapes", [a.data.shape for a in model.A])
print("B shapes", [a.data.shape for a in model.B])

A shapes [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
B shapes [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2, 2, 2, 2, 2)]


In [4]:
def create_identity_transition_factor(mat):
    for i in range(mat.shape[1]):
        mat[:, i] = np.eye(len(mat))
    return mat

def create_constraint_factor_z_greater_than(act_dim, maximum, num_items, weights):
    # Create an array of shape (2, act_dim, act_dim, ..., act_dim)
    tensor_shape = (2,) + (act_dim,) * num_items
    
    # Create an array with indices from 0 to act_dim - 1 along each dimension
    indices = np.indices(tensor_shape[1:])

    # Reshape weights to fit indices shape
    weights_reshaped = np.array(weights).reshape((-1,) + (1,) * (indices.ndim - 1))
    # Multiply weights with matrix that conforms to constraint
    result = np.array(indices == (act_dim - 1)) * weights_reshaped

    # Calculate the total for each combination of actions
    total = np.sum(result, axis=0)
    
    # Create the tensor based on the total hours condition
    tensor = np.where(total > maximum, 1, 0)

    # Stack the tensor along the first axis to create the final tensor
    tensor = np.stack((1 - tensor, tensor), axis=0)

    # make a copy for self state 
    tensor = np.stack([tensor, tensor], axis=1)
    return tensor

# update A tensor
for i in range(len(model.A)):
    model.A[i].data = np.eye(len(model.A[i].data))

# update B tensors
for i in range(num_items):
    model.B[i].data = create_identity_transition_factor(model.B[i].data)

model.B[-1].data = create_constraint_factor_z_greater_than(2, max_capacity, num_items, item_weights)

# create C tensors
preferences = nn.softmax(np.stack([np.zeros_like(rewards), rewards], axis=-1), axis=-1)
Cs = [None for _ in range(len(model.A))]
for i in range(len(model.A)):
    Cs[i] = preferences[i]
    
Cs[-1] = np.array([1., 0]) # capacity constraint cannot be violated

print("A shapes", [a.data.shape for a in model.A])
print("B shapes", [a.data.shape for a in model.B])

print("A normalized", [np.isclose(a.data.sum(0), 1.).all() for a in model.A])
print("B normalized", [np.isclose(b.data.sum(0), 1.).all() for b in model.B])
print("C normalized", [np.isclose(a.sum(0), 1.).all() for a in Cs])

A shapes [(2, 2), (2, 2), (2, 2), (2, 2), (2, 2), (2, 2)]
B shapes [(2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2, 2, 2, 2, 2)]
A normalized [True, True, True, True, True, True]
B normalized [True, True, True, True, True, True]
C normalized [True, True, True, True, True, True]


## Run agent

In [5]:
B_action_dependencies = [
    [list(model_description["controls"].keys()).index(i) for i in s["controlled_by"]] 
    for s in model_description["states"].values()
]
num_controls = [len(c["elements"]) for c in model_description["controls"].values()]

agent = Agent(
    model.A, model.B, Cs,
    B_action_dependencies=B_action_dependencies,
    num_controls=num_controls,
)

In [6]:
qs = jtu.tree_map(lambda x: jnp.expand_dims(x, axis=0), agent.D)
q_pi, G = agent.infer_policies(qs)
action = agent.sample_action(q_pi)
action_multi = agent.decode_multi_actions(action)

In [7]:
print("best action", action)
print("best action multi", action_multi)
print("item weights")
print(item_weights)
print("item rewards")
print(rewards)

best action [[ 0  1  1  1  1 15]]
best action multi [[0 1 1 1 1]]
item weights
[8.211571  4.9118934 5.1358905 4.603628  4.9590673]
item rewards
[0.00729382 0.02089119 0.5814265  0.36183798 0.22303772]


### Verify that the solution chosen by the active inference agent is the best solution, as validated by brute force search

In [8]:
def brute_force_knapsack(weights, rewards, capacity):
    best_r = -jnp.inf
    best_bits = None
    for bits in itertools.product([0,1], repeat=len(weights)):
        w = jnp.dot(jnp.array(bits), weights)
        if w <= capacity + 1e-9:
            r = jnp.dot(jnp.array(bits), rewards)
            if r > best_r:
                best_r, best_bits = float(r), jnp.array(bits)
    return best_bits, best_r

bf_bits, bf_reward = brute_force_knapsack(item_weights, rewards, max_capacity)
print("Brute-force best bits:", bf_bits, "reward:", bf_reward)
print("Agent best bits:", action_multi[0], "reward:", float(jnp.dot(action_multi[0], rewards)))


Brute-force best bits: [0 1 1 1 1] reward: 1.1871933937072754
Agent best bits: [0 1 1 1 1] reward: 1.1871933937072754


In [9]:
# print top actions, sorted in decreasing posterior probability (lowest EFE first)
from pymdp import utils
for i, idx in enumerate(np.argsort(q_pi[0])[::-1]):
    action_multi_f = utils.index_to_combination(agent.policies[idx, 0][-1].tolist(), agent.num_controls_multi)
    print("\naction", agent.policies[idx, 0])
    print("action multi", action_multi_f)
    print("efe: -{:.2f}, reward: {:.2f}".format(G[0, idx], np.sum(rewards * action_multi_f)))
    if i == 5:
        break


action [ 0  1  1  1  1 15]
action multi [0 1 1 1 1]
efe: -3.79, reward: 1.19

action [0 0 1 1 1 7]
action multi [0 0 1 1 1]
efe: -3.78, reward: 1.17

action [ 0  1  1  1  0 14]
action multi [0 1 1 1 0]
efe: -3.68, reward: 0.96

action [ 1  0  1  1  0 22]
action multi [1 0 1 1 0]
efe: -3.67, reward: 0.95

action [0 0 1 1 0 6]
action multi [0 0 1 1 0]
efe: -3.67, reward: 0.94

action [ 0  1  1  0  1 13]
action multi [0 1 1 0 1]
efe: -3.61, reward: 0.83
