Skip to content
Permalink
Browse files

Added normal and adv accuracy to the early stopping graphs and consid…

…er reinit_none experiment in the analysis.
  • Loading branch information
federicozaiter committed Jun 20, 2019
1 parent 5542e35 commit 0438bd8c025d32ac03a4e1351cad087e55a73617
Showing with 48 additions and 10 deletions.
  1. +47 −9 analysis/visualize.py
  2. +1 −1 run_analysis.py
@@ -341,8 +341,8 @@ def produce_tables(
def get_early_stop_iteration(value):
losses = value["valid_acc_log"]["loss"]
target_iteration = pd.Series(losses).idxmin()
early_stop_iteration = value["valid_acc_log"]["batch"][target_iteration]
return early_stop_iteration
# early_stop_iteration = value["valid_acc_log"]["batch"][target_iteration]
return target_iteration


def plot_early_stoping(
@@ -354,13 +354,15 @@ def plot_early_stoping(
for experiment, results in exp_results.items():
for trial in trial_iterator(hparams, experiment):
for label in exp_results:
target_iteration = get_early_stop_iteration(results["{}/prune_iter_00".format(trial)])
unpruned_early_stop_iter.append(
(
"100.0",
label,
get_early_stop_iteration(
results["{}/prune_iter_00".format(trial)]
),
results["{}/prune_iter_00".format(trial)]["valid_acc_log"]["batch"][target_iteration],
# Adv accuracy for eaerly stop iteration
results["{}/prune_iter_00".format(trial)]["valid_acc_log"]["adv_acc"][target_iteration],
results["{}/prune_iter_00".format(trial)]["valid_acc_log"]["acc"][target_iteration],
)
)
del results["{}/prune_iter_00".format(trial)]
@@ -374,23 +376,27 @@ def plot_early_stoping(
continue
label = experiment
sparsity = value["sparsity"]
early_stop_iteration = get_early_stop_iteration(value)
early_stop_iter.append((sparsity, label, early_stop_iteration))
target_iteration = get_early_stop_iteration(value)
early_stop_iteration = value["valid_acc_log"]["batch"][target_iteration]
early_stop_adv_acc = value["valid_acc_log"]["adv_acc"][target_iteration]
early_stop_acc = value["valid_acc_log"]["acc"][target_iteration]
early_stop_iter.append((sparsity, label, early_stop_iteration, early_stop_adv_acc, early_stop_acc))

early_stop_iter.extend(unpruned_early_stop_iter)

data_frame = pd.DataFrame(
early_stop_iter, columns=["Sparsity", "Experiment", "Iteration"]
early_stop_iter, columns=["Sparsity", "Experiment", "Iteration", "Adversarial_Accuracy", "Test_Accuracy"]
)
sorted_index = pd.Series.argsort(data_frame["Sparsity"].astype(float))[::-1]
data_frame = data_frame.iloc[sorted_index]

fig, axes = plt.subplots(figsize=(14, 5))
fig, axes = plt.subplots(1, 3, figsize=(14, 5))

left = sns.lineplot(
x="Sparsity",
y="Iteration",
hue="Experiment",
ax=axes[0],
ci="sd",
err_style="bars",
err_kws={"capsize": 3},
@@ -401,6 +407,38 @@ def plot_early_stoping(
xlabel="Percent of Weights Remaining", ylabel="Early Stop Iteration (Val.)"
)

right = sns.lineplot(
x="Sparsity",
y="Adversarial_Accuracy",
hue="Experiment",
ax=axes[1],
ci="sd",
err_style="bars",
err_kws={"capsize": 3},
sort=False,
data=data_frame,
)
right.set(
xlabel="Percent of Weights Remaining", ylabel="Early Stop Adv. Acc. (Val.)"
)
right.set(ylim=YLIMS[hparams["attack"]][hparams["dataset"]])

third = sns.lineplot(
x="Sparsity",
y="Test_Accuracy",
hue="Experiment",
ax=axes[2],
ci="sd",
err_style="bars",
err_kws={"capsize": 3},
sort=False,
data=data_frame,
)
third.set(
xlabel="Percent of Weights Remaining", ylabel="Early Stop Test Acc. (Val.)"
)
third.set(ylim=YLIMS[hparams["attack"]][hparams["dataset"]])

plt.tight_layout()

file_name = "test_{}.svg".format("early_stop_iteration")
@@ -5,7 +5,7 @@

from .analysis import visualize

EXPERIMENTS = ["reinit_rand", "reinit_orig"] #["no_pruning", "reinit_rand", "reinit_orig", "reinit_none"]
EXPERIMENTS = ["reinit_rand", "reinit_orig", "reinit_none"]#["no_pruning", "reinit_rand", "reinit_orig", "reinit_none"]


def init_flags():

0 comments on commit 0438bd8

Please sign in to comment.
You can’t perform that action at this time.