Skip to content

Commit

Permalink
Hf/top confidence fix (#1712)
Browse files Browse the repository at this point in the history
* Use stable ordering when getting top confidence detections.
  • Loading branch information
brosenberg42 committed Oct 26, 2023
1 parent aadfac9 commit 003bbaf
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.mitre.mpf.wfm.enums.MpfConstants;
import org.mitre.mpf.wfm.util.AggregateJobPropertiesUtil;
import org.mitre.mpf.wfm.util.JsonUtils;
import org.mitre.mpf.wfm.util.TopConfidenceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
Expand Down Expand Up @@ -320,16 +321,13 @@ private SortedSet<JsonDetectionOutputObject> processExtractionsInTrack(BatchJob
topConfidenceCount);
}
if (topConfidenceCount > 0) {
// Sort the detections by confidence, then by frame number, if two detections have equal
// confidence. The sort by confidence is reversed so that the N highest confidence
// detections are at the start of the list.
sortedDetections.sort(
Comparator.comparing(Detection::getConfidence).reversed().thenComparing(Comparator.naturalOrder()));
int extractCount = Math.min(topConfidenceCount, sortedDetections.size());
for (int i = 0; i < extractCount; i++) {
LOG.debug("Will extract frame #{} with confidence = {}", sortedDetections.get(i).getMediaOffsetFrame(),
sortedDetections.get(i).getConfidence());
framesToExtract.add(sortedDetections.get(i).getMediaOffsetFrame());
var topConfidenceDetections = TopConfidenceUtil.getTopConfidenceDetections(
sortedDetections, topConfidenceCount);
for (var detection : topConfidenceDetections) {
LOG.debug("Will extract frame #{} with confidence = {}",
detection.getMediaOffsetFrame(),
detection.getConfidence());
framesToExtract.add(detection.getMediaOffsetFrame());
}
}
// For each frame to be extracted, set the artifact extraction status in the original detection and convert it to a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.mitre.mpf.wfm.enums.StreamingEndpoints;
import org.mitre.mpf.wfm.enums.StreamingJobStatusType;
import org.mitre.mpf.wfm.util.ProtobufDataFormatFactory;
import org.mitre.mpf.wfm.util.TopConfidenceUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;
Expand Down Expand Up @@ -152,9 +153,8 @@ private static JsonStreamingTrackOutputObject convertProtobufTrack(
.map(StreamingJobRoutesBuilder::convertDetection)
.collect(toCollection(TreeSet::new));

JsonStreamingDetectionOutputObject exemplar = detections.stream()
.max(Comparator.comparingDouble(JsonStreamingDetectionOutputObject::getConfidence))
.orElse(null);
JsonStreamingDetectionOutputObject exemplar = TopConfidenceUtil.getTopConfidenceItem(
detections, JsonStreamingDetectionOutputObject::getConfidence);

return new JsonStreamingTrackOutputObject(
Integer.toString(id),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.mitre.mpf.wfm.data.entities.transients.Detection;
import org.mitre.mpf.wfm.data.entities.transients.Track;
import org.mitre.mpf.wfm.util.MediaRange;
import org.mitre.mpf.wfm.util.TopConfidenceUtil;
import org.mitre.mpf.wfm.util.UserSpecifiedRangesUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -153,16 +154,15 @@ private static VideoRequest createFeedForwardVideoRequest(Track track, int topCo
stopFrame = track.getEndOffsetFrameInclusive();
}
else {
includedDetections = getTopConfidenceDetections(track.getDetections(),
topConfidenceCount);
includedDetections = TopConfidenceUtil.getTopConfidenceDetections(
track.getDetections(), topConfidenceCount);
var frameSummaryStats = includedDetections.stream()
.mapToInt(Detection::getMediaOffsetFrame)
.summaryStatistics();
startFrame = frameSummaryStats.getMin();
stopFrame = frameSummaryStats.getMax();
}


var protobufTrackBuilder = DetectionProtobuf.VideoTrack.newBuilder()
.setStartFrame(startFrame)
.setStopFrame(stopFrame)
Expand All @@ -187,37 +187,6 @@ private static VideoRequest createFeedForwardVideoRequest(Track track, int topCo
.build();
}



private static Collection<Detection> getTopConfidenceDetections(Collection<Detection> allDetections,
int topConfidenceCount) {
if (topConfidenceCount <= 0 || topConfidenceCount >= allDetections.size()) {
return allDetections;
}

Comparator<Detection> confidenceComparator = Comparator
.comparingDouble(Detection::getConfidence)
.thenComparing(Comparator.naturalOrder());

PriorityQueue<Detection> topDetections = new PriorityQueue<>(topConfidenceCount, confidenceComparator);

Iterator<Detection> allDetectionsIter = allDetections.iterator();
for (int i = 0; i < topConfidenceCount; i++) {
topDetections.add(allDetectionsIter.next());
}

while (allDetectionsIter.hasNext()) {
Detection detection = allDetectionsIter.next();
// Check if current detection is less than the minimum top detection so far
if (confidenceComparator.compare(detection, topDetections.peek()) > 0) {
topDetections.poll();
topDetections.add(detection);
}
}
return topDetections;
}


private static int getTopConfidenceCount(DetectionContext context) {
return context.getAlgorithmProperties()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import org.mitre.mpf.wfm.data.entities.transients.Detection;

import java.util.Comparator;
import java.util.SortedSet;

public class ExemplarPolicyUtil {
Expand All @@ -54,9 +53,7 @@ else if ("MIDDLE".equalsIgnoreCase(policy)) {
return findMiddle(begin, end, detections);
}
else {
return detections.stream()
.max(Comparator.comparingDouble(Detection::getConfidence))
.orElse(null);
return TopConfidenceUtil.getTopConfidenceItem(detections, Detection::getConfidence);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/******************************************************************************
* NOTICE *
* *
* This software (or technical data) was produced for the U.S. Government *
* under contract, and is subject to the Rights in Data-General Clause *
* 52.227-14, Alt. IV (DEC 2007). *
* *
* Copyright 2023 The MITRE Corporation. All Rights Reserved. *
******************************************************************************/

/******************************************************************************
* Copyright 2023 The MITRE Corporation *
* *
* Licensed under the Apache License, Version 2.0 (the "License"); *
* you may not use this file except in compliance with the License. *
* You may obtain a copy of the License at *
* *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, software *
* distributed under the License is distributed on an "AS IS" BASIS, *
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
* See the License for the specific language governing permissions and *
* limitations under the License. *
******************************************************************************/

package org.mitre.mpf.wfm.util;

import java.util.Collection;
import java.util.Comparator;
import java.util.PriorityQueue;
import java.util.function.ToDoubleFunction;

import org.mitre.mpf.wfm.data.entities.transients.Detection;

public class TopConfidenceUtil {

private TopConfidenceUtil() {
}


public static <T extends Comparable<T>> T getTopConfidenceItem(
Collection<T> items, ToDoubleFunction<T> confidenceGetter) {
return items.stream()
.max(getMaxConfidenceComparator(confidenceGetter))
.orElse(null);
}


public static Collection<Detection> getTopConfidenceDetections(
Collection<Detection> allDetections, int topConfidenceCount) {
if (topConfidenceCount <= 0 || topConfidenceCount >= allDetections.size()) {
return allDetections;
}

var confidenceComparator = getMaxConfidenceComparator(Detection::getConfidence);
var topDetections = new PriorityQueue<>(topConfidenceCount, confidenceComparator);

var allDetectionsIter = allDetections.iterator();
for (int i = 0; i < topConfidenceCount; i++) {
topDetections.add(allDetectionsIter.next());
}

while (allDetectionsIter.hasNext()) {
Detection detection = allDetectionsIter.next();
// Check if current detection is less than the minimum top detection so far
if (confidenceComparator.compare(detection, topDetections.peek()) > 0) {
topDetections.poll();
topDetections.add(detection);
}
}
return topDetections;
}


private static <T extends Comparable<T>>
Comparator<T> getMaxConfidenceComparator(ToDoubleFunction<T> confidenceGetter) {
return Comparator.comparingDouble(confidenceGetter)
.thenComparing(Comparator.reverseOrder());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,14 @@ public void canGetTopConfidenceCount() {
5, 0.5f,
9, 0.0f,
10, 1.0f,
11, 0.9f,
14, 0.9f);

runTest(ArtifactExtractionPolicy.ALL_TYPES,
extractionProps,
10,
detectionFramesAndConfidences,
Arrays.asList(10, 14));
Arrays.asList(10, 11));


detectionFramesAndConfidences = ImmutableMap.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.mitre.mpf.wfm.data.entities.transients.Detection;
import org.mitre.mpf.wfm.data.entities.transients.Track;
import org.mitre.mpf.wfm.util.MediaRange;
import org.mitre.mpf.wfm.util.TopConfidenceUtil;

import java.util.*;

Expand Down Expand Up @@ -454,9 +455,8 @@ protected static Track createTrack(Detection... detections) {
.max()
.getAsInt();

Detection exemplar = detectionList.stream()
.max(Comparator.comparing(Detection::getConfidence))
.get();
Detection exemplar = TopConfidenceUtil.getTopConfidenceItem(
detectionList, Detection::getConfidence);

Track track = new Track(1, 1, 1, 0, start, stop, 0, 0, "type",
exemplar.getConfidence(), detectionList, Collections.emptyMap(), "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,11 @@ public void canCreateNonFeedForwardMessages() {

List<DetectionRequest> detectionRequests = runSegmenter(media, context);

// range 2 -> 40
assertEquals(2, detectionRequests.size());
// range 2 -> 50
assertEquals(3, detectionRequests.size());
assertContainsSegment(2, 21, detectionRequests);
assertContainsSegment(22, 40, detectionRequests);
assertContainsSegment(22, 41, detectionRequests);
assertContainsSegment(42, 50, detectionRequests);

assertContainsExpectedMediaMetadata(detectionRequests);

Expand Down Expand Up @@ -221,12 +222,12 @@ public void canCreateFeedForwardMessages() {
longTrack = detectionRequests.get(0).getVideoRequest().getFeedForwardTrack();
}

assertEquals(3, longTrack.getFrameLocationsCount());
assertEquals(4, longTrack.getFrameLocationsCount());
assertContainsFrameLocation(2, longTrack);
assertContainsFrameLocation(20, longTrack);
assertContainsFrameLocation(40, longTrack);
assertEquals(2, longTrack.getStartFrame());
assertEquals(40, longTrack.getStopFrame());
assertEquals(50, longTrack.getStopFrame());

assertEquals(1, shortTrack.getFrameLocationsCount());
assertContainsFrameLocation(5, shortTrack);
Expand Down Expand Up @@ -364,7 +365,8 @@ private static Set<Track> createTestTracks() {
Track longTrack = createTrack(
createDetection(2, 2),
createDetection(20, 20),
createDetection(40, 40));
createDetection(40, 40),
createDetection(50, 20));

return ImmutableSet.of(shortTrack, longTrack);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ public void testMiddleIsMiddleElement() {
assertSame(_d52, ExemplarPolicyUtil.getExemplar("MIDDLE", 0, 104, _detections));
}

@Test
public void testEqualMaxConfidence() {
var d61 = createDetection(61, _d60.getConfidence());
var detections = ImmutableSortedSet.of(_d50, _d52, _d54, _d60, d61);
assertSame(_d60, ExemplarPolicyUtil.getExemplar("CONFIDENCE", 0, 100, detections));
}

private static Detection createDetection(int frame, double confidence) {
return new Detection(1, 1, 1, 1, (float) confidence, frame, 1, Map.of());
Expand Down

0 comments on commit 003bbaf

Please sign in to comment.