In [None]:
# stdlib
import warnings
from typing import Union

# third party
import numpy as np
import pandas as pd
import pytest
from scipy import stats
from scipy.stats import multivariate_normal
from sdv.tabular import TVAE
from sklearn.datasets import fetch_california_housing, fetch_covtype, load_digits
from sklearn.preprocessing import StandardScaler

# domias absolute
from domias.evaluator import evaluate_performance
from domias.models.ctgan import CTGAN
from domias.models.generator import GeneratorInterface

warnings.filterwarnings("ignore")


def get_dataset() -> np.ndarray:
    def data_loader() -> np.ndarray:
        scaler = StandardScaler()
        X = fetch_california_housing().data
        np.random.shuffle(X)
        return scaler.fit_transform(X)

    return data_loader()

In [None]:
def get_generator(
    gan_method: str = "TVAE",
    epochs: int = 100,
    seed: int = 0,
) -> GeneratorInterface:
    class LocalGenerator(GeneratorInterface):
        def __init__(self) -> None:
            if gan_method == "TVAE":
                syn_model = TVAE(epochs=epochs)
            elif gan_method == "CTGAN":
                syn_model = CTGAN(epochs=epochs)
            elif gan_method == "KDE":
                syn_model = None
            else:
                raise RuntimeError()
            self.method = gan_method
            self.model = syn_model

        def fit(self, data: pd.DataFrame) -> "LocalGenerator":
            if self.method == "KDE":
                self.model = stats.gaussian_kde(np.transpose(data))
            else:
                self.model.fit(data)

            return self

        def generate(self, count: int) -> pd.DataFrame:
            if gan_method == "KDE":
                samples = pd.DataFrame(self.model.resample(count).transpose(1, 0))
            elif gan_method == "TVAE":
                samples = self.model.sample(count)
            elif gan_method == "CTGAN":
                samples = self.model.generate(count)
            else:
                raise RuntimeError()

            return samples

    return LocalGenerator()

In [None]:
dataset = get_dataset()

gen_size = 10000
held_out_size = 10000
training_epochs = [100, 500, 1000, 2000, 3000]
training_sizes = [100, 500, 1000, 3000]

method = "TVAE"
density_estimator = "prior"  # prior, bnaf, kde

results = {}
for training_size in training_sizes:
    results[training_size] = {}
    for training_epoch in training_epochs:
        generator = get_generator(
            gan_method=method,
            epochs=training_epoch,
        )
        perf = evaluate_performance(
            generator,
            dataset,
            training_size,
            held_out_size,
            training_epoch,
            synthetic_sizes=[gen_size],
        )

        print(
            f"""
                SIZE_PARAM = {training_size} ADDITION_SIZE  = {held_out_size} TRAINING_EPOCH = {training_epoch}
                    metrics = {perf[gen_size]["MIA_performance"]}
            """
        )

        results[training_size][training_epoch] = perf[gen_size]

In [None]:
results

In [None]:
# third party
import cloudpickle

with open("experiment_1_results.bkp", "wb") as f:
    cloudpickle.dump(results, f)

## AUC by the number of iterations

In [None]:
training_size = training_sizes[-1]

output = pd.DataFrame([], columns=["epoch", "src", "aucroc"])
for training_epoch in training_epochs:
    epoch_res = results[training_size][training_epoch]
    perf = epoch_res["MIA_performance"]

    for key in perf:
        output = pd.concat(
            [
                output,
                pd.DataFrame(
                    [
                        [training_epoch, key, perf[key]["aucroc"]],
                    ],
                    columns=["epoch", "src", "aucroc"],
                ),
            ]
        )

output

In [None]:
# third party
import seaborn as sns

sns.lineplot(output, x="epoch", y="aucroc", hue="src")

## Accuracy by number of iterations

In [None]:
training_size = training_sizes[-1]

output = pd.DataFrame([], columns=["epoch", "src", "accuracy"])
for training_epoch in training_epochs:
    epoch_res = results[training_size][training_epoch]
    perf = epoch_res["MIA_performance"]

    for key in perf:
        output = pd.concat(
            [
                output,
                pd.DataFrame(
                    [
                        [training_epoch, key, perf[key]["accuracy"]],
                    ],
                    columns=["epoch", "src", "accuracy"],
                ),
            ]
        )

output

In [None]:
# third party
import seaborn as sns

sns.lineplot(output, x="epoch", y="accuracy", hue="src")

## AUC by the training dataset size

In [None]:
training_epoch = training_epochs[-1]

output = pd.DataFrame([], columns=["training_size", "src", "aucroc"])

for training_size in training_sizes:
    epoch_res = results[training_size][training_epoch]
    perf = epoch_res["MIA_performance"]

    for key in perf:
        output = pd.concat(
            [
                output,
                pd.DataFrame(
                    [
                        [training_size, key, perf[key]["aucroc"]],
                    ],
                    columns=["training_size", "src", "aucroc"],
                ),
            ]
        )

output

In [None]:
# third party
import seaborn as sns

sns.lineplot(output, x="training_size", y="aucroc", hue="src")

## Accuracy by the training dataset size

In [None]:
training_epoch = training_epochs[-1]

output = pd.DataFrame([], columns=["training_size", "src", "accuracy"])

for training_size in training_sizes:
    epoch_res = results[training_size][training_epoch]
    perf = epoch_res["MIA_performance"]

    for key in perf:
        output = pd.concat(
            [
                output,
                pd.DataFrame(
                    [
                        [training_size, key, perf[key]["accuracy"]],
                    ],
                    columns=["training_size", "src", "accuracy"],
                ),
            ]
        )

output

In [None]:
# third party
import seaborn as sns

sns.lineplot(output, x="training_size", y="accuracy", hue="src")