In [None]:
import pickle
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

In [None]:
cd ..

In [None]:
DATA_DIR = pathlib.Path("data/generated")
dataset_files =  [f for f in DATA_DIR.iterdir() if f.suffix == ".pkl"]

In [None]:
def load_dataset(dataset_file):
    with open(dataset_file, "rb") as f:
        dataset = pickle.load(f)
    return dataset

def get_plot_args(dataset):

    X, true_labels_train, X_test, true_labels_test, labels, label_errors_mask, ps, py, noise_matrix, m, n = (
        dataset[key] for key in dataset.keys()
    )
    return X, true_labels_train, X_test, true_labels_test, labels, label_errors_mask, ps, py, noise_matrix, m, n

def plot_dataset(*, X, labels, true_labels, noise_matrix, label_errors_mask, unique_labels, label_to_index):
    # Plot the features with and without label noise
    x_axis, y_axis = 0, 1

    plt.figure(figsize=(10, 10))
    plt.subplot(2, 2, 1)
    plt.title("Features with label noise")
    plt.scatter(X[:, x_axis], X[:, y_axis], c=[label_to_index[tuple(label)] for label in labels])
    plt.subplot(2, 2, 2)
    plt.title("Features without label noise")
    plt.scatter(X[:, x_axis], X[:, y_axis], c=[label_to_index[tuple(label)] for label in true_labels])
    # Plot the label noise matrix
    plt.subplot(2, 2, 3)
    plt.title("Label noise matrix")
    sns.heatmap(noise_matrix, annot=True, fmt=".2f", cmap="Blues", xticklabels=unique_labels, yticklabels=unique_labels)
    plt.xlabel("True label")
    plt.ylabel("Noisy label")
    # Plot the label error mask
    plt.subplot(2, 2, 4)
    plt.title("Label error mask")
    plt.scatter(X[:, x_axis], X[:, y_axis], c=label_errors_mask)


def main(dataset_files):
    for dataset_file in dataset_files:
        dataset = load_dataset(dataset_file)
        X, true_labels_train, X_test, true_labels_test, labels, label_errors_mask, ps, py, noise_matrix, m, n = get_plot_args(dataset)
        # Binarized labels to unique label indices
        unique_labels = np.unique(np.concatenate([true_labels_train, true_labels_test]), axis=0)
        label_to_index = {tuple(label): i for i, label in enumerate(unique_labels)}
        plot_dataset(X=X, labels=labels, true_labels=true_labels_train, noise_matrix=noise_matrix, label_errors_mask=label_errors_mask, unique_labels=unique_labels, label_to_index=label_to_index) 
        plt.show()

In [None]:
main(dataset_files)