In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import twosf_solver as solver
from sample_mallows import get_mallows_utilities
from tqdm import tqdm


def sample_curves(phis, m, n, k, delta, num_pts, num_samples):
    all_pairs = []
    for phi in tqdm(phis):
        sample_pairs = []
        for sample in range(num_samples):
            U = get_mallows_utilities(m, n, phi)
            pairs = solver.get_user_curve(U, k, delta, num_pts)
            sample_pairs.append(pairs)
        all_pairs.append(np.array(sample_pairs))

    return all_pairs


def plot_curves(phis, all_pairs, fig_file_name):
    viridis = mpl.colormaps["viridis"]

    for phi, sample_pairs in zip(phis, all_pairs):
        color = viridis(phi)
        sp_mean = np.mean(sample_pairs[:, :, 1], axis=0)
        sp_std = np.std(sample_pairs[:, :, 1], axis=0, ddof=1)

        plt.plot(
            sample_pairs[0, :, 0], sp_mean, color=color, label=(r"$\phi =$ %.2f" % phi)
        )
        plt.fill_between(
            sample_pairs[0, :, 0],
            sp_mean - 2 * sp_std / np.sqrt(num_samples),
            sp_mean + 2 * sp_std / np.sqrt(num_samples),
            color=color,
            alpha=0.3,
        )

    plt.ylabel("Minimum normalized user utility")
    plt.xlabel(r"Minimum normalized item utility guaranteed ($\gamma_I$)")
    plt.legend()
    plt.savefig(fig_file_name)
    plt.show()


phis = np.linspace(0.1, 0.9, 5)

n = 300
m = 40
num_samples = 10
num_pts = 50
k = 1
delta  = 1

fig_file = "fig1.png"


curves = sample_curves(phis, m, n, k, delta, num_pts, num_samples)
plot_curves(phis, curves, fig_file)

 80%|████████  | 4/5 [46:01<11:30, 690.25s/it]

Failure:interrupted





SolverError: Solver 'SCS' failed. Try another solver, or solve with verbose=True for more information.