Skip to content

Commit

Permalink
Extended OverlapTracker to work in 3D
Browse files Browse the repository at this point in the history
  • Loading branch information
sjcross committed May 31, 2024
1 parent a9efd64 commit 53c4c93
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import io.github.mianalysis.mia.module.Category;
import io.github.mianalysis.mia.module.Module;
import io.github.mianalysis.mia.module.Modules;
import io.github.mianalysis.mia.module.objects.relate.trackmate.tracking.OverlapTracker3DFactory;
import io.github.mianalysis.mia.object.Obj;
import io.github.mianalysis.mia.object.Objs;
import io.github.mianalysis.mia.object.Workspace;
Expand Down Expand Up @@ -476,36 +477,14 @@ public Status process(Workspace workspace) {

SpotCollection spotCollection = createSpotCollection(inputObjects, true);

SparseLAPTrackerFactory trackerFactory = new SparseLAPTrackerFactory();
Map<String, Object> trackerSettings = trackerFactory.getDefaultSettings();
trackerSettings.put(TrackerKeys.KEY_ALLOW_TRACK_SPLITTING, true);
trackerSettings.put(TrackerKeys.KEY_ALLOW_TRACK_MERGING, true);

// SparseLAPTrackerFactory trackerFactory = new SparseLAPTrackerFactory();
// Map<String, Object> trackerSettings = trackerFactory.getDefaultSettings();
// // trackerSettings.entrySet().forEach(MIA.log::writeDebug);
// trackerSettings.put(TrackerKeys.KEY_ALLOW_TRACK_SPLITTING, true);
// trackerSettings.put(TrackerKeys.KEY_ALLOW_TRACK_MERGING, true);

// Map<String, Object> settings = new SparseLAPTrackerFactory().getDefaultSettings();
// final Map< String, Object > ftfSettings = new HashMap<>();
// ftfSettings.put( TrackerKeys.KEY_LINKING_MAX_DISTANCE, 1000 );
// ftfSettings.put( TrackerKeys.KEY_ALTERNATIVE_LINKING_COST_FACTOR, settings.get( TrackerKeys.KEY_ALTERNATIVE_LINKING_COST_FACTOR ) );
// ftfSettings.put( TrackerKeys.KEY_LINKING_FEATURE_PENALTIES, settings.get( TrackerKeys.KEY_LINKING_FEATURE_PENALTIES ) );

// final SparseLAPFrameToFrameTracker frameToFrameLinker = new SparseLAPFrameToFrameTracker( spotCollection, ftfSettings );
// Model model = new Model();
// model.setSpots(spotCollection, false);
// frameToFrameLinker.process();
// SimpleWeightedGraph<Spot, DefaultWeightedEdge> result = frameToFrameLinker.getResult();
// model.setTracks(frameToFrameLinker.getResult(), false);
// // AdvancedKalmanTrackerFactory factory = new AdvancedKalmanTrackerFactory();
// OverlapTrackerFactory trackerFactory = new OverlapTrackerFactory();

// Map<String, Object> trackerSettings = trackerFactory.getDefaultSettings();
// // trackerSettings.entrySet().forEach(MIA.log::writeDebug);
// trackerSettings.put(TrackerKeys.KEY_ALLOW_TRACK_SPLITTING, true);
// trackerSettings.put(TrackerKeys.KEY_ALLOW_TRACK_MERGING, true);
// trackerSettings.put(OverlapTrackerFactory.KEY_MIN_IOU,0.001);
OverlapTracker3DFactory trackerFactory = new OverlapTracker3DFactory();
Map<String, Object> trackerSettings = trackerFactory.getDefaultSettings();
trackerSettings.put(OverlapTracker3DFactory.KEY_MIA_OBJECTS, inputObjects);

SpotTracker spotTracker = trackerFactory.create(spotCollection, trackerSettings);

Expand All @@ -520,20 +499,6 @@ public Status process(Workspace workspace) {
model.setTracks(spotTracker.getResult(), false);
SimpleWeightedGraph<Spot, DefaultWeightedEdge> result = spotTracker.getResult();

// SparseLAPTrackerFactory lapTrackerFactory = new SparseLAPTrackerFactory();
// Map<String,Object> lapTrackerSettings = lapTrackerFactory.getDefaultSettings();
// lapTrackerSettings.remove(TrackerKeys.KEY_LINKING_FEATURE_PENALTIES);
// lapTrackerSettings.remove(TrackerKeys.KEY_LINKING_MAX_DISTANCE);
// lapTrackerSettings.remove(TrackerKeys.KEY_BLOCKING_VALUE);
// final SegmentTracker segmentLinker = new SegmentTracker(result1, lapTrackerSettings);
// if (!segmentLinker.checkInput() || !segmentLinker.process()) {
// MIA.log.writeError(segmentLinker.getErrorMessage());
// return Status.FAIL;
// }

// model.setTracks(segmentLinker.getResult(), false);
// SimpleWeightedGraph<Spot, DefaultWeightedEdge> result = segmentLinker.getResult();

for (DefaultWeightedEdge edge : result.edgeSet()) {
Spot sourceSpot = result.getEdgeSource(edge);
Spot targetSpot = result.getEdgeTarget(edge);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,11 @@
import fiji.plugin.trackmate.Logger;
import fiji.plugin.trackmate.Spot;
import fiji.plugin.trackmate.SpotCollection;
import fiji.plugin.trackmate.SpotRoi;
import fiji.plugin.trackmate.tracking.SpotTracker;
import fiji.plugin.trackmate.util.Threads;
import math.geom2d.AffineTransform2D;
import math.geom2d.Point2D;
import math.geom2d.conic.Circle2D;
import math.geom2d.polygon.Polygon2D;
import math.geom2d.polygon.Polygons2D;
import math.geom2d.polygon.SimplePolygon2D;
import io.github.mianalysis.mia.MIA;
import io.github.mianalysis.mia.object.Objs;
import io.github.mianalysis.mia.object.coordinates.volume.Volume;
import net.imglib2.algorithm.MultiThreadedBenchmarkAlgorithm;

public class OverlapTracker3D extends MultiThreadedBenchmarkAlgorithm implements SpotTracker, Cancelable {
Expand All @@ -61,7 +57,7 @@ public class OverlapTracker3D extends MultiThreadedBenchmarkAlgorithm implements

private final SpotCollection spots;

private final double enlargeFactor;
private final Objs objs;

private final double minIoU;

Expand All @@ -73,11 +69,10 @@ public class OverlapTracker3D extends MultiThreadedBenchmarkAlgorithm implements
* CONSTRUCTOR
*/

public OverlapTracker3D(final SpotCollection spots, final double minIoU,
final double enlargeFactor) {
public OverlapTracker3D(final SpotCollection spots, final Objs objs, final double minIoU) {
this.spots = spots;
this.objs = objs;
this.minIoU = minIoU;
this.enlargeFactor = enlargeFactor;
}

/*
Expand Down Expand Up @@ -115,12 +110,6 @@ public boolean process() {
return false;
}

if (enlargeFactor <= 0) {
errorMessage = BASE_ERROR_MESSAGE + "The enlargement factor must be strictly positive, was "
+ enlargeFactor;
return false;
}

// Check that at least one inner collection contains an object.
boolean empty = true;
for (final int frame : spots.keySet()) {
Expand Down Expand Up @@ -151,7 +140,7 @@ public boolean process() {

// First frame.
final int sourceFrame = frameIterator.next();
Map<Spot, Polygon2D> sourceGeometries = createGeometry(spots.iterable(sourceFrame, true), enlargeFactor);
Map<Spot, Volume> sourceGeometries = createGeometry(spots.iterable(sourceFrame, true), objs);

logger.setStatus("Frame to frame linking...");
int progress = 0;
Expand All @@ -160,8 +149,7 @@ public boolean process() {
break;

final int targetFrame = frameIterator.next();
final Map<Spot, Polygon2D> targetGeometries = createGeometry(spots.iterable(targetFrame, true),
enlargeFactor);
final Map<Spot, Volume> targetGeometries = createGeometry(spots.iterable(targetFrame, true), objs);

if (sourceGeometries.isEmpty() || targetGeometries.isEmpty())
continue;
Expand All @@ -171,8 +159,8 @@ public boolean process() {

// Submit work.
for (final Spot target : targetGeometries.keySet()) {
final Polygon2D targetPoly = targetGeometries.get(target);
futures.add(executors.submit(new FindBestSourceTask(target, targetPoly, sourceGeometries, minIoU)));
final Volume targetVolume = targetGeometries.get(target);
futures.add(executors.submit(new FindBestSourceTask(target, targetVolume, sourceGeometries, minIoU)));
}

// Get results.
Expand Down Expand Up @@ -223,70 +211,56 @@ protected boolean checkSettingsValidity(final Map<String, Object> settings, fina

final boolean ok = true;
return ok;

}

private static Map<Spot, Polygon2D> createGeometry(final Iterable<Spot> spots, final double scale) {
final Map<Spot, Polygon2D> geometries = new HashMap<>();
private static Map<Spot, Volume> createGeometry(final Iterable<Spot> spots, final Objs objs) {
final Map<Spot, Volume> geometries = new HashMap<>();

for (final Spot spot : spots)
geometries.put(spot, toPolygon(spot, scale));
geometries.put(spot, objs.get(spot.getFeature("MIA_ID").intValue()));

return Collections.unmodifiableMap(geometries);
}

private static SimplePolygon2D toPolygon(final Spot spot, final double scale) {
final double xc = spot.getDoublePosition(0);
final double yc = spot.getDoublePosition(1);
final SpotRoi roi = spot.getRoi();
final SimplePolygon2D poly;
if (roi == null) {
final double radius = spot.getFeature(Spot.RADIUS).doubleValue();
poly = new SimplePolygon2D(new Circle2D(xc, yc, radius).asPolyline(32));
} else {
final double[] xcoords = roi.toPolygonX(1., 0., xc, 1.);
final double[] ycoords = roi.toPolygonY(1., 0., yc, 1.);
poly = new SimplePolygon2D(xcoords, ycoords);
}
return poly.transform(AffineTransform2D.createScaling(new Point2D(xc, yc), scale, scale));
}

private static final class FindBestSourceTask implements Callable<IoULink> {

private final Spot target;

private final Polygon2D targetPoly;
private final Volume targetVolume;

private final Map<Spot, Polygon2D> sourceGeometries;
private final Map<Spot, Volume> sourceGeometries;

private final double minIoU;

public FindBestSourceTask(final Spot target, final Polygon2D targetPoly,
final Map<Spot, Polygon2D> sourceGeometries, final double minIoU) {
public FindBestSourceTask(final Spot target, final Volume targetVolume, final Map<Spot, Volume> sourceGeometries, final double minIoU) {
this.target = target;
this.targetPoly = targetPoly;
this.targetVolume = targetVolume;
this.sourceGeometries = sourceGeometries;
this.minIoU = minIoU;
}

@Override
public IoULink call() throws Exception {
final double targetArea = Math.abs(targetPoly.area());
double maxIoU = minIoU;
Spot bestSpot = null;
for (final Spot spot : sourceGeometries.keySet()) {
final Polygon2D sourcePoly = sourceGeometries.get(spot);
final double intersection = Math.abs(Polygons2D.intersection(targetPoly, sourcePoly).area());
final Volume sourceVolume = sourceGeometries.get(spot);
final double intersection = targetVolume.getOverlap(sourceVolume);
if (intersection == 0.)
continue;

final double union = Math.abs(sourcePoly.area()) + targetArea - intersection;
final double union = sourceVolume.size() + targetVolume.size() - intersection;
final double iou = intersection / union;
if (iou > maxIoU) {
maxIoU = iou;
bestSpot = spot;
}
}

return new IoULink(bestSpot, target, maxIoU);

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,53 +39,28 @@
import fiji.plugin.trackmate.gui.components.tracker.OverlapTrackerSettingsPanel;
import fiji.plugin.trackmate.tracking.SpotTracker;
import fiji.plugin.trackmate.tracking.SpotTrackerFactory;
import io.github.mianalysis.mia.object.Objs;

@Plugin(type = SpotTrackerFactory.class)
public class OverlapTracker3DFactory implements SpotTrackerFactory {

final static String BASE_ERROR_MESSAGE = "[IoUTracker] ";

/**
* The key to the parameter that stores scale factor parameter. The scale
* factor allows for enlarging (&gt;1) or shrinking (&lt;1) the spot shapes
* before computing their IoU. Values are strictly positive {@link Double}s.
*/
public static final String KEY_SCALE_FACTOR = "SCALE_FACTOR";

public static final Double DEFAULT_SCALE_FACTOR = Double.valueOf(1.);

/**
* The key to the parameter that stores the minimal IoU below which links
* are not created. Values are strictly positive {@link Double}s.
*/
public static final String KEY_MIN_IOU = "MIN_IOU";

public static final String KEY_MIA_OBJECTS = "MIA_OBJECTS";

public static final Double DEFAULT_MIN_IOU = Double.valueOf(0.3);

public static final String TRACKER_KEY = "OVERLAP_TRACKER";

public static final String TRACKER_NAME = "Overlap tracker";

public static final String TRACKER_INFO_TEXT = "<html> "
+ "This tracker is a simple extension of the Intersection - over - Union (IoU) tracker. "
+ "<p> "
+ "<p> "
+ "It generates links between spots whose shapes overlap between consecutive frames. "
+ "When several spots are eligible as a source for a target, the one with the largest IoU "
+ "is chosen."
+ "<p> "
+ "<p> "
+ "The minimal IoU parameter sets a threshold below which links won't be created. The scale "
+ "factor allows for enlarging (&gt;1) or shrinking (&lt;1) the spot shapes before computing "
+ "their IoU. Two methods can be used to compute IoU: The <it>Fast</it> one approximates "
+ "the spot shapes by their rectangular bounding-box. The <it>Precise</it> one uses the actual "
+ "spot polygon. "
+ "<p> "
+ "<p> "
+ "This tracker works in 2D and 3D. However in 3D, the IoU is computed from the "
+ "bounding-boxes regardless of the choice of the IoU computation method. "
+ "The <it>Precise</it> method is not implemented."
+ "</html>";
public static final String TRACKER_INFO_TEXT = "";

private String errorMessage;

Expand All @@ -111,9 +86,9 @@ public ImageIcon getIcon() {

@Override
public SpotTracker create(final SpotCollection spots, final Map<String, Object> settings) {
final double pixelSize = (Double) settings.get(KEY_SCALE_FACTOR);
final double minIoU = (Double) settings.get(KEY_MIN_IOU);
return new OverlapTracker3D(spots, minIoU, pixelSize);
final Objs objs = (Objs) settings.get(KEY_MIA_OBJECTS);
return new OverlapTracker3D(spots, objs, minIoU);
}

@Override
Expand All @@ -126,7 +101,6 @@ public boolean marshall(final Map<String, Object> settings, final Element elemen
boolean ok = true;
final StringBuilder str = new StringBuilder();

ok = ok & writeAttribute(settings, element, KEY_SCALE_FACTOR, Double.class, str);
ok = ok & writeAttribute(settings, element, KEY_MIN_IOU, Double.class, str);
return ok;
}
Expand All @@ -137,7 +111,6 @@ public boolean unmarshall(final Element element, final Map<String, Object> setti
final StringBuilder errorHolder = new StringBuilder();
boolean ok = true;

ok = ok & readDoubleAttribute(element, settings, KEY_SCALE_FACTOR, errorHolder);
ok = ok & readDoubleAttribute(element, settings, KEY_MIN_IOU, errorHolder);
return ok;
}
Expand All @@ -147,20 +120,17 @@ public String toString(final Map<String, Object> settings) {
if (!checkSettingsValidity(settings))
return errorMessage;

final double scale = (Double) settings.get(KEY_SCALE_FACTOR);
final double minIoU = (Double) settings.get(KEY_MIN_IOU);

final StringBuilder str = new StringBuilder();

str.append(String.format(" - scale factor: %.2f\n", scale));
str.append(String.format(" - min. IoU: %.2f\n", minIoU));
return str.toString();
}

@Override
public Map<String, Object> getDefaultSettings() {
final Map<String, Object> settings = new HashMap<>();
settings.put(KEY_SCALE_FACTOR, DEFAULT_SCALE_FACTOR);
settings.put(KEY_MIN_IOU, DEFAULT_MIN_IOU);
return settings;
}
Expand All @@ -174,18 +144,12 @@ public boolean checkSettingsValidity(final Map<String, Object> settings) {

boolean ok = true;
final StringBuilder str = new StringBuilder();
ok = ok & checkParameter(settings, KEY_SCALE_FACTOR, Double.class, str);
ok = ok & checkParameter(settings, KEY_MIN_IOU, Double.class, str);
if (!ok) {
errorMessage = str.toString();
return false;

}
final double scale = ((Double) settings.get(KEY_SCALE_FACTOR)).doubleValue();
if (scale <= 0) {
errorMessage = BASE_ERROR_MESSAGE + "Scale factor must be strictly positive, was " + scale;
return false;
}

return true;
}
Expand Down

0 comments on commit 53c4c93

Please sign in to comment.