Skip to content
Permalink
Browse files

Updated graph legends to original, random, continued.

  • Loading branch information
federicozaiter committed Jun 22, 2019
1 parent f6b821d commit cbe07dc7395a2484fb135212d771c792caebc9f3
Showing with 34 additions and 6 deletions.
  1. +34 −6 analysis/visualize.py
@@ -58,7 +58,7 @@

def run(hparams):
exp_results = load_all_exp_results(hparams)
# plot_test_accuracies(hparams, exp_results)
plot_test_accuracies(hparams, exp_results)
# generate_weight_distributions(hparams, exp_results)
# plot_early_stoping(hparams, exp_results)
if hparams["table"]:
@@ -188,6 +188,9 @@ def plot_test_accuracies(
for key, value in sorted(results.items()):
if filter_ids is not None and value["sparsity"] not in filter_ids:
continue
if value["sparsity"] == "100.0":
if not "orig" in experiment:
continue
label = "{} ({})".format(
value["sparsity"],
experiment,
@@ -210,7 +213,7 @@ def plot_test_accuracies(
for metric in metrics:
fig, axes = plt.subplots(1, 3, figsize=(21, 5))

exp_filter = np.logical_or(test_acc['Experiment'].str.contains("orig"), test_acc['Experiment'] == "100.0")
exp_filter = np.logical_or(test_acc['Experiment'].str.contains("orig"), test_acc['Experiment'].str.contains("100"))
left_data = test_acc[exp_filter]
left = sns.lineplot(
x="batch", y=metric["metric"], data=left_data, ax=axes[0], hue="Experiment", style="Experiment", dashes=dashes, palette=palette
@@ -220,7 +223,7 @@ def plot_test_accuracies(
left.set(xlabel="Training Iterations", ylabel=metric["label"])
left.get_legend().remove()

exp_filter = np.logical_and(np.logical_not(test_acc['Experiment'].str.contains("none")), test_acc['Experiment'] != "100.0")
exp_filter = np.logical_and(np.logical_not(test_acc['Experiment'].str.contains("none")), np.logical_not(test_acc['Experiment'].str.contains("100")))
right_data = test_acc[exp_filter]
right = sns.lineplot(
x="batch",
@@ -237,7 +240,7 @@ def plot_test_accuracies(
right.set(xlabel="Training Iterations", ylabel=metric["label"])
right.get_legend().remove()

exp_filter = np.logical_and(np.logical_not(test_acc['Experiment'].str.contains("rand")), test_acc['Experiment'] != "100.0")
exp_filter = np.logical_and(np.logical_not(test_acc['Experiment'].str.contains("rand")), np.logical_not(test_acc['Experiment'].str.contains("100")))
third_data = test_acc[exp_filter]
third = sns.lineplot(
x="batch",
@@ -257,7 +260,7 @@ def plot_test_accuracies(

def parse_legend(x):
return float(x[0].split(" ")[0]) + (
100 if "orig" not in x[0] else 0
0 if "100" in x[0] else 100 if "orig" in x[0] else 200 if "rand" in x[0] else 300 if "none" in x[0] else 400
)

left_handles, left_labels = left.get_legend_handles_labels()
@@ -266,6 +269,21 @@ def parse_legend(x):
handles = left_handles[1:] + right_handles[1:] + third_handles[1:]
legend_labels = left_labels[1:] + right_labels[1:] + third_labels[1:]

def update_label(legend_label):
if "100" in legend_label:
return "100"
label = legend_label.split(" ")
if "orig" in label[1]:
label[1] = "(original)"
elif "rand" in label[1]:
label[1] = "(random)"
elif "none" in label[1]:
label[1] = "(continued)"
legend_label = " ".join(label)
return legend_label

legend_labels = [update_label(label) for label in legend_labels]

by_label = collections.OrderedDict(
sorted(zip(legend_labels, handles), key=parse_legend)
)
@@ -392,7 +410,6 @@ def produce_tables(
def get_early_stop_iteration(value):
losses = value["valid_acc_log"]["loss"]
target_iteration = pd.Series(losses).idxmin()
# early_stop_iteration = value["valid_acc_log"]["batch"][target_iteration]
return target_iteration


@@ -473,6 +490,17 @@ def plot_early_stoping(
handles = left_handles[1:] + right_handles[1:] + third_handles[1:]
legend_labels = left_labels[1:] + right_labels[1:] + third_labels[1:]

def update_label(legend_label):
if "orig" in legend_label:
legend_label = "original"
elif "rand" in legend_label:
legend_label = "random"
elif "none" in legend_label:
legend_label = "continued"
return legend_label

legend_labels = [update_label(label) for label in legend_labels]

by_label = dict(
zip(legend_labels, handles)
)

0 comments on commit cbe07dc

Please sign in to comment.
You can’t perform that action at this time.