In [1]:
import quant_inf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm

dict_color = {
    'quantum':'#1f77b4',
    'quant_diag':'#ff7f0e',
    'quant_greedy': '#2ca02c',
    'logdet':'#d62728',
    'TRW':'#9467bd',
    'quant_exp':'#8c564b',
    'quant_greedy_3': '#e377c2',
    'quant_greedy_6': '#7f7f7f',
    'quant_greedy_8': '#bcbd22',
    'quant_greedy_10': '#17becf'}


In [2]:
sns.set_context("notebook", font_scale=1)

# Gaussian parameters and varying temperature

In [3]:
def get_result_experiment_1(
        d:int,
        max_features_greedy:int,
        n_sample:int,n_points:int,
        min_temperature:float,
        max_temperature:float
        ):
    """Compare greedy relaxation with different numbers of added
    features, logdet and TRW relaxations, with random normal
    coefficients and temperature scaling.

    Returns:
        pandas.core.frame.DataFrame: a dataframe with columns
        ['temperature','relaxation','logp_bound','true_logp','logp_bound_base','l1_error','d','n_added_features']
    
    Args:
        d (int): number of variables
        max_features_greedy (int): number of features selected for greedy algorithm
        n_sample (int): number of sample for each coupling strenght
        n_points (int): number of point on the x line
        max_temperature (float): maximal value for temperature
    """
    results = pd.DataFrame(columns=['temperature','relaxation','logp_bound','true_logp','logp_bound_base','l1_error','d','n_added_features'])

    list_eps = np.linspace(start=min_temperature,stop=max_temperature,num=n_points,endpoint=True)
    features_1 = (
    [set()] 
    + [{i} for i in range(1,d+1)]
    )
    complete_graph_features = (
        [{i} for i in range(1,d+1)]
        + [{i,j} for i in range(1,d+1) for j in range(i+1,d+1)]
    )

    rho_fixed = 2*(np.ones((d,d)) - np.eye(d))/d

    for _ in tqdm(range(n_sample)):
        coefficients = quant_inf.tools.random_coefficients_gaussian(
                graph_features=complete_graph_features
                )
        for eps in list_eps:
            exact_inference = quant_inf.algorithms.ExactBruteForce(
                coefficients=coefficients,
                d=d,
                features=features_1,
                eps=eps)
            exact_inference.solve()

            quant_inference = quant_inf.algorithms.QuantumRelaxation(
                coefficients=coefficients,
                d=d,
                features=features_1,
                eps=eps)
            quant_inference.solve(tol = 1e-4)
            
            results.loc[len(results)] = [
                eps,
                "quantum",
                quant_inference.log_partition,
                exact_inference.log_partition,
                quant_inference.log_partition,
                quant_inference.l1_error(exact_inference.marginals),
                d,
                0
                ]
            
            quant_greedy_inference = quant_inf.algorithms.QuantGreedyRelaxation(
                    coefficients=coefficients,
                    d=d,
                    eps=eps)

            logdet_inference = quant_inf.algorithms.LogDetRelaxation(
                coefficients=coefficients,
                d=d,
                features=features_1,
                eps=eps)
            logdet_inference.solve()
            results.loc[len(results)] = [
                eps,
                "logdet",
                logdet_inference.log_partition,
                exact_inference.log_partition,
                quant_inference.log_partition,
                logdet_inference.l1_error(exact_inference.marginals),
                d,
                0
                ]
            
            TRW_inference = quant_inf.algorithms.TRWRelaxation(
                coefficients=coefficients,
                d=d,
                eps=eps)
            TRW_inference.solve_rho_fixed(rho=rho_fixed)
            results.loc[len(results)] = [
                eps,
                "TRW",
                TRW_inference.log_partition,
                exact_inference.log_partition,
                quant_inference.log_partition,
                TRW_inference.l1_error(exact_inference.marginals),
                d,
                0
                ]
            
            for i in range(1,max_features_greedy+1):
                quant_greedy_inference._select_feature(tol_search=1e-1,tol = 1e-3)
                results.loc[len(results)] = [
                    eps,
                    "quant_greedy",
                    quant_greedy_inference.log_partition,
                    exact_inference.log_partition,
                    quant_inference.log_partition,
                    quant_greedy_inference.l1_error(exact_inference.marginals),
                    d,
                    i
                    ]
    return results


In [19]:
@np.vectorize
def rename(name:str,n_added_features):
    if name == 'quant_greedy':
        return f'quant_greedy_{n_added_features}'
    else: 
        return name

def plot_experiment_1a(result_experiment,title:str="",file:str=""):
    fig = plt.figure(figsize=(5,4.5))
    
    to_plot = (
        result_experiment
        .query("relaxation == 'quant_greedy' & (n_added_features in [3,6,8,10])")
        .assign(normalized_gain_in_bound=lambda x: (x.logp_bound - x.logp_bound_base)/x.d)
    )

    ax1 = plt.subplot(3,1,(1,2))
    sns.lineplot(
        data=to_plot,
        x="temperature",
        y="normalized_gain_in_bound",
        hue="n_added_features",
        errorbar=('sd',1),
        ax=ax1
    )
    ax1.set_ylabel("Normalized gain in bound")
    ax1.set_xlabel(None)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.axhline(y=0,color="grey")
    
    ax2 = plt.subplot(3,1,3,sharex=ax1)
    sns.lineplot(
        data=to_plot,
        x="temperature",
        y="l1_error",
        hue="n_added_features",
        errorbar=('sd',1),
        ax=ax2,
        legend=False
    )
    ax2.set_ylabel("Error in\n marginals")
    ax2.set_xlabel("Temperature")
    
    ax1.legend(loc = 'best')

    fig.suptitle(title)
    fig.tight_layout()
    if file != "":
        plt.savefig(file)
    plt.show()

def plot_experiment_1b(result_experiment,title:str="",file:str=""):
    fig = plt.figure(figsize=(7,4.5))
    
    to_plot = (
        result_experiment
        .query("n_added_features in [0,3,10]")
        .assign(relaxation = lambda x: rename(x.relaxation,x.n_added_features) )
        .assign(normalized_error_in_bound=lambda x: (x.logp_bound - x.true_logp)/x.d)
    )

    ax1 = plt.subplot(3,1,(1,2))
    sns.lineplot(
        data=to_plot,
        x="temperature",
        y="normalized_error_in_bound",
        hue="relaxation",
        palette=dict_color,
        errorbar=('sd',1),
        ax=ax1
    )
    ax1.set_ylabel("Normalized error in bound")
    ax1.set_xlabel(None)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.axhline(y=0,color="grey")
    
    ax2 = plt.subplot(3,1,3,sharex=ax1)
    sns.lineplot(
        data=to_plot,
        x="temperature",
        y="l1_error",
        hue="relaxation",
        palette=dict_color,
        errorbar=('sd',1),
        ax=ax2,
        legend=False
    )
    ax2.set_ylabel("Error in\n marginals")
    ax2.set_xlabel("Temperature")
    
    sns.move_legend(ax1, "upper left", bbox_to_anchor=(1, 1))
    
    fig.suptitle(title)
    fig.tight_layout()
    if file != "":
        plt.savefig(file)
    plt.show()

In [None]:
d=10
result_experiment_1 = get_result_experiment_1(
    d=d,
    max_features_greedy=10,
    n_sample=10,
    n_points=10,
    min_temperature=.1,
    max_temperature=10,
    )

## Plots

In [None]:
plot_experiment_1a(result_experiment_1)

In [None]:
plot_experiment_1b(result_experiment_1)

# Log-det parameters

In [11]:
def get_result_experiment_2(
        d:int,
        max_features_greedy:int,
        n_sample:int,n_points:int,
        max_coupling_strenght:float,
        interaction:str):
    """Compare greedy relaxation with different numbers of added
    features, logdet and TRW relaxations, with random
    coefficients from [1]

    Returns:
        pandas.core.frame.DataFrame: a dataframe with columns
        ['coupling_strenght','relaxation','logp_bound','true_logp','logp_bound_base','l1_error','d','n_added_features']
    
    Args:
        d (int): number of variables
        max_features_greedy (int): number of features selected for greedy algorithm
        n_sample (int): number of sample for each coupling strenght
        n_points (int): number of point on the x line
        max_coupling_strenght (float): maximal value for coupling strenght
        interaction (str): type of interaction

    Reference:
    [1] Michael Jordan and Martin J Wainwright. “Semidefinite Relaxations
        for Approximate Inference on Graphs with Cycles”. In: Advances in
        Neural Information Processing Systems. Vol. 16. MIT Press, 2003.
    """

    assert interaction in ["mixed","attractive","repulsive"]
    results = pd.DataFrame(columns=['coupling_strenght','relaxation','logp_bound','true_logp','logp_bound_base','l1_error','d','n_added_features'])

    list_w = np.linspace(start=0,stop=max_coupling_strenght,num=n_points,endpoint=True)

    features_1 = (
    [set()] 
    + [{i} for i in range(1,d+1)]
    )
    complete_graph_features = (
        [{i} for i in range(1,d+1)]
        + [{i,j} for i in range(1,d+1) for j in range(i+1,d+1)]
    )

    rho_fixed = 2*(np.ones((d,d)) - np.eye(d))/d

    for _ in tqdm(range(n_sample)):
        for w in list_w:
            coefficients = quant_inf.tools.ramdom_coefficient_logdet(
                d=d,
                strenght=w,
                graph="complete",
                interaction=interaction
                )
            
            exact_inference = quant_inf.algorithms.ExactBruteForce(
                coefficients=coefficients,
                d=d,
                features=features_1)
            
            exact_inference.solve()

            quant_inference = quant_inf.algorithms.QuantumRelaxation(
                coefficients=coefficients,
                d=d,
                features=features_1
                )
            quant_inference.solve(tol = 1e-4)
            
            results.loc[len(results)] = [
                w,
                "quantum",
                quant_inference.log_partition,
                exact_inference.log_partition,
                quant_inference.log_partition,
                quant_inference.l1_error(exact_inference.marginals),
                d,
                0
                ]
            
            quant_greedy_inference = quant_inf.algorithms.QuantGreedyRelaxation(
                    coefficients=coefficients,
                    d=d
                    )

            logdet_inference = quant_inf.algorithms.LogDetRelaxation(
                coefficients=coefficients,
                d=d,
                features=features_1
                )
            logdet_inference.solve()
            results.loc[len(results)] = [
                w,
                "logdet",
                logdet_inference.log_partition,
                exact_inference.log_partition,
                quant_inference.log_partition,
                logdet_inference.l1_error(exact_inference.marginals),
                d,
                0
                ]
            
            TRW_inference = quant_inf.algorithms.TRWRelaxation(
                coefficients=coefficients,
                d=d
                )
            TRW_inference.solve_rho_fixed(rho=rho_fixed)
            results.loc[len(results)] = [
                w,
                "TRW",
                TRW_inference.log_partition,
                exact_inference.log_partition,
                quant_inference.log_partition,
                TRW_inference.l1_error(exact_inference.marginals),
                d,
                0
                ]
            
            for i in range(1,max_features_greedy+1):
                quant_greedy_inference._select_feature(tol_search=1e-1,tol = 1e-3)
                results.loc[len(results)] = [
                    w,
                    "quant_greedy",
                    quant_greedy_inference.log_partition,
                    exact_inference.log_partition,
                    quant_inference.log_partition,
                    quant_greedy_inference.l1_error(exact_inference.marginals),
                    d,
                    i
                    ]
    return results


In [12]:
@np.vectorize
def rename(name:str,n_added_features):
    if name == 'quant_greedy':
        return f'quant_greedy_{n_added_features}'
    else: 
        return name

def plot_experiment_2a(result_experiment,title:str="",file:str=""):
    fig = plt.figure(figsize=(5,4.5))
    
    to_plot = (
        result_experiment
        .query("relaxation == 'quant_greedy' & (n_added_features in [3,6,8,10])")
        .assign(normalized_gain_in_bound=lambda x: (x.logp_bound - x.logp_bound_base)/x.d)
    )

    ax1 = plt.subplot(3,1,(1,2))
    sns.lineplot(
        data=to_plot,
        x="coupling_strenght",
        y="normalized_gain_in_bound",
        hue="n_added_features",
        errorbar=('sd',1),
        ax=ax1
    )
    ax1.set_ylabel("Normalized gain in bound")
    ax1.set_xlabel(None)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.axhline(y=0,color="grey")
    
    ax2 = plt.subplot(3,1,3,sharex=ax1)
    sns.lineplot(
        data=to_plot,
        x="coupling_strenght",
        y="l1_error",
        hue="n_added_features",
        errorbar=('sd',1),
        ax=ax2,
        legend=False
    )
    ax2.set_ylabel("Error in\n marginals")
    ax2.set_xlabel("Coupling strenght")
    
    ax1.legend(loc = 'best')

    fig.suptitle(title)
    fig.tight_layout()
    if file != "":
        plt.savefig(file)
    plt.show()

def plot_experiment_2b(result_experiment,title:str="",file:str=""):
    fig = plt.figure(figsize=(5,6))
    
    to_plot = (
        result_experiment
        .query("n_added_features in [0,3,10]")
        .assign(relaxation = lambda x: rename(x.relaxation,x.n_added_features) )
        .assign(normalized_error_in_bound=lambda x: (x.logp_bound - x.true_logp)/x.d)
    )

    ax1 = plt.subplot(3,1,(1,2))
    sns.lineplot(
        data=to_plot,
        x="coupling_strenght",
        y="normalized_error_in_bound",
        hue="relaxation",
        palette=dict_color,
        errorbar=('sd',1),
        ax=ax1
    )
    ax1.set_ylabel("Normalized error in bound")
    ax1.set_xlabel(None)
    plt.setp(ax1.get_xticklabels(), visible=False)
    ax1.axhline(y=0,color="grey")
    
    ax2 = plt.subplot(3,1,3,sharex=ax1)
    sns.lineplot(
        data=to_plot,
        x="coupling_strenght",
        y="l1_error",
        hue="relaxation",
        palette=dict_color,
        errorbar=('sd',1),
        ax=ax2,
        legend=False
    )
    ax2.set_ylabel("Error in\n marginals")
    ax2.set_xlabel("Coupling strenght")
    
    sns.move_legend(
        ax1, "lower center",
        bbox_to_anchor=(.5, 1), ncol=2, title=None, frameon=False,
    )
        
    fig.suptitle(title)
    fig.tight_layout()
    if file != "":
        plt.savefig(file)
    plt.show()

In [None]:
d=10
result_experiment_2_attractive= get_result_experiment_2(
    d=d,
    max_features_greedy=10,
    n_sample=10,
    n_points=10,
    max_coupling_strenght=.5,
    interaction="attractive"
    )
result_experiment_2_mixed= get_result_experiment_2(
    d=d,
    max_features_greedy=10,
    n_sample=10,
    n_points=10,
    max_coupling_strenght=.5,
    interaction="mixed"
    )
result_experiment_2_repulsive= get_result_experiment_2(
    d=d,
    max_features_greedy=10,
    n_sample=10,
    n_points=10,
    max_coupling_strenght=.5,
    interaction="repulsive"
    )

## Plots

In [None]:
print("Full with attractive coupling")
plot_experiment_2a(result_experiment_2_attractive)
print("Full with mixed coupling")
plot_experiment_2a(result_experiment_2_mixed)
print("Full with repulsive coupling")
plot_experiment_2a(result_experiment_2_repulsive)

In [None]:
print("Full with attractive coupling")
plot_experiment_2b(result_experiment_2_attractive)
print("Full with mixed coupling")
plot_experiment_2b(result_experiment_2_mixed)
print("Full with repulsive coupling")
plot_experiment_2b(result_experiment_2_repulsive)