In [1]:
import quant_inf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from scipy.linalg import logm
from quant_inf.tools.list import unique

dict_color = {
    'quantum':'#1f77b4',
    'quant_diag':'#ff7f0e',
    'quant_greedy': '#2ca02c',
    'logdet':'#d62728',
    'TRW':'#9467bd',
    'quant_exp':'#8c564b',
    'quant_exp_tronc': '#e377c2',
    'quant_exp_mean_field': '#bcbd22',
    'C4': '#17becf'}

In [2]:
sns.set_context("notebook", font_scale=1)

# Experiment 1

In [3]:
def get_result_experiment_1(
        d:int,
        n_sample:int):
    """For k between d+1 and 2^d, computes quantum and quant_greedy
    relaxation of KL with k features.

    Returns:
        pandas.core.frame.DataFrame: a dataframe with columns
        ['k','relaxation','KL_bound','true_KL','d','graph']
    
    Args:
        d (int): number of variables
        max_features_greedy (int): number of features selected for greedy algorithm
        n_sample (int): number of sample for each size of feature vector
    """
    
    results = pd.DataFrame(columns=['k','relaxation','KL_bound','true_KL','d','graph'])

    indep_features = [{i} for i in range(1,d+1)]

    tree_features = (
        [{i} for i in range(1,d+1)] 
        + [{i,i+1} for i in range(1,d)]
    )
    complete_features = (
        [{i} for i in range(1,d+1)]
        + [{i,j} for i in range(1,d+1) for j in range(i+1,d+1)]
    )

    dict_features = {
        "Independent variables":indep_features,
        "Tree":tree_features,
        "Complete graph":complete_features
    }

    all_features = (
        [set()] 
        + [{i} for i in range(1,d+1)]
        + [{i,j} for i in range(1,d+1) for j in range(i+1,d+1)]
        + [{i,j,k} for i in range(1,d+1) for j in range(i+1,d+1) for k in range(j+1,d+1)]
        + [{i,j,k,l} for i in range(1,d+1) for j in range(i+1,d+1) for k in range(j+1,d+1) for l in range(k+1,d+1)]
        + [{i,j,k,l,m} for i in range(1,d+1) for j in range(i+1,d+1) for k in range(j+1,d+1) for l in range(k+1,d+1) for m in range(l+1,d+1)]
    )
    n = len(all_features)
    
    for graph in ["Independent variables", "Tree", "Complete graph"]:
        print(f"Computing for graph: {graph}")         
        for _ in tqdm(range(n_sample)):
            coefficients = quant_inf.tools.random_coefficients_gaussian(
                graph_features=dict_features[graph]
                )
            
            features_greedy = [set()] + [{i} for i in range(1,d+1)] #the k features already selected
            
            exact_inference = quant_inf.algorithms.ExactBruteForce(
                    coefficients=coefficients,
                    d=d,
                    features=all_features)
            exact_inference.solve()
            
            for k in range(d+1,n+1):

                sig_log_sig = exact_inference.moment_matrix[:k,:k]@logm(exact_inference.moment_matrix[:k,:k])

                results.loc[len(results)] = [
                    k,
                    "quantum",
                    np.trace(sig_log_sig)/k,
                    -exact_inference.entropy,
                    d,
                    graph
                    ]

                #We rewrite greedy algorithm used for QuantGreedyRelaxation on the KL
                if k == d+1:
                    results.loc[len(results)] = [
                        k,
                        "quant_greedy",
                        np.trace(sig_log_sig)/k,
                        -exact_inference.entropy,
                        d,
                        graph
                        ]
                else:
                    kl_bound_greedy = -np.inf
                    greedy_selected_feature = None

                    features_greedy_pool = unique(
                        [feat1^feat2 for feat1 in features_greedy for feat2 in [{i} for i in range(1,d+1)]]
                        )
                    features_greedy_pool = quant_inf.algorithms.quant_greedy.diff_list(features_greedy_pool,features_greedy)

                    for feat in features_greedy_pool:
                        #for regular
                        exact_inference_temp = quant_inf.algorithms.ExactBruteForce(
                            coefficients=coefficients,
                            d=d,
                            features=features_greedy + [feat])
                        exact_inference_temp.solve()
                        sig_log_sig_temp = exact_inference_temp.moment_matrix[:k,:k]@logm(exact_inference_temp.moment_matrix[:k,:k])
                        if (np.trace(sig_log_sig_temp)/k) > kl_bound_greedy:
                            kl_bound_greedy = np.trace(sig_log_sig_temp)/k
                            greedy_selected_feature = feat     
                    features_greedy.append(greedy_selected_feature)
                    results.loc[len(results)] = [
                        k,
                        "quant_greedy",
                        kl_bound_greedy,
                        -exact_inference.entropy,
                        d,
                        graph
                        ]
    return results

In [4]:
def plot_experiment_1(result_experiment,graph:str,title:str="",file:str=""):
    fig = plt.figure(figsize=(5,5))
    
    to_plot = (
        result_experiment
        .query('graph == @graph')
        .assign(normalized_error_in_bound=lambda x: (x.KL_bound - x.true_KL)/x.d)
    )
    ax = sns.lineplot(
        data=to_plot,
        x="k",
        y="normalized_error_in_bound",
        hue="relaxation",
        palette=dict_color,
        errorbar=('sd',1),
    )
    plt.ylabel("Normalized error in bound")
    plt.xlabel("Number of features")
    
    ax.axhline(y=0,color="grey")
    fig.suptitle(title)
    if file != "":
        plt.savefig(file)
    plt.show()

In [None]:
d= 5
result_experiment_1= get_result_experiment_1(
    d=d,
    n_sample=10
    )

## Plots

In [None]:
print("Independent variables")
plot_experiment_1(result_experiment=result_experiment_1,graph="Independent variables")
print("Tree")
plot_experiment_1(result_experiment=result_experiment_1,graph="Tree")
print("Complete graph")
plot_experiment_1(result_experiment=result_experiment_1,graph="Complete graph")