In [29]:
from pathlib import Path
import numpy as np
import json
import math

import torch
from src.problems.problems import compute_producer_optimal_solution, compute_consumer_optimal_solution

In [117]:
def compute_correlation(a, b):
    a = a.flatten()
    b = b.flatten()

    a_mean = np.mean(a)
    b_mean = np.mean(b)

    numerator = np.sum((a - a_mean) * (b - b_mean))
    denominator = np.sqrt(np.sum((a - a_mean) ** 2) * np.sum((b - b_mean) ** 2))

    return numerator / denominator

In [2]:
DATA_PATH_ROOT = Path("../../data")

In [3]:
# load data
with open(DATA_PATH_ROOT / "amazon_predictions.npy", "rb") as f:
    REL_MATRIX = np.load(f)

with open(DATA_PATH_ROOT / "amazon_user_groups.json", "r") as f:
    GROUPS_MAP = json.load(f)

In [545]:
rel_mat = REL_MATRIX[0:200, 0:200]

In [32]:
compute_producer_optimal_solution(
    rel_matrix=rel_mat,
    k_rec=10,
)

(9.999999999999998,
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], shape=(200, 200)))

In [33]:
_, mean_allocations = compute_consumer_optimal_solution(
    rel_matrix=rel_mat,
    k_rec=10,
    producer_max_min_utility=10,
    gamma=0.5,
    method="mean"
)

In [191]:
np.mean(np.sum(mean_allocations * rel_mat, axis=1)) / 10

np.float64(0.9466594812127049)

In [None]:
# hyper‐parameters
k_rec  = 10
u_min  = 10.0
gamma  = 0.5
prod_min = math.ceil(gamma * u_min)

# hyper‐params
tau0, tau_min, decay = 1.0, 0.01, 1e-3
lr, lr_dual          = 1e-2, 1e-2
λ = torch.zeros(n);  μ = torch.zeros(m)

for epoch in range(E):
    tau     = max(tau_min, tau0 * exp(-decay * epoch))
    Z       = z / tau
    A       = torch.softmax(Z, dim=1) * k_rec

    util    = torch.mean((A*R).sum(1) / k_rec)
    col_s   = A.sum(0)
    L_prod  = torch.mean(F.relu(prod_min - col_s)**2)
    L_bin   = torch.mean(A*(1 - A))

    # Augmented Lagrangian
    L_prim  = -util \
            + torch.dot(λ,  A.sum(1) - k_rec) \
            + torch.dot(μ,  F.relu(prod_min - col_s)) \
            + γ * L_bin

    opt.zero_grad()
    L_prim.backward()
    opt.step()

    # dual ascent
    with torch.no_grad():
        λ += lr_dual * (A.sum(1) - k_rec)
        μ += lr_dual * F.relu(prod_min - col_s)

    # optional STE rounding every so often
    if epoch > warmup and epoch % project_every == 0:
        A_hard = top_k_allocations(A.detach().cpu().numpy(), k_rec)
        z      = z + (torch.tensor(A_hard) - A).to(z.device)


In [212]:
import numpy as np
import torch
import torch.nn.functional as F
import math

# === hyper‐parameters ===
tau0         = 1.0        # initial softmax “temperature”
tau_min      = 0.01       # minimum temperature
decay        = 1e-3       # temperature decay rate per epoch
lr           = 1e-2       # primal (z) learning rate
lr_dual      = 1e-2       # dual (λ, μ) learning rate
num_epochs   = 1500       # total training epochs
warmup_frac  = 0.5        # fraction of epochs before we start STE projection
project_every= 100        # how often to project (straight‐through)
# ========================

def top_k_allocations(allocations: np.ndarray, k_rec: int) -> np.ndarray:
    """
    Given a continuous allocation matrix, returns the hard 0/1 top‐k per row.
    """
    idxs = allocations.argsort(axis=1)[:, -k_rec:]
    alls = np.zeros_like(allocations, dtype=np.int32)
    alls[np.arange(allocations.shape[0])[:, None], idxs] = 1
    return alls

def optim_augmented(
    rel_matrix: np.ndarray,
    k_rec: int,
    prod_min: float,
    gamma: float,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Augmented‐Lagrangian + row‐softmax relaxation for
    your Boolean LP:
      max mean_i ∑_j a_ij R_ij
      s.t. ∑_j a_ij = k_rec,   ∑_i a_ij ≥ prod_min,   a_ij ∈ {0,1}
    Returns:
      A_disc (n×m, int): final 0/1 allocations
      A_cont (n×m, float): final continuous allocations
    """
    n, m = rel_matrix.shape
    R = torch.tensor(rel_matrix, dtype=torch.float32)

    # optimization variable
    z = torch.zeros(n, m, requires_grad=True)
    opt = torch.optim.Adam([z], lr=lr)

    # dual multipliers for each constraint
    lambda_row = torch.ones(n)  # for ∑_j a_ij = k_rec
    mu_col     = torch.ones(m)  # for ∑_i a_ij ≥ prod_min

    warmup_epochs = int(warmup_frac * num_epochs)

    for epoch in range(1, num_epochs+1):
        # 1) temperature schedule
        tau = max(tau_min, tau0 * math.exp(-decay * epoch))

        # 2) row‐softmax to enforce ∑_j a_ij = k_rec
        Z = z / tau
        A = torch.softmax(Z, dim=1) * k_rec

        # 3) compute objectives
        util     = torch.mean(torch.sum(A * R, dim=1) / k_rec)
        col_sums = torch.sum(A, dim=0)
        # hinge‐squared for ∑_i a_ij ≥ prod_min
        L_prod   = torch.mean(F.relu(prod_min - col_sums)**2)
        # binarity penalty
        L_bin    = torch.mean(A * (1 - A))

        # 4) augmented Lagrangian (primal loss)
        row_res   = torch.sum(A, dim=1) - k_rec
        col_res   = F.relu(prod_min - col_sums)
        L_prim    = (
            -util
            + torch.dot(lambda_row,   row_res)
            + torch.dot(mu_col,       col_res)
            + gamma * L_bin
        )

        # 5) gradient step on z
        opt.zero_grad()
        L_prim.backward()
        opt.step()

        # 6) dual ascent
        with torch.no_grad():
            lambda_row += lr_dual * row_res
            mu_col     += lr_dual * col_res

        # 7) Straight‐through “hardening” every so often
        if epoch > warmup_epochs and epoch % project_every == 0:
            A_hard = top_k_allocations(A.detach().cpu().numpy(), k_rec)
            A_hard_t = torch.tensor(A_hard, dtype=torch.float32)
            with torch.no_grad():
                z += (A_hard_t - A)  # STE update

        # (optional) logging
        if epoch % 500 == 0 or epoch == 1:
            print(f"Epoch {epoch:4d} | util: {util.item():.4f} "
                  f"| L_prod: {L_prod.item():.4f} | L_bin: {L_bin.item():.4f}")

    # Final continuous allocation at lowest temperature
    Z_final = z / tau_min
    A_cont  = torch.softmax(Z_final, dim=1) * k_rec
    A_cont_np = A_cont.detach().cpu().numpy()
    A_disc    = top_k_allocations(A_cont_np, k_rec)

    return A_disc, A_cont_np



disc, cont = optim_augmented(rel_matrix=rel_mat,
                                 k_rec=10,
                                 prod_min=math.ceil(0.5 * 10.0),
                                 gamma=0.05)
print("Discrete allocations shape:", disc.shape)


Epoch    1 | util: 0.8922 | L_prod: 0.0000 | L_bin: 0.0475
Epoch  500 | util: 0.9102 | L_prod: 0.0000 | L_bin: 0.0466
Epoch 1000 | util: 0.9093 | L_prod: 0.0000 | L_bin: 0.0410
Epoch 1500 | util: 0.8990 | L_prod: 0.0000 | L_bin: 0.0196
Discrete allocations shape: (200, 200)


In [582]:
import torch
import torch.nn as nn
import torch.optim as optim

class TwoLayerSelector(nn.Module):
    def __init__(self, n_users, hidden_dim, n_items):
        super().__init__()
        # first “layer”: maps each user to a hidden representation
        self.z1 = nn.Parameter(torch.randn(n_users, hidden_dim))
        # second “layer”: maps hidden reps into item-scores
        self.z2 = nn.Parameter(torch.randn(hidden_dim, n_items))

    def forward(self):
        # pre-sigmoid logits in item-space:
        logits = self.z1 @ self.z2       # shape (n_users, n_items)
        return torch.sigmoid(logits)     # a ∈ (0,1)^{n×m}

R = torch.tensor(rel_mat)
n, m = R.shape    # for example
hidden_dim = 200       # you choose
model = TwoLayerSelector(n, hidden_dim, m)

opt = optim.Adam(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, patience=100, factor=0.5, verbose=True)

# your constants…
lambda_util, lambda_card, lambda_prod = 10, 100, 1000
k_rec, decay = 10, 1
R = torch.rand(n, m)          # whatever rewards matrix
prod_min = torch.ones(m) * 5  # example

for epoch in range(1, 5001):
    a = model()  # now two-layer rather than single z

    # === exactly the same losses as before ===
    util   = (a * R).sum(dim=1).mean() / k_rec
    L_util = -util

    row_sums = a.sum(dim=1)
    L_card   = ((row_sums - k_rec)**2).mean()

    col_sums = a.sum(dim=0)
    L_prod   = torch.relu(prod_min - col_sums).pow(2).mean()

    L_bin    = (a * (1 - a)).mean()

    loss = (lambda_util * L_util
            + lambda_card * L_card
            + lambda_prod * L_prod
            + L_bin)

    opt.zero_grad()
    loss.backward()
    opt.step()
    scheduler.step(loss)

    if epoch % 500 == 0:
        print(f"Epoch {epoch:5d} — "
              f"loss: {loss.item():.4f}, "
              f"util: {util.item():.2f}, "
              f"card: {L_card.item():.4f}, "
              f"prod: {L_prod.item():.4f}, "
              f"bin: {L_bin.item():.4f}")


Epoch   500 — loss: 6914.6748, util: 0.58, card: 38.6275, prod: 3.0577, bin: 0.0001
Epoch  1000 — loss: 5828.4238, util: 0.58, card: 35.0871, prod: 2.3255, bin: 0.0002
Epoch  1500 — loss: 5520.0186, util: 0.58, card: 34.2992, prod: 2.0959, bin: 0.0001
Epoch  2000 — loss: 5167.5361, util: 0.58, card: 33.8881, prod: 1.7845, bin: 0.0001
Epoch  2500 — loss: 4904.4922, util: 0.58, card: 31.3538, prod: 1.7749, bin: 0.0001
Epoch  3000 — loss: 4719.0176, util: 0.58, card: 29.7777, prod: 1.7470, bin: 0.0001
Epoch  3500 — loss: 4655.2881, util: 0.58, card: 29.8909, prod: 1.6720, bin: 0.0001
Epoch  4000 — loss: 4400.0166, util: 0.58, card: 30.4515, prod: 1.3607, bin: 0.0001
Epoch  4500 — loss: 4322.3301, util: 0.58, card: 29.3516, prod: 1.3929, bin: 0.0001
Epoch  5000 — loss: 4288.4893, util: 0.57, card: 28.2400, prod: 1.4702, bin: 0.0002


In [589]:
np.sort(A_cont[0])[::-1]

array([1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 1.0000000e+00,
       1.0000000e+00, 1.0000000e+00, 1.0000000e+00, 9.9999940e-01,
       9.9999249e-01, 8.6736975e-08, 3.3401788e-11, 3.4550032e-13,
       5.5672348e-14, 3.9424988e-15, 8.4031952e-16, 2.7273479e-16,
       5.0404889e-17, 4.6134803e-18, 1.3354737e-18, 9.7296582e-19,
       5.0736260e-19, 1.7643549e-19, 7.1620997e-23, 5.5525855e-23,
       4.2024223e-23, 1.7201043e-23, 9.5035897e-25, 8.8954332e-25,
       6.8952616e-25, 4.6912843e-25, 2.3436868e-25, 1.7843234e-25,
       7.1527239e-26, 4.6985175e-26, 1.5371144e-26, 4.8078646e-27,
       3.1385659e-27, 1.2347193e-27, 1.1496957e-27, 6.8806948e-28,
       1.5683688e-28, 3.9537964e-29, 1.4174376e-29, 7.8123219e-30,
       5.1484843e-30, 4.5798904e-30, 3.7377410e-30, 2.1696022e-30,
       1.9388040e-30, 1.2963986e-30, 3.6371838e-31, 2.4597840e-31,
       7.2596333e-32, 3.5179557e-32, 3.3808114e-32, 1.0104752e-32,
       8.8565264e-33, 8.6458227e-33, 4.9887995e-33, 4.1283255e

In [583]:
A_cont = model.forward().detach().cpu().numpy()
A_cont.shape

(200, 200)

In [584]:
A_cont = model.forward().detach().cpu().numpy()

A_opt = np.zeros_like(A_cont)
for i, row in enumerate(A_cont):
    A_opt[i, row.argsort()[-k_rec:]] = 1

In [585]:
np.mean(np.sum(A_opt * rel_mat, axis=1) / k_rec)

np.float64(0.8951087220039388)

In [551]:

R = torch.tensor(rel_mat)
n, m = R.shape
z = torch.rand(n, m, requires_grad=True, device="cpu")
opt = torch.optim.Adam([z], lr=0.1)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=100, factor=0.5, verbose=True)
prod_min = 5


lambda_util = 10
lambda_card = 100
lambda_prod = 10
k_rec = 10
decay = 1e-2

for epoch in range(1, 3001):
    a = torch.sigmoid(z)

    util   = torch.mean((a * R).sum(dim=1) / k_rec)
    L_util = -util

    row_sums = a.sum(dim=1)
    L_card   = torch.mean((row_sums - k_rec)**2)

    col_sums = a.sum(dim=0)
    L_prod   = torch.mean(torch.relu(prod_min - col_sums)**2)

    L_bin  = torch.mean(a * (1 - a))

    loss = lambda_util * L_util + lambda_card * L_card + lambda_prod * L_prod + L_bin

    opt.zero_grad()
    loss.backward()

    if epoch % 500 == 0:
        print(f"Epoch {epoch:5d} — loss: {loss.item():.4f}, util: {-L_util.item():.2f}, card: {L_card.item():.4f}, prod: {L_prod.item():.4f}, bin: {L_bin.item():.4f}")

    opt.step()
    scheduler.step(loss.item())

A_cont = torch.sigmoid(z).detach().cpu().numpy()
# projection to binary
A_opt = np.zeros_like(A_cont)
for i, row in enumerate(A_cont):
    A_opt[i, row.argsort()[-k_rec:]] = 1




Epoch   500 — loss: -8.8511, util: 0.89, card: 0.0004, prod: 0.0000, bin: 0.0476
Epoch  1000 — loss: -8.8768, util: 0.89, card: 0.0000, prod: 0.0000, bin: 0.0475
Epoch  1500 — loss: -8.8769, util: 0.89, card: 0.0000, prod: 0.0000, bin: 0.0475
Epoch  2000 — loss: -8.8771, util: 0.89, card: 0.0000, prod: 0.0000, bin: 0.0475
Epoch  2500 — loss: -8.8773, util: 0.89, card: 0.0000, prod: 0.0000, bin: 0.0475
Epoch  3000 — loss: -8.8776, util: 0.89, card: 0.0000, prod: 0.0000, bin: 0.0475


In [553]:
np.sort(A_cont[0])[::-1]

array([0.05779154, 0.05771403, 0.05765368, 0.05763114, 0.05748032,
       0.05746737, 0.05745109, 0.0573871 , 0.05737695, 0.05736849,
       0.05735333, 0.0573409 , 0.05724973, 0.05698105, 0.0569311 ,
       0.05686256, 0.05677284, 0.05677127, 0.05671323, 0.0565905 ,
       0.05654738, 0.05640062, 0.0563875 , 0.05636477, 0.05631623,
       0.05626244, 0.0561929 , 0.05607791, 0.0560635 , 0.05598119,
       0.055796  , 0.05579099, 0.05566983, 0.05551579, 0.05549545,
       0.05544922, 0.05526333, 0.05515146, 0.05514026, 0.05512378,
       0.0551208 , 0.05511221, 0.05501353, 0.05500968, 0.0549847 ,
       0.0549709 , 0.05496861, 0.05494679, 0.05487379, 0.05481775,
       0.05479058, 0.05460422, 0.05456612, 0.05438355, 0.05427985,
       0.0542602 , 0.05414438, 0.05396806, 0.0538949 , 0.05344047,
       0.05329363, 0.05327445, 0.05327269, 0.05324512, 0.05314672,
       0.05307278, 0.0529814 , 0.05295432, 0.05279669, 0.05271975,
       0.05261503, 0.05256554, 0.05239647, 0.05221638, 0.05215

In [520]:
np.mean(np.sum(A_opt * rel_mat, axis=1) / k_rec)

np.float64(0.9046699930702641)

In [524]:
compute_correlation(mean_allocations, A_opt)

np.float64(0.023157894478032476)

In [522]:
A_opt.sum(axis=0)

array([ 9.,  9.,  4.,  8., 13., 13., 10.,  6.,  6., 12., 11., 11., 21.,
        4.,  7.,  8.,  6.,  8., 16.,  5., 15., 13., 20.,  9., 12., 10.,
        6.,  2., 11.,  9., 14., 17., 21., 21., 10.,  8., 11.,  5.,  9.,
       12.,  9.,  9., 14., 12., 13.,  7.,  9.,  3.,  8.,  6.,  9., 14.,
        8., 12.,  9., 15.,  7., 19., 10.,  5., 14.,  9., 10.,  9., 12.,
       15.,  9., 11., 10.,  9., 13., 10.,  6., 10.,  7., 10.,  6., 10.,
       14.,  6.,  9., 12., 13.,  4.,  9.,  7.,  6.,  9.,  9.,  8., 14.,
        9.,  9., 10.,  5., 16., 11.,  8., 10.,  7.,  9., 18., 12.,  4.,
       11., 10., 12.,  9.,  4., 18., 14., 18., 17., 16.,  8.,  8., 15.,
        7., 11.,  8., 10., 12.,  7.,  6.,  8., 12.,  3., 10.,  4.,  9.,
        9., 11., 10., 14.,  4.,  9., 10., 12., 10., 10.,  6., 13.,  9.,
       11.,  3.,  7.,  3.,  9.,  8., 14.,  6., 14.,  9.,  9., 10., 17.,
       18., 11.,  5.,  9.,  9., 10., 11., 12., 16., 14., 14.,  1.,  5.,
       13., 12.,  5., 11., 10.,  8., 12., 16., 13.,  6., 13.,  7

In [386]:
n, m = rel_mat.shape  # ensure rel_matrix is defined
R = torch.tensor(rel_mat, dtype=torch.float32, device="cpu")

z = (torch.randn(n, m, device="cpu") * 0.01 + 0.1).requires_grad_(True)
log_lambda_prod = torch.zeros(m, device="cpu", requires_grad=False)

opt = torch.optim.Adam([z], lr=0.05)


# STAGE 1: Allow the model to find a meaningful initial solution (no lambdas initially)
for epoch in range(1, 1001):
    tau = max(0.05, 0.995 ** epoch)
    a = torch.sigmoid(z / tau)

    util = torch.mean((a * R).sum(dim=1) / k_rec)
    L_util = -util

    row_sums = a.sum(dim=1)
    L_card = ((row_sums - k_rec)**2).mean()

    col_sums = a.sum(dim=0)
    prod_shortfall = torch.relu(prod_min - col_sums)
    L_prod = (prod_shortfall**2).mean()  # soft quadratic penalty, no lambda yet

    loss = L_util + 5.0 * L_card + 5.0 * L_prod  # balanced weights, moderate penalties

    opt.zero_grad()
    loss.backward()
    opt.step()

    if epoch % 500 == 0:
        violated_count = (col_sums.detach().cpu().numpy() < prod_min).sum()
        print(f"[Stage 1] Epoch {epoch:4d} | Loss: {loss.item():.3f} | "
              f"Util: {-L_util.item():.3f} | Violations: {violated_count}")

# STAGE 2: Now introduce controlled dual updates to enforce constraints strictly
for epoch in range(1001, 4001):
    tau = max(0.01, 0.995 ** epoch)
    a = torch.sigmoid(z / tau)

    util = torch.mean((a * R).sum(dim=1) / k_rec)
    L_util = -util

    row_sums = a.sum(dim=1)
    L_card = ((row_sums - k_rec)**2).mean()

    col_sums = a.sum(dim=0)
    prod_shortfall = torch.relu(prod_min - col_sums)

    lambda_prod = torch.exp(log_lambda_prod)
    L_prod = (lambda_prod * (prod_shortfall**2)).mean()

    loss = L_util + 10.0 * L_card + L_prod  # slightly stronger constraints now

    opt.zero_grad()
    loss.backward()
    opt.step()

    with torch.no_grad():
        log_lambda_prod += 0.01 * prod_shortfall  # controlled dual updates
        log_lambda_prod.clamp_(min=-2, max=5)

    if epoch % 500 == 0:
        violated_count = (col_sums.detach().cpu().numpy() < prod_min).sum()
        avg_shortfall = prod_shortfall.mean().item()
        print(f"[Stage 2] Epoch {epoch:4d} | Loss: {loss.item():.3f} | "
              f"Util: {-L_util.item():.3f} | Violations: {violated_count} | Avg Shortfall: {avg_shortfall:.3f}")

# Final binary projection
A_cont = torch.sigmoid(z).detach().cpu().numpy()
A_opt = np.zeros_like(A_cont)
for i, row in enumerate(A_cont):
    A_opt[i, row.argsort()[-k_rec:]] = 1

producer_counts = A_opt.sum(axis=0)
violations = (producer_counts < prod_min).sum()
print(f"Final violations after projection: {violations}")


[Stage 1] Epoch  500 | Loss: 625.000 | Util: 0.000 | Violations: 200
[Stage 1] Epoch 1000 | Loss: 625.000 | Util: 0.000 | Violations: 200
[Stage 2] Epoch 1500 | Loss: 4710.328 | Util: 0.000 | Violations: 200 | Avg Shortfall: 5.000
[Stage 2] Epoch 2000 | Loss: 4710.328 | Util: 0.000 | Violations: 200 | Avg Shortfall: 5.000
[Stage 2] Epoch 2500 | Loss: 4710.328 | Util: 0.000 | Violations: 200 | Avg Shortfall: 5.000
[Stage 2] Epoch 3000 | Loss: 4710.328 | Util: 0.000 | Violations: 200 | Avg Shortfall: 5.000
[Stage 2] Epoch 3500 | Loss: 4710.328 | Util: 0.000 | Violations: 200 | Avg Shortfall: 5.000
[Stage 2] Epoch 4000 | Loss: 4710.328 | Util: 0.000 | Violations: 200 | Avg Shortfall: 5.000
Final violations after projection: 5


In [387]:
np.mean(np.sum(A_opt * rel_mat, axis=1) / k_rec)

np.float64(0.8925312649793051)

In [388]:
A_opt.sum(axis=0)

array([ 7.,  5., 12.,  9., 13.,  8., 12.,  7., 11.,  6.,  9., 16., 13.,
        9.,  8.,  4., 12.,  8.,  9., 12.,  9., 10., 11., 13., 12.,  7.,
       11., 10., 10.,  9., 12., 11.,  9., 11., 10., 11., 11., 13.,  7.,
        5.,  7.,  8., 14.,  6.,  3., 16., 11., 12.,  8., 13., 10., 13.,
        7.,  8.,  6.,  6., 15., 11., 13., 11., 13., 13.,  4.,  9., 16.,
       14., 13., 11.,  8., 12., 10., 11., 11., 10.,  9.,  7., 18., 10.,
        9., 10., 15., 13.,  8., 11.,  9., 11.,  7., 12., 16., 11., 11.,
       10., 13.,  7., 10., 14., 12.,  8., 10., 10.,  6., 14., 15.,  6.,
       15., 17.,  6., 18., 12.,  9., 13., 13., 11., 10., 10., 15.,  8.,
       10., 11.,  8., 14.,  9., 11.,  7.,  9., 13., 11., 11.,  5.,  6.,
        7.,  6.,  8.,  8.,  9.,  9., 10.,  4., 11., 13., 11.,  6.,  9.,
        6.,  6.,  9.,  7.,  9.,  7., 10.,  6., 17., 11.,  3.,  7., 19.,
        7., 10.,  9., 11., 10.,  6.,  9., 11.,  6.,  9., 11., 13.,  9.,
        8.,  8.,  9.,  5., 11., 11., 15.,  7.,  9., 12., 12., 10

In [281]:
np.mean(np.sum(mean_allocations * rel_mat, axis=1) / k_rec)

np.float64(0.9466594812127047)

In [276]:

# for each rel_mat row, take top-k based on the allocation
sorted_indices = np.argsort(mean_allocations, axis=1)[::-1][:, :k_rec]
print(np.mean(rel_mat[np.arange(rel_mat.shape[0])[:, None], sorted_indices], axis=1))
print(np.sum(mean_allocations, axis=0))
print(np.sum(mean_allocations, axis=1))


[0.88454426 0.82073556 0.92958885 0.95041268 0.90130418 0.95360957
 0.90437249 0.90215369 0.84169271 0.81099343 0.74843775 0.96578593
 0.94725356 0.95360391 0.89828264 0.8910981  0.78405916 0.89805986
 0.94508075 0.97740006 0.79811105 0.87104322 0.98142324 0.85312095
 0.8231652  0.91341474 0.89999115 0.8845513  0.9452226  0.85785696
 0.91790804 0.97983441 0.91876659 0.89139027 0.89914497 0.8987039
 0.81594811 0.89713778 0.82567525 0.91429423 0.9268033  0.91292709
 0.90771376 0.97021217 0.92682126 0.94819362 0.8867817  0.87495038
 0.93166013 0.7806246  0.73209766 0.78123458 0.92761685 0.86109238
 0.96749913 0.79566761 0.89570176 0.93509685 0.9707759  0.72826213
 0.84841753 0.72617826 0.92275587 0.95187544 0.91133795 0.87837534
 0.89899835 0.80371097 0.71420634 0.93934083 0.80872552 0.78192086
 0.92679767 0.93672771 0.96834136 0.74314832 0.73710272 0.84708724
 0.88626303 0.84994237 0.88079671 0.79758221 0.81891578 0.87057058
 0.97260963 0.97045069 0.7622217  0.96137044 0.73829369 0.96344

In [182]:
# compute correlation between the two allocations


corr = compute_correlation(new_alls, mean_allocations)
corr

np.float64(0.3194736703774324)

In [482]:
ws = torch.randn(n, m, device="cpu")
gumbel = F.gumbel_softmax(ws, tau=tau, hard=True, dim=-1)

gumbel[0]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])