Skip to content

Commit

Permalink
Plotting ROC envelopes
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanislav Nikolov committed Aug 11, 2012
1 parent 26f41e1 commit 7caee0e
Showing 1 changed file with 87 additions and 21 deletions.
108 changes: 87 additions & 21 deletions roc.py
Expand Up @@ -4,7 +4,10 @@
import pickle
import pprint
import string
import matplotlib.patches as patches

from math import exp
from matplotlib.path import Path
from operator import attrgetter
from params import *

Expand All @@ -28,13 +31,13 @@ def roc(res_path):
paramsets = sorted(paramsets)
stats = [ stats[sorted_indices[i]] for i in range(len(stats)) ]
"""
save_fig = True
save_fig = False
plot = True
pnt = False
if plot:
plt.close('all')
plt.ion()
plt.figure()
fig = plt.figure()
plt.show()

if save_fig:
Expand Down Expand Up @@ -66,9 +69,14 @@ def roc(res_path):
indices_sorted = [ elt[0] for elt in enum_sorted ]
statsets_sorted = [ statsets[indices_sorted[i]] for i in range(len(statsets)) ]

# Initialize variables for each ROC curve. These will be reset when a new
# curve is ready to be drawn.
var_attr_count = 0
var_attr_values = []
anything_plotted = False
mean_fprs = [0,1]
mean_tprs = [0,1]
std_fprs = []
std_tprs = []

for psi in xrange(len(paramsets_sorted)):
# Take the difference between the numerical values.
Expand Down Expand Up @@ -97,18 +105,82 @@ def roc(res_path):
print 'New experiment!'

if plot:
plt.hold(False)
if save_fig:
if anything_plotted:
if len(mean_fprs) > 2:
# There were Points other than the manually added (0,0) and (1,1).
plt.savefig(os.path.join('fig', var_attr,
const_attr_str) + '.png')
else:
plt.draw()
raw_input()

# +-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~
# | PLOT SCATTER
# +-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~
plot_scatter = False
if plot_scatter and len(mean_fprs) > 2:
plt.scatter(mean_fprs, mean_tprs,
s = 20 * (var_attr_count + 0.5), c = 'k')
plt.hold(True)
# Don't take the first and last if we've put a dummy 0 and 1 at
# each end of the means lists.
plt.errorbar(mean_fprs[1:-1], mean_tprs[1:-1], xerr = std_fprs,
yerr = std_tprs, color = 'k', linestyle = 'None')
plt.title(const_attr_str + '\n' + var_attr + '=' + \
str(var_attr_values),
fontsize = 11)
plt.xlim([-0.1,1.1])
plt.ylim([-0.1,1.1])
plt.draw()
#raw_input()

#plt.hold(True)
# Sort points from left to right.

# +-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~
# | PLOT LINES
# +-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~-~
plot_curves = True
if plot_curves:
mean_fprs_ltor_enum = sorted(enumerate(mean_fprs),
key = lambda x:x[1])
mean_fprs_ltor = [ mean_fprs[i]
for (i,v) in mean_fprs_ltor_enum ]
mean_tprs_ltor = [ mean_tprs[i]
for (i,v) in mean_fprs_ltor_enum ]
num_unique = len(set([ (mean_fprs_ltor[i], mean_tprs_ltor[i])
for i in range(len(mean_fprs_ltor)) ]))
if num_unique > 3:
# Plot bezier curves.
plot_bezier = False
if plot_bezier:
verts = [ (mean_fprs_ltor[i], mean_tprs_ltor[i])
for i in range(len(mean_fprs_ltor)) ]
codes = [ Path.CURVE4 ] * (len(verts) - 1)
codes.insert(0, Path.MOVETO)
path = Path(verts, codes)
ax = plt.gca()
patch = patches.PathPatch(path, facecolor='none', lw=2)
# Manually clear axes. This isn't a plotting command, so hold
# = False has no effect.
if not plt.ishold():
plt.cla()
ax.add_patch(patch)
plot_lines = True
if plot_lines:
# Plot lines.
plt.plot(mean_fprs_ltor, mean_tprs_ltor, color = 'k',
linewidth = 1)
plt.draw()
plt.draw()
plt.xlim([-0.1,1.1])
plt.ylim([-0.1,1.1])
plt.hold(True)

# Reset variables for next ROC curve.
mean_fprs = [1,0]
mean_tprs = [1,0]
std_fprs = []
std_tprs = []
var_attr_count = 0
var_attr_values = []
anything_plotted = False

var_attr_values.append(curr_params._asdict()[var_attr])
if plot:
Expand All @@ -121,22 +193,16 @@ def roc(res_path):
if pnt:
print fprs, tprs
if fprs and tprs:
anything_plotted = True
mfprs = np.mean(fprs)
mtprs = np.mean(tprs)
sfprs = np.std(fprs)
stprs = np.std(tprs)
plt.scatter(mfprs, mtprs, s = 20 * (var_attr_count + 0.5))
plt.hold(True)
plt.errorbar(mfprs, mtprs, xerr = sfprs, yerr = stprs)
else:
# Make an empty plot, so we know there was no data.
plt.scatter([],[])
plt.xlim([-0.1,1.1])
plt.ylim([-0.1,1.1])
plt.hold(True)
plt.title(const_attr_str + '\n' + var_attr + '=' + str(var_attr_values),
fontsize = 11)

# Record these so we plot them before moving on to the next ROC curve.
mean_fprs.insert(1, mfprs)
mean_tprs.insert(1, mtprs)
std_fprs.append(sfprs)
std_tprs.append(stprs)

var_attr_count += 1

Expand Down

0 comments on commit 7caee0e

Please sign in to comment.