In [3]:
import numpy as np
import pandas as pd

import poissonlearning as pl
import graphlearning as gl

import storage

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
dataset = "fashionmnist"
dataset_metric = "vae"
experiments = storage.load_results(name="real_data_1", folder="../results")

In [10]:
p_all = list(set((ex["p"] for ex in experiments if ex["dataset"] == dataset and ex["dataset_metric"] == dataset_metric)))
results_table = pd.DataFrame()
for p in p_all:
    labels_per_class_all = list(set((ex["labels_per_class"] for ex in experiments if np.isclose(ex["p"], p) and ex["dataset"] == dataset and ex["dataset_metric"] == dataset_metric)))
    print("Num experiments: ", len(labels_per_class_all))
    results_p = pd.Series(index=labels_per_class_all, name=p, dtype="object")
    for num_labels in labels_per_class_all:
        selected_experiments = list(
            filter(
                lambda x: 
                    np.isclose(x["p"], p) 
                    and np.all(np.isclose(x["labels_per_class"], num_labels))
                    and x["dataset"] == dataset,
                experiments
            )
        )
        def _compute_accuracy(experiment):
            prob = experiment["solution"].drop(columns=["x", "y", "true_labels"]).to_numpy()
            scores = prob - np.min(prob)
            scores = scores / np.max(scores)

            # Check if scores are similarity or distance
            pred_labels = np.argmax(scores, axis=1)
            accuracy = gl.ssl.ssl_accuracy(experiment["solution"]["true_labels"], pred_labels, experiment["labels_per_class"] * 10)
            return accuracy
        accuracy = [_compute_accuracy(ex) for ex in selected_experiments]
        accuracy_mean = np.mean(accuracy)
        accuracy_std = np.std(accuracy)
        results_p[num_labels] = f"{accuracy_mean:.3f} ({accuracy_std:.2f})"

    def _extend_results(results_table, new_entries):
        new_index = np.union1d(results_table.index, new_entries.index)
        results_table = results_table.reindex(new_index)
        new_entries = new_entries.reindex(new_index)
        results_table[new_entries.name] = new_entries
        return results_table

    results_table = _extend_results(results_table, results_p)
results_table

Num experiments:  4
Num experiments:  4
Num experiments:  4
Num experiments:  4
Num experiments:  4
Num experiments:  4


Unnamed: 0,2.0,3.0,4.0,5.0,6.0,8.0
1,59.780 (1.06),57.365 (1.71),55.040 (2.21),53.537 (2.58),52.475 (2.82),51.323 (3.19)
2,65.733 (2.64),63.303 (2.26),61.707 (1.87),60.633 (1.92),59.759 (1.91),58.956 (1.91)
5,70.859 (4.11),69.343 (4.47),67.889 (4.25),66.798 (3.73),66.192 (3.48),64.909 (3.21)
10,75.643 (0.21),74.592 (0.35),73.704 (0.19),72.755 (0.41),72.163 (0.41),71.378 (0.72)


In [11]:
print(results_table.T.to_latex())

\begin{tabular}{lllll}
\toprule
{} &             1  &             2  &             5  &             10 \\
\midrule
2.0 &  59.780 (1.06) &  65.733 (2.64) &  70.859 (4.11) &  75.643 (0.21) \\
3.0 &  57.365 (1.71) &  63.303 (2.26) &  69.343 (4.47) &  74.592 (0.35) \\
4.0 &  55.040 (2.21) &  61.707 (1.87) &  67.889 (4.25) &  73.704 (0.19) \\
5.0 &  53.537 (2.58) &  60.633 (1.92) &  66.798 (3.73) &  72.755 (0.41) \\
6.0 &  52.475 (2.82) &  59.759 (1.91) &  66.192 (3.48) &  72.163 (0.41) \\
8.0 &  51.323 (3.19) &  58.956 (1.91) &  64.909 (3.21) &  71.378 (0.72) \\
\bottomrule
\end{tabular}

