In [1]:
import quant_inf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import time 

dict_color = {
    'quantum':'#1f77b4',
    'quant_diag':'#ff7f0e',
    'quant_greedy': '#2ca02c',
    'logdet':'#d62728',
    'TRW':'#9467bd',
    'quant_exp':'#8c564b',
    'C1': '#e377c2',
    'C2': '#7f7f7f',
    'C3': '#bcbd22',
    'C4': '#17becf'}

In [2]:
sns.set_context("notebook", font_scale=1)

# Coefficients from logdet article

In [45]:
def get_result_experiment_1(
        d:int,
        max_features_greedy:int,
        n_sample:int,n_points:int,
        max_coupling_strenght:float,
        interaction:str):
    """Compare different relaxation with random coefficients from [1]

    Returns:
        pandas.core.frame.DataFrame: a dataframe with columns
        ['coupling_strenght','relaxation','logp_bound','logdet_bound','d']
    
    Args:
        d (int): number of variables
        max_features_greedy (int): number of features selected for greedy algorithm
        n_sample (int): number of sample for each coupling strenght
        n_points (int): number of point on the x line
        max_coupling_strenght (float): maximal value for coupling strenght
        interaction (str): type of interaction
    
    Reference:
    [1] Michael Jordan and Martin J Wainwright. “Semidefinite Relaxations
        for Approximate Inference on Graphs with Cycles”. In: Advances in
        Neural Information Processing Systems. Vol. 16. MIT Press, 2003.
    """
    assert interaction in ["mixed","attractive","repulsive"]
    results = pd.DataFrame(columns=['coupling_strenght','relaxation','logp_bound','logdet_bound','d'])

    list_w = np.linspace(start=0,stop=max_coupling_strenght,num=n_points,endpoint=True)
    
    features_1 = (
    [set()] 
    + [{i} for i in range(1,d+1)]
    )

    rho_fixed = 2*(np.ones((d,d)) - np.eye(d))/d

    time_logdet = 0
    time_quant = 0
    time_TRW = 0

    for w in tqdm(list_w):
        for _ in range(n_sample):

            coefficients = quant_inf.tools.ramdom_coefficient_logdet(
                d=d,
                strenght=w,
                graph="complete",
                interaction=interaction
                )
            
            start = time.perf_counter()
            logdet_inference = quant_inf.algorithms.LogDetRelaxation(
                coefficients=coefficients,
                d=d,
                features=features_1)
            logdet_inference.solve()
            time_logdet += time.perf_counter() - start

            start = time.perf_counter()
            TRW_inference = quant_inf.algorithms.TRWRelaxation(
                coefficients=coefficients,
                d=d)
            TRW_inference.solve_rho_fixed(rho_fixed)
            results.loc[len(results)] = [
                w,
                "TRW",
                TRW_inference.log_partition,
                logdet_inference.log_partition,
                d
                ]
            time_TRW += time.perf_counter() - start

            start = time.perf_counter()
            quant_inference = quant_inf.algorithms.QuantumRelaxation(
                coefficients=coefficients,
                d=d,
                features=features_1)
            quant_inference.solve(tol=1e-5)
            results.loc[len(results)] = [
                w,
                "quantum",
                quant_inference.log_partition,
                logdet_inference.log_partition,
                d
                ]
            time_quant += time.perf_counter() - start
            
            """
            quant_greedy_inference = quant_inf.algorithms.QuantGreedyRelaxation(
                coefficients=coefficients,
                d=d)
            quant_greedy_inference.solve(number_extra_features=max_features_greedy,tol=1e-5)
            results.loc[len(results)] = [
                w,
                "quant_greedy",
                quant_greedy_inference.log_partition,
                logdet_inference.log_partition,
                d
                ]
            """

    print(f"Time logdet: {time_logdet:.2f}")
    print(f"Time TRW: {time_quant:.2f}")
    print(f"Time quant: {time_TRW:.2f}")

    return results

In [None]:
def plot_experiment_1(result_experiment,title:str="",file:str=""):
    fig = plt.figure(figsize=(5,5))
    
    to_plot = result_experiment.assign(normalized_error_in_bound=lambda x: (x.logp_bound - x.logdet_bound)/x.d)
    ax1 = plt.subplot(1,1,1)
    sns.lineplot(
        data=to_plot,
        x="coupling_strenght",
        y="normalized_error_in_bound",
        hue="relaxation",
        palette=dict_color,
        errorbar=('sd',1),
        ax=ax1
    )
    ax1.set_ylabel("Relative error in bound")
    ax1.set_xlabel("Coupling strenght")
    ax1.axhline(y=0,color="grey")
    
    fig.suptitle(title)
    fig.tight_layout()
    if file != "":
        plt.savefig(file)
    plt.show()

## Running the experiment

In [None]:
d=30
result_experiment_1_attractive= get_result_experiment_1(
    d=d,
    max_features_greedy=3,
    n_sample=10,
    n_points=10,
    max_coupling_strenght=.5,
    interaction="attractive"
    )
result_experiment_1_mixed= get_result_experiment_1(
    d=d,
    max_features_greedy=3,
    n_sample=10,
    n_points=10,
    max_coupling_strenght=.5,
    interaction="mixed"
    )
result_experiment_1_repulsive= get_result_experiment_1(
    d=d,
    max_features_greedy=3,
    n_sample=10,
    n_points=10,
    max_coupling_strenght=.5,
    interaction="repulsive"
    )

## Plots

In [None]:
print("Full with attractive coupling")
plot_experiment_1(result_experiment_1_attractive)

In [None]:
print("Full with mixed coupling")
plot_experiment_1(result_experiment_1_mixed)

In [None]:
print("Full with repulsive coupling")
plot_experiment_1(result_experiment_1_repulsive)