In [2]:
from pathlib import Path
import numpy as np
import json
import cvxpy as cp

import torch
import torch
from src.problems.utils import sample_data_for_group
from src.problems.problems import compute_producer_optimal_solution, _compute_consumer_optimal_solution_cvar
from src.problems.gradient_problem import compute_consumer_optimal_solution_cvar_grad
from src.problems.problems import (
    _compute_consumer_optimal_solution_cvar,
    _compute_consumer_optimal_solution_mean,
    _compute_consumer_optimal_solution_min
)

In [3]:
DATA_PATH_ROOT = Path("../../data")

In [4]:
# load data
with open(DATA_PATH_ROOT / "amazon_predictions.npy", "rb") as f:
    REL_MATRIX = np.load(f)

with open(DATA_PATH_ROOT / "amazon_user_groups.json", "r") as f:
    GROUPS_MAP = json.load(f)

In [248]:
N_CONSUMERS = 500
N_PRODUCERS = 500

In [6]:
sampled_matrix

NameError: name 'sampled_matrix' is not defined

In [249]:
sampled_matrix, consumer_ids, group_assignments = sample_data_for_group(
    n_consumers=N_CONSUMERS,
    n_producers=N_PRODUCERS,
    groups_map=GROUPS_MAP,
    group_key="top_category",
    data=REL_MATRIX,
    seed=3
)

In [None]:
mean_allocations = _compute_consumer_optimal_solution_mean(
    rel_matrix=sampled_matrix,
    k_rec=10,
    producer_max_min_utility=10,
    gamma=0.5,
    solver=cp.GUROBI
)

In [113]:
import numpy as np
import math

def async_augmented_lagrangian(
    r: np.ndarray,
    k: int,
    U: float,
    gamma: float,
    rho0: float = 1.0,
    rho_mult: float = 1.05,
    max_epochs: int = 200,
    tol: float = 1e-3,
    jitter: float = 1e-3
):
    """
    Asynchronous row-by-row Augmented Lagrangian for the
    max-mean utility problem with (i) exactly k picks per
    consumer and (ii) >= h picks per producer.

    Returns final A, mean utility, and final beta.
    """
    n, m = r.shape
    h = math.ceil(gamma * U)

    # initialize
    A    = np.zeros((n, m), dtype=float)
    beta = np.zeros(m)
    rho  = rho0
    cov  = A.sum(axis=0)

    for epoch in range(max_epochs):
        any_change = False

        for i in range(n):
            # build this consumer's score
            # note: cov and beta include all previous updates
            score = (r[i]/n) + beta + rho * np.maximum(0, h - cov)
            score += np.random.uniform(0, jitter, size=m)

            # pick top-k for row i
            new_row = np.zeros(m, dtype=float)
            topk_idxs = np.argpartition(-score, k-1)[:k]
            new_row[topk_idxs] = 1.0

            # if this row really changed, commit it
            if not np.array_equal(new_row, A[i]):
                # remove old coverage, add new
                cov -= A[i]
                cov += new_row
                A[i] = new_row
                any_change = True

                # immediate dual update on all producers
                # g_j = h - cov_j  (positive if still under-covered)
                g = h - cov
                beta = np.maximum(0.0, beta + rho * g)

                # ramp penalty a bit every row
                rho *= rho_mult ** (1.0/n)

        # check stopping: are all cov >= h?
        min_cov = cov.min()
        if min_cov >= h - tol:
            break
        # if completely no row changed, bump rho heavier to force movement
        if not any_change:
            rho *= 2

    mean_util = (A * r).sum(axis=1).mean()
    return A, mean_util, beta

r = async_augmented_lagrangian(
    r=sampled_matrix,
    k=10,
    U=3,
    gamma=1,
    rho0=1.0,
    rho_mult=2.0,
    max_epochs=200,
    tol=1e-3,
    jitter=1e-6
)

In [111]:
(mean_allocations[1] * sampled_matrix).sum(axis=1).mean()

np.float64(9.528643500775212)

In [87]:
mean_allocations[1].sum(axis=0)

array([  3.,  13.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
         3.,   3.,   3.,   3.,   3.,   3.,  14.,   3.,   3.,   3.,   3.,
         8.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,  23.,   3.,
         3.,  93.,   3.,   3.,  16.,   3.,   3.,   3.,   3.,   3.,   3.,
         3., 345.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
         3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
         3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
         3.,   3.,  46., 255.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
       231.,   3.,   3.,   3.,  23.,   3.,   3.,   3.,   3.,   3.,   3.,
       187.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
        11.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
         3.,  31.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,   3.,
         3.,   3.,   3.,   3.,   3.,   3.,   3.,   3., 423.,   3.,   3.,
         3.,   3.,   3.,   3.,   3.,   3.,  68.,   

In [114]:
r[0].sum(axis=0)  # sum of all producers

array([  6.,  11.,   6.,   6.,   6.,   5.,   6.,   6.,   5.,   6.,   5.,
         6.,   5.,   6.,   6.,   6.,   6.,  18.,   6.,   6.,   5.,   6.,
         7.,   6.,   5.,   6.,   6.,   6.,   6.,   6.,   5.,  18.,   6.,
         5.,  50.,   5.,   6.,  19.,   6.,   6.,   6.,   6.,   6.,   5.,
         6., 194.,   5.,   6.,   5.,   5.,   5.,   6.,   6.,   5.,   5.,
         6.,   5.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,
         6.,   6.,   5.,   6.,   5.,   6.,   6.,   6.,   6.,   6.,   6.,
         6.,   6.,  39., 169.,   5.,   6.,   6.,   6.,   5.,   6.,   6.,
       150.,   6.,   6.,   6.,  21.,   5.,   6.,   5.,   6.,   6.,   6.,
       130.,   6.,   5.,   6.,   5.,   6.,   6.,   6.,   5.,   6.,   6.,
        12.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,   5.,   6.,   6.,
         6.,  24.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,   6.,
         6.,   6.,   6.,   6.,   6.,   6.,   6.,   6., 222.,   6.,   6.,
         6.,   6.,   6.,   5.,   6.,   6.,  54.,   

In [104]:
import numpy as np
import math

def async_aug_lagrange_tame(
    r: np.ndarray,
    k: int,
    U: float,
    gamma: float,
    rho: float = 0.5,       # small, fixed penalty
    alpha: float = 0.75,    # exponent for diminishing step size
    max_epochs: int = 200,
    tol: float = 1e-3,
    jitter: float = 1e-4
):
    """
    Asynchronous AL with small, fixed rho and diminishing dual steps.
    """
    n, m = r.shape
    h      = math.ceil(gamma * U)
    A      = np.zeros((n, m), dtype=float)
    beta   = np.zeros(m)
    coverage = A.sum(axis=0)

    for epoch in range(1, max_epochs + 1):
        any_change = False

        # dual step size
        eta = rho / (epoch ** alpha)

        for i in range(n):
            # violation vector (positive if under)
            viol = np.maximum(0, h - coverage)

            # per-entry score
            score = (r[i] / n) + beta + rho * viol
            score += np.random.rand(m) * jitter

            # pick top-k
            new_row = np.zeros(m)
            topk   = np.argpartition(-score, k-1)[:k]
            new_row[topk] = 1.0

            if not np.array_equal(new_row, A[i]):
                # commit change
                coverage -= A[i]
                coverage += new_row
                A[i] = new_row
                any_change = True

                # dual update (projected subgradient ascent)
                # g_j = h - coverage_j
                g = h - coverage
                beta = np.maximum(0.0, beta + eta * g)

        # stop if all columns ≥ h
        if coverage.min() >= h - tol:
            break

        # if nothing moved, you might bump rho slightly or just continue

    mean_util = (A * r).sum(axis=1).mean()
    return A, mean_util, beta, coverage

r = async_aug_lagrange_tame(
    r=sampled_matrix,
    k=10,
    U=5,
    gamma=0.5,
    rho=0.5,
    alpha=0.75,
    max_epochs=200,
    tol=1e-3,
    jitter=1e-4
)

In [108]:
r[0].sum(axis=0).min()  # sum of all producers

np.float64(5.0)

In [None]:
import numpy as np
from scipy.optimize import minimize

def optimize_allocations_with_grad(rel_matrix, k_rec, producer_max_min_utility, gamma):
    N, M = rel_matrix.shape
    x0 = np.full(N * M, 1 / M)

    # Objective (negative utility since we minimize)
    def objective(x):
        x_mat = x.reshape(N, M)
        return -np.mean(np.sum(x_mat * rel_matrix, axis=1))

    # Gradient of objective
    def grad_objective(x):
        grad = - (rel_matrix / N).flatten()
        return grad

    # Equality constraints (consumer-level)
    cons_eq = [{
        'type': 'eq',
        'fun': lambda x, i=i: np.sum(x.reshape(N, M)[i]) - k_rec,
        'jac': lambda x, i=i: np.array(
            [(1 if (idx // M) == i else 0) for idx in range(N*M)]
        )
    } for i in range(N)]

    # Inequality constraints (producer-level)
    cons_ineq = [{
        'type': 'ineq',
        'fun': lambda x, j=j: np.sum(x.reshape(N, M)[:, j]) - gamma * producer_max_min_utility,
        'jac': lambda x, j=j: np.array(
            [(1 if (idx % M) == j else 0) for idx in range(N*M)]
        )
    } for j in range(M)]

    constraints = cons_eq + cons_ineq

    # Bounds
    bounds = [(0, 1) for _ in range(N*M)]

    # Solve with SLSQP
    result = minimize(
        objective, x0, method='SLSQP', jac=grad_objective,
        constraints=constraints, bounds=bounds,
        options={'ftol': 1e-9, 'disp': True, 'maxiter': 500}
    )

    if not result.success:
        raise ValueError("Optimization failed:", result.message)

    optimal_allocations = result.x.reshape(N, M)
    optimal_value = -result.fun  # revert sign back

    return optimal_value, optimal_allocations

# Example usage:
rel_matrix = sampled_matrix[:100, :100]  # Use a smaller matrix for testing
k_rec = 2
producer_max_min_utility = 1.0
gamma = 0.5

optimal_value, allocations = optimize_allocations_with_grad(
    rel_matrix, k_rec, producer_max_min_utility, gamma
)

print("Optimal Mean Utility:", optimal_value)
print("Optimal Allocations:\n", allocations.round())


In [31]:
import numpy as np


import numpy as np


def solve_with_dual_ascent(
    rel_matrix: np.ndarray,
    k_rec: int,
    producer_max_min_utility: float,
    gamma: float,
    max_iters: int = 1000,
    step_size: float = 0.1,
    tolerance: float = 1e-4,
):
    n_consumers, n_producers = rel_matrix.shape

    lambda_consumer = np.zeros(n_consumers)
    mu_producer = np.zeros(n_producers)

    for iteration in range(max_iters):
        prev_lambda = lambda_consumer.copy()
        prev_mu = mu_producer.copy()

        allocations = (rel_matrix / n_consumers) + mu_producer - lambda_consumer[:, None]
        allocations = np.clip(allocations, 0, 1)

        lambda_gradient = allocations.sum(axis=1) - k_rec
        mu_gradient = gamma * producer_max_min_utility - allocations.sum(axis=0)

        lambda_consumer += step_size * lambda_gradient
        mu_producer += step_size * mu_gradient
        mu_producer = np.maximum(0, mu_producer)

        dual_change = np.linalg.norm(lambda_consumer - prev_lambda) + np.linalg.norm(mu_producer - prev_mu)
        if dual_change < tolerance:
            break

    allocations = (rel_matrix / n_consumers) + mu_producer - lambda_consumer[:, None]
    allocations = np.clip(allocations, 0, 1)

    for i in range(n_consumers):
        top_indices = np.argsort(allocations[i])[::-1][:k_rec]
        allocations[i, :] = 0
        allocations[i, top_indices] = 1

    col_sums = allocations.sum(axis=0)
    required_allocations = gamma * producer_max_min_utility

    # Robust producer constraint enforcement
    while np.any(col_sums < required_allocations):
        deficit_producers = np.where(col_sums < required_allocations)[0]
        for producer in deficit_producers:
            deficit = int(required_allocations - col_sums[producer])
            potential_consumers = np.argsort(-rel_matrix[:, producer])
            allocated = 0
            for consumer in potential_consumers:
                if allocations[consumer].sum() < k_rec and allocations[consumer, producer] == 0:
                    allocations[consumer, producer] = 1
                    allocated += 1
                    if allocated >= deficit:
                        break
            col_sums = allocations.sum(axis=0)

        # If stuck, free allocations from overloaded producers
        if np.any(col_sums > required_allocations):
            overloaded_producers = np.where(col_sums > required_allocations)[0]
            for producer in overloaded_producers:
                excess = int(col_sums[producer] - required_allocations)
                consumers_assigned = np.where(allocations[:, producer] == 1)[0]
                for consumer in consumers_assigned:
                    if allocations[consumer].sum() > k_rec:
                        allocations[consumer, producer] = 0
                        excess -= 1
                        if excess <= 0:
                            break
            col_sums = allocations.sum(axis=0)

    mean_utility = np.mean((allocations * rel_matrix).sum(axis=1))

    return mean_utility, allocations



mean_utility, optimal_allocations = solve_with_dual_ascent(
    sampled_matrix, 10, 10, 0.5, tolerance=1e-20
)




KeyboardInterrupt: 

In [26]:
mean_allocations = _compute_consumer_optimal_solution_mean(
    rel_matrix=sampled_matrix,
    k_rec=10,
    producer_max_min_utility=10,
    gamma=0.5,
    solver=cp.GUROBI
)

In [27]:
mean_allocations[1].sum(axis=0)  # sum of all producers

array([ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 21., 51., 22.,  5.,
        5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 76.,
        5., 85.,  5.,  5.,  5.,  5.,  5.,  5., 76.,  5.,  5.,  5.,  5.,
        5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 12.,  5.,  5., 23.,
       40., 26.,  5.,  5.,  5., 82.,  5.,  5.,  5.,  5.,  7.,  5.,  5.,
        5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,  5., 24.,  5.,  5.,  6.,
        5.,  5., 11.,  5.,  5.,  6.,  5.,  5.,  5.,  5.,  5.,  5.,  5.,
        5.,  5.,  5., 13.,  9.,  5.,  5.,  5.,  5.])

In [30]:
optimal_allocations.sum(axis=0)  # sum of all producers

array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   9.,  19.,
         6.,   0.,   1.,   0.,   0.,   0.,   0.,   2.,   0.,   0.,   0.,
         0.,   4.,   0.,  22.,  78., 100.,  78.,  78.,  78.,  78.,  78.,
        78.,  22.,   0.,  79.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
         0.,   0.,   1.,   0.,   6.,   0.,   0.,  11.,  16.,  10.,   0.,
         0.,   0.,  22.,   0.,   0.,   0.,   0.,   6.,   0.,   0.,   0.,
         0.,   0.,   0.,   0.,   1.,   0.,   0.,   0.,   9.,   0.,   0.,
         7.,   0.,   0.,   2.,   0.,   0.,   6.,   0.,   0.,   0.,   0.,
         1.,   0.,   0.,   0.,   1.,   0.,   4.,   9.,   0.,   0.,   0.,
        78.])

In [167]:
import torch
import torch.nn as nn
import torch.optim as optim

def compute_consumer_optimal_with_lagrangian(
    rel_matrix: torch.Tensor,
    k_rec: int,
    producer_max_min_utility: float,
    gamma: float,
    lr_primal: float = 1e-2,
    lr_dual: float = 1e-2,
    max_epochs: int = 5000,
    verbose: bool = False
) -> torch.Tensor:
    """
    Primal-dual via explicit Lagrangian multipliers for mean consumer utility maximization.

    Args:
        rel_matrix: (n, m) tensor of relevances.
        k_rec: number of producers per consumer.
        producer_max_min_utility: U_max reference.
        gamma: fraction of U_max for floor constraint.
        lr_primal: learning rate for primal (allocations) optimizer.
        lr_dual: learning rate for dual (multipliers) optimizer.
        max_epochs: maximum training epochs.
        verbose: print diagnostics every 500 epochs.

    Returns:
        allocations: (n, m) tensor in [0,1] approx binary selections.
    """
    n, m = rel_matrix.shape

    # Primal: allocations logits
    logits = nn.Parameter(torch.zeros(n, m))

    # Dual multipliers: alpha for equality (size n), beta for inequality (size m)
    alpha = nn.Parameter(torch.zeros(n))       # unconstrained
    beta = nn.Parameter(torch.zeros(m))        # will clamp to >=0

    # Optimizers
    opt_primal = optim.Adam([logits], lr=lr_primal)
    opt_dual = optim.Adam([alpha, beta], lr=lr_dual)

    # Precompute threshold for producer constraint
    min_prod = gamma * producer_max_min_utility

    for epoch in range(1, max_epochs+1):
        # Compute allocations via sigmoid
        A = torch.sigmoid(logits)

        # Compute objective (mean consumer utility)
        util = (A * rel_matrix).sum(dim=1).mean()

        # Constraints:
        #   per-consumer: sum_j A_ij == k_rec  =>  ci = sum - k_rec
        ci = A.sum(dim=1) - k_rec
        #   per-producer: sum_i A_ij >= min_prod => gj = min_prod - sum
        gj = min_prod - A.sum(dim=0)

        # Lagrangian: maximize util subject to constraints
        # We minimize L = -util + alpha^T ci + beta^T gj
        L = -util + (alpha * ci).mean() + (beta * gj).mean()

        # ---- Primal update (minimize L wrt logits) ----
        opt_primal.zero_grad()
        L.backward(retain_graph=True)
        opt_primal.step()

        # ---- Dual update (maximize L wrt alpha,beta) ----
        # we perform gradient ascent by minimizing -L
        opt_dual.zero_grad()
        dual_loss = -L
        dual_loss.backward()
        opt_dual.step()

        # Enforce beta >= 0
        with torch.no_grad():
            beta.clamp_(min=0.0)

        if verbose and epoch % 500 == 0:
            with torch.no_grad():
                min_c = ci.abs().max().item()
                min_p = A.sum(dim=0).min().item()
                print(f"Epoch {epoch}: util={util.item():.4f}, max|ci|={min_c:.4f}, min_prod={min_p:.4f}")

    # Final allocations
    return torch.sigmoid(logits).detach()

allocations = compute_consumer_optimal_with_lagrangian(
    rel_matrix=torch.tensor(sampled_matrix),
    k_rec=10,
    producer_max_min_utility=10,
    gamma=0.5,
    lr_primal=1e-2,
    lr_dual=1e-2,
    max_epochs=100000,
    verbose=True
)



Epoch 500: util=3.9690, max|ci|=6.3529, min_prod=3.9996
Epoch 1000: util=2.3538, max|ci|=7.5739, min_prod=2.3680
Epoch 1500: util=12.6000, max|ci|=19.3618, min_prod=11.5018
Epoch 2000: util=2.5822, max|ci|=8.4394, min_prod=2.5751
Epoch 2500: util=6.1808, max|ci|=39.9784, min_prod=4.0663
Epoch 3000: util=4.2114, max|ci|=7.8515, min_prod=3.7652
Epoch 3500: util=4.5326, max|ci|=35.7553, min_prod=3.0172
Epoch 4000: util=4.0862, max|ci|=53.5875, min_prod=3.9011
Epoch 4500: util=5.0665, max|ci|=38.4551, min_prod=4.2985
Epoch 5000: util=2.6227, max|ci|=48.8652, min_prod=2.5539
Epoch 5500: util=11.5624, max|ci|=40.4783, min_prod=6.0378
Epoch 6000: util=2.3059, max|ci|=26.7210, min_prod=1.7196
Epoch 6500: util=17.2119, max|ci|=50.6612, min_prod=6.9722
Epoch 7000: util=2.5568, max|ci|=44.8316, min_prod=1.9707
Epoch 7500: util=7.5786, max|ci|=44.9026, min_prod=4.7659
Epoch 8000: util=5.4476, max|ci|=38.7051, min_prod=2.8088
Epoch 8500: util=5.6993, max|ci|=73.5864, min_prod=4.1271
Epoch 9000: uti

KeyboardInterrupt: 

In [251]:
def cvar_util(
    rel_matrix: torch.Tensor,
    allocations: torch.Tensor,
    group_assignments: torch.Tensor,
    k_rec: int,
    rho: torch.Tensor,
    alpha: float,
) -> torch.Tensor:
    """
    Vectorized, fully-differentiable CVaR‐style objective over groups.
    """
    # 1) Greedy top‐k sum per example
    greedy_allocs = rel_matrix.topk(k_rec, dim=1).values.sum(dim=1)  # (N,)

    eps = 1e-8
    loss_per_item = 1 - allocations / (greedy_allocs + eps)  # (N,)

    # 3) Remap group IDs to 0…G−1
    unique_groups, inverse = torch.unique(group_assignments, return_inverse=True)
    G = unique_groups.numel()

    # 4) Sum losses and counts per group via scatter_add_
    device = rel_matrix.device
    dtype = loss_per_item.dtype
    sum_losses = torch.zeros(G, device=device, dtype=dtype).scatter_add_(0, inverse, loss_per_item)
    counts = torch.zeros(G, device=device, dtype=dtype).scatter_add_(
        0, inverse, torch.ones_like(loss_per_item)
    )
    norm_losses = sum_losses / (counts + eps)

    # 5) CVaR objective
    rho_clamped = torch.clamp(rho, min=0.0)
    excess = torch.relu(norm_losses - rho_clamped).sum()
    cvar_obj = rho_clamped + excess / ((1 - alpha) * G)

    return cvar_obj

In [168]:
import torch
import torch.nn as nn
import torch.optim as optim

def compute_consumer_optimal_with_augmented_lagrangian(
    rel_matrix: torch.Tensor,
    k_rec: int,
    producer_max_min_utility: float,
    gamma: float,
    lr_primal: float = 1e-2,
    lr_dual: float = 1e-2,
    rho: float = 10.0,
    max_epochs: int = 5000,
    verbose: bool = False
) -> torch.Tensor:
    """
    Augmented Lagrangian primal-dual solver:
      maximize mean utility under sum_j A[i,j]=k_rec and sum_i A[i,j]>=gamma*U constraints.

    Args:
        rel_matrix: (n, m) tensor of relevances.
        allocations: (n, m) tensor for storing allocation results.
        group_assignments: (n,) tensor for consumer group assignments.
        k_rec: per-consumer recommendation count.
        producer_max_min_utility: U_max reference.
        gamma: fraction of U_max floor.
        lr_primal: LR for primal (alloc) Adam.
        lr_dual: LR for dual ascent.
        rho: penalty parameter for augmented terms.
        max_epochs: training steps.
        verbose: print diagnostics.
    Returns:
        A: (n, m) allocation matrix in [0,1].
    """
    n, m = rel_matrix.shape
    # primal logits
    logits = nn.Parameter(torch.zeros(n, m))
    # dual vars
    alpha = torch.zeros(n, requires_grad=False)
    beta  = torch.zeros(m, requires_grad=False)
    min_prod = gamma * producer_max_min_utility

    opt = optim.Adam([logits], lr=lr_primal)

    for epoch in range(1, max_epochs+1):
        A = torch.sigmoid(logits)
        # objective
        util = (A * rel_matrix).sum(dim=1).mean()
        # constraints residuals
        ci = A.sum(dim=1) - k_rec                 # target zero
        gj = torch.relu(min_prod - A.sum(dim=0))  # positive slack if violation

        # augmented Lagrangian
        L = -util + (alpha * ci).mean() + (beta * gj).mean() \
            + (rho/2)*(ci.pow(2).mean()) + (rho/2)*(gj.pow(2).mean())

        opt.zero_grad()
        L.backward()
        # gradient step on logits
        opt.step()
        # project logits to encourage binarity (optional)
        with torch.no_grad():
            logits.clamp_(-5, 5)

        # dual ascent on alpha, beta
        with torch.no_grad():
            alpha += lr_dual * ci
            beta  += lr_dual * gj
            beta.clamp_(min=0.0)

        if verbose and epoch % 500 == 0:
            with torch.no_grad():
                max_ci = ci.abs().max().item()
                min_prod_rec = (A.sum(dim=0)).min().item()
                print(f"Epoch {epoch}: util={util.item():.4f}, max|ci|={max_ci:.4f}, min_prod={min_prod_rec:.4f}")

    # final allocation
    return torch.sigmoid(logits).detach()

A = compute_consumer_optimal_with_augmented_lagrangian(
    rel_matrix=torch.tensor(sampled_matrix),
    k_rec=10,
    producer_max_min_utility=10,
    gamma=0.5,
    lr_primal=1e-2,
    lr_dual=1e-2,
    rho=10.0,
    max_epochs=30000,
    verbose=True
)


Epoch 500: util=9.1009, max|ci|=0.2335, min_prod=10.2269
Epoch 1000: util=6.0822, max|ci|=3.1755, min_prod=6.8299
Epoch 1500: util=5.9762, max|ci|=3.2986, min_prod=6.7023
Epoch 2000: util=7.1862, max|ci|=1.9355, min_prod=8.0394
Epoch 2500: util=8.1647, max|ci|=0.8273, min_prod=9.1013
Epoch 3000: util=8.5212, max|ci|=0.4252, min_prod=9.4566
Epoch 3500: util=8.6877, max|ci|=0.2386, min_prod=9.5868
Epoch 4000: util=8.7788, max|ci|=0.1374, min_prod=9.6172
Epoch 4500: util=8.8313, max|ci|=0.0803, min_prod=9.5853
Epoch 5000: util=8.8630, max|ci|=0.0473, min_prod=9.5061
Epoch 5500: util=8.8832, max|ci|=0.0281, min_prod=9.3848
Epoch 6000: util=8.8973, max|ci|=0.0168, min_prod=9.2206
Epoch 6500: util=8.9085, max|ci|=0.0101, min_prod=9.0098
Epoch 7000: util=8.9186, max|ci|=0.0060, min_prod=8.7465
Epoch 7500: util=8.9291, max|ci|=0.0036, min_prod=8.4242
Epoch 8000: util=8.9410, max|ci|=0.0022, min_prod=8.0368
Epoch 8500: util=8.9553, max|ci|=0.0013, min_prod=7.5801
Epoch 9000: util=8.9729, max|ci

In [174]:
A.round().sum(axis=1)  # sum of all producers

tensor([ 9., 10.,  9.,  9.,  9.,  9.,  9.,  9.,  9.,  9., 10.,  9., 10.,  9.,
         9.,  9., 10.,  9.,  9.,  9.,  9.,  8.,  9., 10.,  9.,  9.,  9.,  9.,
         9., 10.,  9.,  9.,  9.,  9.,  8.,  9.,  9.,  9.,  9., 10.,  9.,  9.,
         9.,  9.,  9., 10., 10.,  9.,  9., 10.,  9.,  9.,  9.,  9., 10.,  9.,
        10., 10., 10.,  9., 10.,  9.,  9., 10.,  9.,  9., 10.,  9.,  9.,  9.,
         9.,  9.,  9.,  9., 10., 10.,  9.,  9.,  9.,  9., 10.,  8.,  9.,  9.,
         8., 10., 10.,  9.,  9., 10.,  9.,  9.,  9., 10.,  9.,  9., 10.,  8.,
         9.,  9.])

In [165]:
import torch
import torch.nn as nn
import torch.optim as optim

def compute_consumer_optimal_with_augmented_lagrangian_cvar(
    rel_matrix: torch.Tensor,
    k_rec: int,
    producer_max_min_utility: float,
    gamma: float,
    group_assignments: torch.Tensor,
    alpha: float,
    lr_primal: float = 1e-2,
    lr_dual: float = 1e-2,
    rho: float = 10.0,
    max_epochs: int = 5000,
    verbose: bool = False
) -> torch.Tensor:
    """
    Augmented Lagrangian solver with CVaR group objective:
      minimizes cvar_util + constraints via Lagrangian multipliers.
    """
    n, m = rel_matrix.shape
    # primal logits
    logits = nn.Parameter(torch.zeros(n, m))
    # dual vars (no grad for multipliers)
    alpha_user = torch.zeros(n, requires_grad=False)
    beta_prod  = torch.zeros(m, requires_grad=False)
    min_prod = gamma * producer_max_min_utility

    log_rho = nn.Parameter(torch.log(torch.tensor(0.0)))
    opt = optim.Adam([logits, log_rho], lr=lr_primal)

    for epoch in range(1, max_epochs+1):
        A = torch.sigmoid(logits)
        consumer_allocations = (A * rel_matrix).sum(dim=1)
        # Compute CVaR loss over groups
        rho_val = torch.exp(log_rho)

        cvar_loss = cvar_util(rel_matrix, consumer_allocations, group_assignments, k_rec, rho_val, alpha)
        # Constraint residuals
        ci = A.sum(dim=1) - k_rec
        gj = torch.relu(min_prod - A.sum(dim=0))

        # Augmented Lagrangian
        L = cvar_loss + (alpha_user * ci).mean() + (beta_prod * gj).mean() \
            + (rho/2)*(ci.pow(2).mean()) + (rho/2)*(gj.pow(2).mean())

        # Primal step
        opt.zero_grad()
        L.backward()
        opt.step()
        with torch.no_grad():
            logits.clamp_(-5, 5)
            log_rho.clamp_(min=torch.log(torch.tensor(1e-3)), max=torch.log(torch.tensor(1e3)))


        # Dual ascent
        with torch.no_grad():
            alpha_user += lr_dual * ci
            beta_prod  += lr_dual * gj
            beta_prod.clamp_(min=0.0)

        if verbose and epoch % 500 == 0:
            with torch.no_grad():
                max_ci = ci.abs().max().item()
                min_prod_rec = A.sum(dim=0).min().item()
                print(f"Epoch {epoch}: cvar_loss={cvar_loss.item():.4f}, "
                      f"max|ci|={max_ci:.4f}, min_prod={min_prod_rec:.4f}")

    return torch.sigmoid(logits).detach()


A = compute_consumer_optimal_with_augmented_lagrangian_cvar(
    rel_matrix=torch.tensor(sampled_matrix),
    group_assignments=torch.tensor(group_assignments),
    alpha=0.95,
    k_rec=10,
    producer_max_min_utility=10,
    gamma=0.5,
    lr_primal=1e-2,
    lr_dual=1e-2,
    rho=10.0,
    max_epochs=30000,
    verbose=True
)

Epoch 500: cvar_loss=0.7080, max|ci|=0.3564, min_prod=10.2114
Epoch 1000: cvar_loss=0.3854, max|ci|=3.1995, min_prod=6.8058
Epoch 1500: cvar_loss=0.3990, max|ci|=3.3566, min_prod=6.6652
Epoch 2000: cvar_loss=0.2820, max|ci|=2.0007, min_prod=8.0255
Epoch 2500: cvar_loss=0.1858, max|ci|=0.8919, min_prod=9.1461
Epoch 3000: cvar_loss=0.1484, max|ci|=0.4737, min_prod=9.5471
Epoch 3500: cvar_loss=0.1304, max|ci|=0.2741, min_prod=9.7271
Epoch 4000: cvar_loss=0.1204, max|ci|=0.1554, min_prod=9.8200
Epoch 4500: cvar_loss=0.1129, max|ci|=0.0827, min_prod=9.8681
Epoch 5000: cvar_loss=0.1089, max|ci|=0.1038, min_prod=9.8914
Epoch 5500: cvar_loss=0.1032, max|ci|=0.1488, min_prod=9.8990
Epoch 6000: cvar_loss=0.0990, max|ci|=0.1904, min_prod=9.8933
Epoch 6500: cvar_loss=0.0946, max|ci|=0.2248, min_prod=9.8766
Epoch 7000: cvar_loss=0.0898, max|ci|=0.2378, min_prod=9.8506
Epoch 7500: cvar_loss=0.0849, max|ci|=0.2187, min_prod=9.8178
Epoch 8000: cvar_loss=0.0816, max|ci|=0.1981, min_prod=9.7792
Epoch 85

In [396]:
import torch
import torch.nn as nn
import torch.optim as optim

def compute_consumer_optimal_with_augmented_lagrangian_cvar(
    rel_matrix: torch.Tensor,
    group_assignments: torch.Tensor,
    alpha_cvar: float,
    k_rec: int,
    producer_max_min_utility: float,
    gamma: float,
    lr_primal: float = 1e-2,
    lr_dual: float = 1e-2,
    rho: float = 10.0,
    max_epochs: int = 5000,
    verbose: bool = False,
) -> torch.Tensor:
    """
    Augmented Lagrangian primal-dual solver:
      maximize mean utility under sum_j A[i,j]=k_rec and sum_i A[i,j]>=gamma*U constraints.

    Args:
        rel_matrix: (n, m) tensor of relevances.
        allocations: (n, m) tensor for storing allocation results.
        group_assignments: (n,) tensor for consumer group assignments.
        k_rec: per-consumer recommendation count.
        producer_max_min_utility: U_max reference.
        gamma: fraction of U_max floor.
        lr_primal: LR for primal (alloc) Adam.
        lr_dual: LR for dual ascent.
        rho: penalty parameter for augmented terms.
        max_epochs: training steps.
        verbose: print diagnostics.
    Returns:
        A: (n, m) allocation matrix in [0,1].
    """
    class CVaRModule(nn.Module):
        def __init__(self, n: int, m: int):
            super().__init__()
            self.rho_cvar = nn.Parameter(torch.zeros(1))
            self.logits = nn.Parameter(torch.zeros(n, m))

    model = CVaRModule(rel_matrix.shape[0], rel_matrix.shape[1])

    n, m = rel_matrix.shape
    # primal logits
    # dual vars
    alpha = torch.zeros(n, requires_grad=False)
    beta  = torch.zeros(m, requires_grad=False)
    min_prod = gamma * producer_max_min_utility


    opt = optim.AdamW(model.parameters(), lr=lr_primal, weight_decay=1e-4)

    for epoch in range(1, max_epochs+1):
        A = torch.sigmoid(model.logits / 0.1)
        consumer_allocations = (A * rel_matrix).sum(dim=1)
        # objective
        util = cvar_util(
            rel_matrix, consumer_allocations, group_assignments, k_rec, model.rho_cvar, alpha_cvar
        )
        ci = A.sum(dim=1) - k_rec
        gj = torch.relu(min_prod - A.sum(dim=0))

        # augmented Lagrangian
        loss = util + (alpha * ci).mean() + (beta * gj).mean() \
            + (rho/2)*(ci.pow(2).mean()) + (rho/2)*(gj.pow(2).mean())

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(), max_norm=3.0
        )
        opt.step()


        with torch.no_grad():
            alpha += lr_dual * ci
            beta  += lr_dual * gj
            beta.clamp_(min=0.0)


        if verbose and epoch % 500 == 0:
            with torch.no_grad():
                max_ci = ci.abs().max().item()
                min_prod_rec = (A.round().sum(axis=0)).min().item()
                print(f"Epoch {epoch}: util={util.item():.4f}, max|ci|={max_ci:.4f}, min_prod={min_prod_rec:.4f}")

    return torch.sigmoid(model.logits).detach().numpy()



sampled_matrix, consumer_ids, group_assignments = sample_data_for_group(
    n_consumers=100,
    n_producers=100,
    groups_map=GROUPS_MAP,
    group_key="top_category",
    data=REL_MATRIX,
    seed=1
)


A = compute_consumer_optimal_with_augmented_lagrangian_cvar(
    rel_matrix=torch.tensor(sampled_matrix),
    group_assignments=torch.tensor(group_assignments),
    alpha_cvar=0.95,
    k_rec=25,
    producer_max_min_utility=10,
    gamma=0.5,
    lr_primal=1e-3,
    lr_dual=1e-3,
    rho=10.0,
    max_epochs=40000,
    verbose=True,
)

Epoch 500: util=0.7602, max|ci|=0.2951, min_prod=0.0000
Epoch 1000: util=0.5435, max|ci|=0.2810, min_prod=0.0000
Epoch 1500: util=0.4055, max|ci|=0.2107, min_prod=0.0000
Epoch 2000: util=0.3104, max|ci|=0.1993, min_prod=0.0000
Epoch 2500: util=0.2409, max|ci|=0.1444, min_prod=0.0000
Epoch 3000: util=0.1833, max|ci|=0.1944, min_prod=0.0000
Epoch 3500: util=0.1453, max|ci|=0.1533, min_prod=0.0000
Epoch 4000: util=0.1152, max|ci|=0.1327, min_prod=0.0000
Epoch 4500: util=0.0945, max|ci|=0.1097, min_prod=0.0000
Epoch 5000: util=0.0776, max|ci|=0.1201, min_prod=0.0000
Epoch 5500: util=0.0669, max|ci|=0.0914, min_prod=0.0000
Epoch 6000: util=0.0549, max|ci|=0.1090, min_prod=0.0000
Epoch 6500: util=0.0482, max|ci|=0.0823, min_prod=0.0000
Epoch 7000: util=0.0434, max|ci|=0.0602, min_prod=0.0000
Epoch 7500: util=0.0368, max|ci|=0.0874, min_prod=1.0000
Epoch 8000: util=0.0331, max|ci|=0.0980, min_prod=1.0000
Epoch 8500: util=0.0317, max|ci|=0.0559, min_prod=1.0000
Epoch 9000: util=0.0295, max|ci|

In [399]:
A.round().sum(axis=0)

array([ 5.,  5.,  6., 29.,  5., 19., 53.,  5., 10., 50.,  5., 20.,  5.,
       30., 10., 86., 14.,  5., 77.,  5., 98.,  5., 93.,  5., 93., 68.,
        5., 38.,  5.,  5.,  5.,  5., 96.,  5.,  5.,  5., 79.,  5.,  5.,
       70.,  5.,  5., 98., 44.,  5., 54.,  5.,  5., 71.,  5.,  5.,  5.,
       83.,  5., 95., 76.,  5.,  5.,  5.,  5.,  5., 35., 90.,  5.,  5.,
       45.,  5., 29.,  5.,  5.,  5., 31.,  5.,  5.,  6.,  8., 95.,  5.,
       19.,  5., 98.,  4., 81.,  5.,  5.,  5.,  5.,  5.,  5., 15.,  4.,
       48.,  5.,  9., 16.,  5.,  5.,  5.,  5., 31.], dtype=float32)