Skip to content
Permalink
Browse files

init analysis

  • Loading branch information
justincosentino committed Jun 11, 2019
1 parent 002d56d commit bbb6c7a4225f46a54f6d7f7aea36ddf165ec30c4
Showing with 382 additions and 1 deletion.
  1. +2 −1 __init__.py
  2. +1 −0 analysis/__init__.py
  3. +213 −0 analysis/visualize.py
  4. +166 −0 run_analysis.py
@@ -1,5 +1,6 @@
__all__ = ["attacks", "data", "experiments", "models"]
__all__ = ["analysis", "attacks", "data", "experiments", "models"]

from .analysis import *
from .attacks import *
from .data import *
from .experiments import *
@@ -0,0 +1 @@
__all__ = ["visualize"]
@@ -0,0 +1,213 @@
import collections
import os


import numpy as np
import pandas as pd

import matplotlib

matplotlib.use("Agg")
from matplotlib import pyplot as plt
import seaborn as sns


from ..experiments import path, utils


sns.set_style(
"ticks",
{
"axes.grid": True,
"font.family": ["serif"],
"text.usetex": True,
"legend.frameon": False,
},
)
sns.set_palette("deep")


YLIMS = {
"pgd": {"digits": (0.25, 0.65), "fashion": (0.0, 0.40)},
"fgsm": {"digits": (0.80, 0.99), "fashion": (0.70, 0.9)},
}


def run(hparams):
exp_results = load_all_exp_results(hparams)
plot_test_accuracies(hparams, exp_results)


def generate_weight_distributions():
pass


def get_masks_stats(masks):
total_weights = sum([v.size for k, v in masks.items()])
active_weights = sum([int(np.sum(v)) for k, v in masks.items()])
return active_weights, total_weights


def load_all_exp_results(hparams):
"""
Load the valid/test logs, init kernels, post kernels, and masks for each
each pruning iteration of trial of each experiment.
"""
all_results = {}
for experiment, experiment_dir in hparams["base_dirs"].items():
results = collections.defaultdict(lambda: collections.defaultdict(int))
for trial_dir in trial_iterator(hparams, experiment):
for prune_iter_dir in prune_iter_iterator(hparams, experiment, trial_dir):
path = os.path.join(experiment_dir, trial_dir, prune_iter_dir)
masks_path = os.path.join(path, "masks")
init_kernels_path = os.path.join(path, "init_kernels")
post_kernels_path = os.path.join(path, "post_kernels")

# TODO: remove. try-catch here so that we can run on live exps
try:
masks = utils.restore_array(masks_path)
init_kernels = utils.restore_array(init_kernels_path)
post_kernels = utils.restore_array(post_kernels_path)
except:
continue

active_weights, total_weights = get_masks_stats(masks)
key = "{}/{}".format(trial_dir, prune_iter_dir)
results[key]["sparsity"] = "{:>04.1f}".format(
active_weights / total_weights * 100
)

valid_acc_log = pd.read_csv(os.path.join(path, "valid.csv"))
test_acc_log = pd.read_csv(os.path.join(path, "test.csv"))
results[key]["valid_acc_log"] = valid_acc_log
results[key]["test_acc_log"] = test_acc_log
results[key]["masks"] = masks
results[key]["init_kernels"] = init_kernels
results[key]["post_kernels"] = post_kernels

for key, value in sorted(results.items()):
print(
"{}: {} -> {:6.3f} | {:6.3f}".format(
key,
value["sparsity"],
value["test_acc_log"]["acc"].iloc[-1],
value["test_acc_log"]["adv_acc"].iloc[-1],
)
)

all_results[experiment] = results

return all_results


def plot_test_accuracies(
hparams,
exp_results,
metrics=[
{"metric": "acc", "label": "Test Accuracy"},
{"metric": "adv_acc", "label": "Adversarial Test Accuracy"},
],
filter_ids=["01.8", "03.6", "08.7", "16.9", "51.3", "100"],
):
# Average unpruned results
unpruned_test_acc = []
for experiment, results in exp_results.items():
for trial in trial_iterator(hparams, experiment):
unpruned_test_acc.append(
results["{}/prune_iter_00".format(trial)]["test_acc_log"]
)
del results["{}/prune_iter_00".format(trial)]
unpruned_test_acc = pd.concat(unpruned_test_acc).groupby(level=0).mean()

reinit_label = " (reinit)"
for metric in metrics:
accs = {}
for experiment, results in exp_results.items():
for key, value in sorted(results.items()):
if filter_ids is not None and value["sparsity"] not in filter_ids:
continue
label = "{}{}".format(
value["sparsity"],
reinit_label if experiment == "reinit_rand" else "",
)
accs[label] = value["test_acc_log"][metric["metric"]]
accs["100"] = unpruned_test_acc[metric["metric"]]

current_palette = sns.color_palette()
palette = {
k: current_palette[filter_ids.index(k.strip(reinit_label))] for k in accs
}
dashes = {k: (1, 1) if (reinit_label in k or k == "100") else "" for k in accs}

accs["iterations"] = value["test_acc_log"]["batch"]
data_frame = pd.DataFrame.from_dict(accs).set_index("iterations")

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

left = sns.lineplot(
data=data_frame[filter_ids], ax=axes[0], dashes=dashes, palette=palette
)
left.set(xlim=(0, 30000))
left.set(ylim=YLIMS[hparams["attack"]][hparams["dataset"]])
left.set(xlabel="Training Iterations", ylabel=metric["label"])
left.get_legend().remove()

right = sns.lineplot(
data=data_frame.loc[:, data_frame.columns != "100"],
ax=axes[1],
dashes=dashes,
palette=palette,
)
right.set(xlim=(0, 30000))
right.set(ylim=YLIMS[hparams["attack"]][hparams["dataset"]])
right.set(xlabel="Training Iterations", ylabel=metric["label"])
right.get_legend().remove()

def parse_legend(x):
return float(x[0].strip(reinit_label)) + (
100 if reinit_label in x[0] else 0
)

left_handles, left_labels = left.get_legend_handles_labels()
right_handles, right_labels = right.get_legend_handles_labels()
handles = left_handles + right_handles
labels = left_labels + right_labels
by_label = collections.OrderedDict(
sorted(zip(labels, handles), key=parse_legend)
)
fig.legend(
by_label.values(),
by_label.keys(),
bbox_to_anchor=(0, 1.0, 1.0, 0.05),
loc="lower center",
ncol=11,
mode="expand",
borderaxespad=0.0,
frameon=False,
)
plt.tight_layout()

file_name = "test_{}.svg".format(metric["metric"])
file_path = os.path.join(hparams["analysis_dir"], file_name)

fig.savefig(file_path, format="svg", bbox_inches="tight")
plt.clf()
print("Saving figure to ", file_path)


def trial_iterator(hparams, experiment):
with os.scandir(hparams["base_dirs"][experiment]) as it:
for entry in it:
if entry.name.startswith(".") or entry.name == "analysis":
continue
if entry.is_dir():
yield entry.name


def prune_iter_iterator(hparams, experiment, trial_dir):
with os.scandir(os.path.join(hparams["base_dirs"][experiment], trial_dir)) as it:
for entry in it:
if entry.name.startswith("."):
continue
if entry.is_dir():
yield entry.name
@@ -0,0 +1,166 @@
"""Run anlysis with CLI"""
import argparse
import os
import shutil

from .analysis import visualize

EXPERIMENTS = ["reinit_rand", "reinit_orig"]


def init_flags():
"""Init command line flags used for experiment configuration."""
parser = argparse.ArgumentParser(
description="Runs analysis on results generated by run_experiments.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dataset",
metavar="dataset",
type=str,
nargs=1,
default=["digits"],
choices=["digits", "fashion"],
help="source dataset",
)
parser.add_argument(
"--model",
metavar="model",
type=str,
nargs=1,
default=["dense-300-100"],
choices=["dense-300-100"],
help="model type",
)
base_dir_default = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "output"
)
parser.add_argument(
"--base_dir",
metavar="base_dir",
type=str,
nargs=1,
default=[base_dir_default],
help="base output directory for results and checkpoints",
)
parser.add_argument(
"--attack",
metavar="attack",
type=str,
nargs=1,
default=["fgsm"],
choices=["fgsm", "pgd"],
help="adversarial attack to analyze",
)
parser.add_argument(
"--adv_train",
action="store_true",
default=False,
help="whether or not adversarial training was used for the given attack method",
)
parser.add_argument(
"-lr",
"--learning_rate",
metavar="learning_rate",
type=float,
nargs=1,
default=[0.0012],
help="model's learning rate",
)
parser.add_argument(
"-l1",
"--l1_reg",
metavar="l1_reg",
type=float,
nargs=1,
default=[0.0],
help="l1 regularization penalty",
)
parser.add_argument(
"--force",
action="store_true",
default=False,
help="force analysis, deleting old anlysis dirs if existing.",
)
return parser.parse_args()


def parse_args(args):
"""Parse provided args for runtime configuration."""
hparams = {
"dataset": args.dataset[0],
"model": args.model[0],
"attack": args.attack[0],
"adv_train": args.adv_train,
"base_dirs": {},
"learning_rate": args.learning_rate[0],
"l1_reg": args.l1_reg[0],
"force": args.force,
"experiments": EXPERIMENTS,
}
exp_dir = "lr-{}_l1-{}_advtrain-{}".format(
hparams["learning_rate"], hparams["l1_reg"], str(hparams["adv_train"]).lower()
)
for experiment in hparams["experiments"]:
hparams["base_dirs"][experiment] = os.path.join(
args.base_dir[0],
args.dataset[0],
args.model[0],
experiment,
args.attack[0],
exp_dir,
)

hparams["analysis_dir"] = os.path.join(
args.base_dir[0],
args.dataset[0],
args.model[0],
"analysis",
args.attack[0],
exp_dir,
)
print("-" * 40, "hparams", "-" * 40)
print("Beginning anlysis for the following experiments:\n")
for param, value in hparams.items():
if param == "base_dirs":
print("\t{:>13}:".format(param))
for exp, exp_dir in value.items():
print("\t\t{:>13}: {}".format(exp, exp_dir))
else:
print("\t{:>13}: {}".format(param, value))

print()
print("-" * 89)
return hparams


def main():
"""Parses command line arguments and runs the specified analysis."""

# Init hparams
hparams = parse_args(init_flags())

# Check if base_dir already exists, fail if not
for experiment in hparams["experiments"]:
if not os.path.exists(hparams["base_dirs"][experiment]):
raise Exception(
"directory '{} does not exist. ".format(
hparams["base_dirs"][experiment]
)
)
if os.path.exists(hparams["analysis_dir"]) and not hparams["force"]:
raise Exception(
"directory '{} already exists. ".format(hparams["analysis_dir"])
+ "Run with --force to overwrite."
)
if os.path.exists(hparams["analysis_dir"]):
shutil.rmtree(hparams["analysis_dir"])
os.makedirs(hparams["analysis_dir"])

visualize.run(hparams)

# TODO: we need to run per-trial anlysis for network structure (ie weight magnitudes, etc. )


if __name__ == "__main__":
main()

0 comments on commit bbb6c7a

Please sign in to comment.
You can’t perform that action at this time.