Skip to content

Commit

Permalink
Added method to prtScoreRoc to plot chance values in Pd/Pf
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethmorton committed Jun 28, 2016
1 parent 2192e72 commit 368de48
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 42 deletions.
8 changes: 4 additions & 4 deletions class/prtClassRvm.m
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,10 @@

end

function DataSetOut = runAction(Obj,DataSet)
function DataSet = runAction(Obj,DataSet)

if isempty(Obj.sparseBeta)
DataSetOut = DataSet.setObservations(nan(DataSet.nObservations,DataSet.nFeatures));
DataSet = DataSet.setObservations(nan(DataSet.nObservations,DataSet.nFeatures));
return
end

Expand All @@ -312,7 +312,7 @@
memChunkSize = max(floor(largestMatrixSize/length(Obj.sparseBeta)),1);

OutputMat = zeros(n,1);
for i = 1:memChunkSize:n;
for i = 1:memChunkSize:n
cI = i:min(i+memChunkSize,n);
cDataSet = prtDataSetClass(DataSet.X(cI,:));

Expand All @@ -321,7 +321,7 @@
OutputMat(cI) = prtRvUtilNormCdf(gram.getObservations*Obj.sparseBeta);
end

DataSetOut = prtDataSetClass(OutputMat);
DataSet.X = OutputMat;
end
end

Expand Down
26 changes: 1 addition & 25 deletions cluster/prtClusterMeanShift.m
Original file line number Diff line number Diff line change
Expand Up @@ -37,31 +37,7 @@
% plot(clusterAlgo); % Plot the results
%
% See also prtCluster, prtClusterGmm




% Copyright (c) 2013 New Folder Consulting
%
% Permission is hereby granted, free of charge, to any person obtaining a
% copy of this software and associated documentation files (the
% "Software"), to deal in the Software without restriction, including
% without limitation the rights to use, copy, modify, merge, publish,
% distribute, sublicense, and/or sell copies of the Software, and to permit
% persons to whom the Software is furnished to do so, subject to the
% following conditions:
%
% The above copyright notice and this permission notice shall be included
% in all copies or substantial portions of the Software.
%
% THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
% OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
% MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
% NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
% DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
% OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
% USE OR OTHER DEALINGS IN THE SOFTWARE.



properties (SetAccess=private)
name = 'Mean Shift Clustering'
Expand Down
16 changes: 5 additions & 11 deletions engine/dataset/prtDataSetClass.m
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@
% prtDataInterfaceCategoricalTargets for more information.
%
%








properties (Hidden = true)
plotOptions = prtDataSetClass.initializePlotOptions();
Expand Down Expand Up @@ -970,14 +963,15 @@


function varargout = plotTimeSeriesDensity(obj,featureIndices,xData,varargin)
% plotTimeSeriesDensity Plot the data set density as time series data
%
% dataSet.plotTimeSeriesDensity() plots the density of the data contained in
% dataSet as if it were a time series.

opts.linewidth = 3;
opts.faceAlpha = 0.2;
opts.quantileValue = 0.1;
opts = prtUtilAssignStringValuePairs(opts,varargin{:});
% plotAsTimeSeries Plot the data set as time series data
%
% dataSet.plotAsTimeSeries() plots the data contained in
% dataSet as if it were a time series.

if ~obj.isLabeled
obj = obj.setTargets(zeros(obj.nObservations,1));
Expand Down
140 changes: 140 additions & 0 deletions score/metrics/prtMetricRoc.m
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
% Undocumented single-output object for prtScoreRoc
%
properties
nTargets
nNonTargets
pd
pf
nfa
Expand Down Expand Up @@ -460,6 +462,116 @@
hold off
end

colors = prtPlotUtilClassColors(size(h,1));
for iCol = 1:size(h,1)
set(h(iCol,:),'Color',colors(iCol,:));
end


if nargout
varargout = {h};
else
varargout = {};
end

end

function varargout = plotRocChanceNan(self,varargin)
% plotH = plotRocFarChanceNan(self)

plotArgs = varargin;

holdState = ishold;

h = gobjects(length(self),2);
for i = 1:length(self)
s = self(i);

nTarget = s.nTargets;
uPd = linspace(0,1,nTarget+1);
uPdPf = s.pfAtPdValues(uPd);

uPd = uPd(:);
uPdPf = uPdPf(:);
uPdPf(isinf(uPdPf)) = 1;

xyDiff = bsxfun(@minus,[1 1],[uPdPf, uPd]);
xyDiff = bsxfun(@rdivide,xyDiff,sqrt(sum(xyDiff.^2,2)));
localRocChanceAngle = atan2(xyDiff(:,2),xyDiff(:,1));
localRocChanceAngle(isnan(localRocChanceAngle)) = 0;

% Find the equation of the chance line at each uPd point
% y = m x + b -> b = y-mx for slopes and uPd uPf points
% x = pf, y = pd;
localChanceSlope = tan(localRocChanceAngle);
localChanceYIntercepts = uPd - uPdPf.*localChanceSlope;

uPdPfChance = repmat(uPdPf(:),[1 length(uPd)]);
for iPd = 1:length(uPd)
% y = mx + b
% x = (y-b)/m
if localChanceSlope(iPd) == 0
uPdPfChance(iPd:end,iPd) = uPdPfChance(iPd-1,iPd);
else
uPdPfChance(iPd:end,iPd) = (uPd(iPd:end) - localChanceYIntercepts(iPd))/localChanceSlope(iPd);
end
end
%
% plot(uPdFar, uPd(:),'k','lineWidth',3)
% hold on, h = plot(uPdPfChance,uPd(:));
% hold off
% cmap = parula(length(h))*0.8;
% for iLine = 1:length(h)
% set(h(iLine),'color',cmap(iLine,:));
% end
% xlabel('FAR (#/m^2)');
% ylabel('PD');
% grid on
% cvrS2('randomChancePRocs');
% axis([0 1 0 1])
% cvrS2('randomChancePRocsZoom');
% %%

curveAreas = zeros(length(uPd),1);
for iPd = 1:length(uPd)
cX = uPdPfChance(:,iPd);
cY = uPd;
% if max(cX) < farCutOff
% cX = cat(1,cX,farCutOff);
% cY = cat(1,cY, cY(end));
% end

curveAreas(iPd) = trapz(cX,cY);
end

[~, stopInd] = max(curveAreas);

uPdFarThresh = uPdPf;
uPdThresh = uPd;
stopInd = max(min(stopInd,length(uPd)),1);
if ~isempty(stopInd)
uPdFarThresh(stopInd:end) = nan;
uPdThresh(stopInd:end) = nan;
end

%plot(uPdFar(1:end-1),curveAreas(2:end)/farCutOff*100), ylabel('% pAUC'), xlabel('FAR of Switching to Random (#/m^2)'), grid on, cvrS2('areaVsSwitchPoint')

h(i,1) = plot(uPdPf(:),uPd(:),plotArgs{:},'LineWidth',1);

if i==1, hold on, end

h(i,2) = plot(uPdFarThresh(:),uPdThresh(:),plotArgs{:},'LineWidth',3);

end
if ~holdState
hold off
end

colors = prtPlotUtilClassColors(size(h,1));
for iCol = 1:size(h,1)
set(h(iCol,:),'Color',colors(iCol,:));
end

if nargout
varargout = {h};
else
Expand Down Expand Up @@ -525,5 +637,33 @@

ds.X = newX;
end
function pauc = aucFar(self, maxFar)

if nargin < 2 || isempty(maxFar)
maxFar = -inf; % This will be ignored
end

pauc = zeros(size(self));
for iRoc = 1:numel(self)

nTarget = self(iRoc).nTargets;
uPd = linspace(0,1,nTarget+1);
uPdFar = self(iRoc).farAtPdValues(uPd);

cX = uPdFar(:);
cY = uPd(:);

if max(cX) < maxFar
cX = cat(1,cX,maxFar);
cY = cat(1,cY, cY(end));
end

keep = cX <= maxFar;
cX = cX(keep);
cY = cY(keep);

pauc(iRoc) = trapz(cX,cY);
end
end
end
end
4 changes: 2 additions & 2 deletions score/prtScoreRoc.m
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@
thresholds = cat(1,inf,sortedDS(:));
varargout = {pFa,pD,thresholds,auc};
if nargout == 1
varargout{1} = prtMetricRoc('pf',pFa,'pd',pD,'nfa',nFa,'tau',thresholds,'auc',auc);
varargout{1} = prtMetricRoc('pf',pFa,'pd',pD,'nfa',nFa,'tau',thresholds,'auc',auc,'nTargets', nH1,'nNonTargets',nH0);
end
if inputs.outputStructure
varargout{1} = struct('pf',pFa,'pd',pD,'tau',thresholds,'auc',auc);
varargout{1} = struct('pf',pFa,'pd',pD,'tau',thresholds,'auc',auc,'nTargets', nH1,'nNonTargets',nH0);
end
end

Expand Down

0 comments on commit 368de48

Please sign in to comment.