In [7]:
import numpy as np
import pandas as pd

import poissonlearning as pl
import graphlearning as gl

import storage

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
dataset = "mnist"
dataset_metric = "vae"
experiments = storage.load_results(name="real_data", folder="../results")

In [9]:
p_all = list(set((ex["p"] for ex in experiments if ex["dataset"] == dataset and ex["dataset_metric"] == dataset_metric)))
results_table = pd.DataFrame()
for p in p_all:
    labels_per_class_all = list(set((ex["labels_per_class"] for ex in experiments if np.isclose(ex["p"], p) and ex["dataset"] == dataset and ex["dataset_metric"] == dataset_metric)))
    results_p = pd.Series(index=labels_per_class_all, name=p, dtype="object")
    for num_labels in labels_per_class_all:
        selected_experiments = list(
            filter(
                lambda x: 
                    np.isclose(x["p"], p) 
                    and np.all(np.isclose(x["labels_per_class"], num_labels))
                    and x["dataset"] == dataset,
                experiments
            )
        )
        def _compute_accuracy(experiment):
            prob = experiment["solution"].drop(columns=["x", "y", "true_labels"]).to_numpy()
            scores = prob - np.min(prob)
            scores = scores / np.max(scores)

            # Check if scores are similarity or distance
            pred_labels = np.argmax(scores, axis=1)
            accuracy = gl.ssl.ssl_accuracy(experiment["solution"]["true_labels"], pred_labels, experiment["labels_per_class"] * 10)
            return accuracy
        accuracy = [_compute_accuracy(ex) for ex in selected_experiments]
        accuracy_mean = np.mean(accuracy)
        accuracy_std = np.std(accuracy)
        results_p[num_labels] = f"{accuracy_mean:.3f} ({accuracy_std:.2f})"

    def _extend_results(results_table, new_entries):
        new_index = np.union1d(results_table.index, new_entries.index)
        results_table = results_table.reindex(new_index)
        new_entries = new_entries.reindex(new_index)
        results_table[new_entries.name] = new_entries
        return results_table

    results_table = _extend_results(results_table, results_p)
results_table

Unnamed: 0,2.0,3.0,4.0,5.0,6.0,8.0
1,80.301 (0.28),71.994 (0.79),66.042 (0.89),62.585 (1.08),60.251 (1.05),57.655 (0.78)
2,88.454 (2.63),83.635 (2.93),79.458 (3.03),76.657 (3.26),74.689 (3.46),72.530 (3.43)
5,92.101 (0.81),89.606 (1.06),87.434 (1.19),85.596 (1.13),84.384 (1.13),83.040 (0.98)
10,92.673 (0.78),90.898 (1.10),89.163 (1.27),87.643 (1.40),86.480 (1.32),85.388 (1.43)


In [10]:
print(results_table.T.to_latex())

\begin{tabular}{lllll}
\toprule
{} &             1  &             2  &             5  &             10 \\
\midrule
2.0 &  80.301 (0.28) &  88.454 (2.63) &  92.101 (0.81) &  92.673 (0.78) \\
3.0 &  71.994 (0.79) &  83.635 (2.93) &  89.606 (1.06) &  90.898 (1.10) \\
4.0 &  66.042 (0.89) &  79.458 (3.03) &  87.434 (1.19) &  89.163 (1.27) \\
5.0 &  62.585 (1.08) &  76.657 (3.26) &  85.596 (1.13) &  87.643 (1.40) \\
6.0 &  60.251 (1.05) &  74.689 (3.46) &  84.384 (1.13) &  86.480 (1.32) \\
8.0 &  57.655 (0.78) &  72.530 (3.43) &  83.040 (0.98) &  85.388 (1.43) \\
\bottomrule
\end{tabular}

