Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions examples/40_paper/2018_ida_strang_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,10 @@
# adds column that indicates the difference between the two classifiers
evaluations['diff'] = evaluations[flow_ids[0]] - evaluations[flow_ids[1]]


##############################################################################
# makes the s-plot

fig_splot, ax_splot = plt.subplots()
ax_splot.plot(range(len(evaluations)), sorted(evaluations['diff']))
ax_splot.set_title(classifier_family)
Expand All @@ -71,7 +74,10 @@
plt.show()


# adds column that indicates the difference between the two classifiers
##############################################################################
# adds column that indicates the difference between the two classifiers,
# needed for the scatter plot

def determine_class(val_lin, val_nonlin):
if val_lin < val_nonlin:
return class_values[0]
Expand All @@ -84,7 +90,7 @@ def determine_class(val_lin, val_nonlin):
evaluations['class'] = evaluations.apply(
lambda row: determine_class(row[flow_ids[0]], row[flow_ids[1]]), axis=1)

# makes the scatter plot
# does the plotting and formatting
fig_scatter, ax_scatter = plt.subplots()
for class_val in class_values:
df_class = evaluations[evaluations['class'] == class_val]
Expand All @@ -98,3 +104,17 @@ def determine_class(val_lin, val_nonlin):
ax_scatter.set_xscale('log')
ax_scatter.set_yscale('log')
plt.show()

##############################################################################
# makes a scatter plot where each data point represents the performance of the
# two algorithms on various axis (not in the paper)

fig_diagplot, ax_diagplot = plt.subplots()
ax_diagplot.grid(linestyle='--')
ax_diagplot.plot([0, 1], ls="-", color="black")
ax_diagplot.plot([0.2, 1.2], ls="--", color="black")
ax_diagplot.plot([-0.2, 0.8], ls="--", color="black")
ax_diagplot.scatter(evaluations[flow_ids[0]], evaluations[flow_ids[1]])
ax_diagplot.set_xlabel(measure)
ax_diagplot.set_ylabel(measure)
plt.show()