Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
*
* @author Matthias Arzt
*/
@Deprecated
public abstract class AbstractMultiThreadedConvolution< T > implements Convolution< T >
{

Expand All @@ -61,6 +62,7 @@ abstract protected void process( RandomAccessible< ? extends T > source,
ExecutorService executorService,
int numThreads );

@Deprecated
@Override
public void setExecutor( final ExecutorService executor )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,12 @@ class Concatenation< T > implements Convolution< T >
this.steps = new ArrayList<>( steps );
}

@Deprecated
@Override
public void setExecutor( final ExecutorService executor )
public void setExecutor( ExecutorService executor )
{
steps.forEach( step -> step.setExecutor( executor ) );
for ( Convolution<T> step : steps )
step.setExecutor( executor );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public interface Convolution< T >
/**
* Set the {@link ExecutorService} to be used for convolution.
*/
@Deprecated
default void setExecutor( final ExecutorService executor )
{}

Expand Down
132 changes: 36 additions & 96 deletions src/main/java/net/imglib2/algorithm/convolution/LineConvolution.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,50 @@
*/
package net.imglib2.algorithm.convolution;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Supplier;

import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.Localizable;
import net.imglib2.Point;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.util.IntervalIndexer;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.parallel.Parallelization;
import net.imglib2.parallel.TaskExecutor;
import net.imglib2.parallel.TaskExecutors;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.util.Localizables;
import net.imglib2.view.Views;

import java.util.concurrent.ExecutorService;

/**
* This class can be used to implement a separable convolution. It applies a
* {@link LineConvolverFactory} on the given images.
*
* @author Matthias Arzt
*/
public class LineConvolution< T > extends AbstractMultiThreadedConvolution< T >
public class LineConvolution< T > implements Convolution<T>
{
private final LineConvolverFactory< ? super T > factory;

private final int direction;

private ExecutorService executor;

public LineConvolution( final LineConvolverFactory< ? super T > factory, final int direction )
{
this.factory = factory;
this.direction = direction;
}

@Deprecated
@Override
public void setExecutor( ExecutorService executor )
{
this.executor = executor;
}

@Override
public Interval requiredSourceInterval( final Interval targetInterval )
{
Expand All @@ -84,104 +90,38 @@ public Interval requiredSourceInterval( final Interval targetInterval )
@Override
public T preferredSourceType( final T targetType )
{
return (T) factory.preferredSourceType( targetType );
return Cast.unchecked( factory.preferredSourceType( targetType ) );
}

@Override
protected void process( final RandomAccessible< ? extends T > source, final RandomAccessibleInterval< ? extends T > target, final ExecutorService executorService, final int numThreads )
public void process( RandomAccessible< ? extends T > source, RandomAccessibleInterval< ? extends T > target )
{
final RandomAccessibleInterval< ? extends T > sourceInterval = Views.interval( source, requiredSourceInterval( target ) );
final long[] sourceMin = Intervals.minAsLongArray( sourceInterval );
final long[] targetMin = Intervals.minAsLongArray( target );

final Supplier< Consumer< Localizable > > actionFactory = () -> {

final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
final RandomAccess< ? extends T > out = target.randomAccess();
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );

return position -> {
in.setPosition( sourceMin );
out.setPosition( targetMin );
in.move( position );
out.move( position );
convolver.run();
};
};

final long[] dim = Intervals.dimensionsAsLongArray( target );
dim[ direction ] = 1;

final int numTasks = numThreads > 1 ? timesFourAvoidOverflow(numThreads) : 1;
LineConvolution.forEachIntervalElementInParallel( executorService, numTasks, new FinalInterval( dim ), actionFactory );
}
RandomAccessibleInterval< Localizable > positions = Localizables.randomAccessibleInterval( new FinalInterval( dim ) );
TaskExecutor taskExecutor = executor == null ? Parallelization.getTaskExecutor() : TaskExecutors.forExecutorService( executor );
LoopBuilder.setImages( positions ).multiThreaded(taskExecutor).forEachChunk(
chunk -> {

private int timesFourAvoidOverflow( int x )
{
return (int) Math.min((long) x * 4, Integer.MAX_VALUE);
}
final RandomAccess< ? extends T > in = sourceInterval.randomAccess();
final RandomAccess< ? extends T > out = target.randomAccess();
final Runnable convolver = factory.getConvolver( in, out, direction, target.dimension( direction ) );

/**
* {@link #forEachIntervalElementInParallel(ExecutorService, int, Interval, Supplier)}
* executes a given action for each position in a given interval. Therefor
* it starts the specified number of tasks. Each tasks calls the action
* factory once, to get an instance of the action that should be executed.
* The action is then called multiple times by the task.
*
* @param service
* {@link ExecutorService} used to create the tasks.
* @param numTasks
* number of tasks to use.
* @param interval
* interval to iterate over.
* @param actionFactory
* factory that returns the action to be executed.
*/
// TODO: move to a better place
public static void forEachIntervalElementInParallel( final ExecutorService service, final int numTasks, final Interval interval,
final Supplier< Consumer< Localizable > > actionFactory )
{
final long[] min = Intervals.minAsLongArray( interval );
final long[] dim = Intervals.dimensionsAsLongArray( interval );
final long size = Intervals.numElements( dim );
final int boundedNumTasks = (int) Math.max( 1, Math.min(size, numTasks ));
final long taskSize = ( size - 1 ) / boundedNumTasks + 1; // taskSize = roundUp(size / boundedNumTasks);
final ArrayList< Callable< Void > > callables = new ArrayList<>();
chunk.forEachPixel( position -> {
in.setPosition( sourceMin );
out.setPosition( targetMin );
in.move( position );
out.move( position );
convolver.run();
} );

for ( int taskNum = 0; taskNum < boundedNumTasks; ++taskNum )
{
final long myStartIndex = taskNum * taskSize;
final long myEndIndex = Math.min( size, myStartIndex + taskSize );
final Callable< Void > r = () -> {
final Consumer< Localizable > action = actionFactory.get();
final long[] position = new long[ dim.length ];
final Localizable localizable = Point.wrap( position );
for ( long index = myStartIndex; index < myEndIndex; ++index )
{
IntervalIndexer.indexToPositionWithOffset( index, dim, min, position );
action.accept( localizable );
return null;
}
return null;
};
callables.add( r );
}
execute( service, callables );
}

private static void execute( final ExecutorService service, final ArrayList< Callable< Void > > callables )
{
try
{
final List< Future< Void > > futures = service.invokeAll( callables );
for ( final Future< Void > future : futures )
future.get();
}
catch ( final InterruptedException | ExecutionException e )
{
final Throwable cause = e.getCause();
if ( cause instanceof RuntimeException )
throw ( RuntimeException ) cause;
throw new RuntimeException( e );
}
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class MultiDimensionConvolution< T > implements Convolution< T >
{
private ExecutorService executor;

@Deprecated
@Override
public void setExecutor( final ExecutorService executor )
{
Expand Down
71 changes: 55 additions & 16 deletions src/main/java/net/imglib2/algorithm/gauss3/Gauss3.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@

package net.imglib2.algorithm.gauss3;

import java.util.Arrays;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;

import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.algorithm.convolution.Convolution;
import net.imglib2.algorithm.convolution.kernel.Kernel1D;
import net.imglib2.algorithm.convolution.kernel.SeparableKernelConvolution;
import net.imglib2.exception.IncompatibleTypeException;
import net.imglib2.parallel.Parallelization;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.DoubleType;
Expand All @@ -56,7 +59,7 @@
public final class Gauss3
{
/**
* Apply Gaussian convolution to source and write the result to output.
* Apply Gaussian convolution to source and write the result to target.
* In-place operation (source==target) is supported.
*
* <p>
Expand All @@ -66,6 +69,11 @@ public final class Gauss3
* in their own precision. The source type S and target type T are either
* both {@link RealType RealTypes} or both the same type.
*
* <p>
* Computation may be multi-threaded, according to the current
* {@link Parallelization} context. (By default, it will use the
* {@link ForkJoinPool#commonPool() common ForkJoinPool})
*
* @param sigma
* standard deviation of isotropic Gaussian.
* @param source
Expand Down Expand Up @@ -93,7 +101,7 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
}

/**
* Apply Gaussian convolution to source and write the result to output.
* Apply Gaussian convolution to source and write the result to target.
* In-place operation (source==target) is supported.
*
* <p>
Expand All @@ -104,9 +112,10 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
* both {@link RealType RealTypes} or both the same type.
*
* <p>
* Computation is multi-threaded with as many threads as processors
* available.
*
* Computation may be multi-threaded, according to the current
* {@link Parallelization} context. (By default, it will use the
* {@link ForkJoinPool#commonPool() common ForkJoinPool})
*
* @param sigma
* standard deviation in every dimension.
* @param source
Expand All @@ -126,13 +135,27 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
*/
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target ) throws IncompatibleTypeException
{
final int numthreads = Runtime.getRuntime().availableProcessors();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the javadoc:
"Computation is multi-threaded with as many threads as processors available."
should be something like
"...Computation may be multi-threaded, according to the current {@link Parallelization} context. (By default, runs on the common ForkJoinPool ...)"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I updated the javadoc.

final ExecutorService service = Executors.newFixedThreadPool( numthreads );
gauss( sigma, source, target, service );
service.shutdown();
final double[][] halfkernels = halfkernels( sigma );
final Convolution< NumericType< ? > > convolution = SeparableKernelConvolution.convolution( Kernel1D.symmetric( halfkernels ) );
convolution.process( source, target );
}

/**
* @deprecated
* Deprecated. Please use
* {@link Gauss3#gauss(double, RandomAccessible, RandomAccessibleInterval)
* gauss(sigma, source, target)} instead. The number of threads used to
* calculate the Gaussion convolution may by set with the
* {@link Parallelization} context, as show in this example:
* <pre>
* {@code
* Parallelization.runWithNumThreads( numThreads,
* () -> gauss( sigma, source, target )
* );
* }
* </pre>
*
* <p>
* Apply Gaussian convolution to source and write the result to output.
* In-place operation (source==target) is supported.
*
Expand Down Expand Up @@ -162,14 +185,30 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
* if source and target type are not compatible (they must be
* either both {@link RealType RealTypes} or the same type).
*/
@Deprecated
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target, final int numThreads ) throws IncompatibleTypeException
{
final ExecutorService service = Executors.newFixedThreadPool( numThreads );
gauss( sigma, source, target, service );
service.shutdown();
Parallelization.runWithNumThreads( numThreads,
() -> gauss( sigma, source, target )
);
}

/**
* @deprecated
* Deprecated. Please use
* {@link Gauss3#gauss(double, RandomAccessible, RandomAccessibleInterval)
* gauss(sigma, source, target)} instead. The ExecutorService used to
* calculate the Gaussion convolution may by set with the
* {@link Parallelization} context, as show in this example:
* <pre>
* {@code
* Parallelization.runWithExecutor( executorService,
* () -> gauss( sigma, source, target )
* );
* }
* </pre>
*
* <p>
* Apply Gaussian convolution to source and write the result to output.
* In-place operation (source==target) is supported.
*
Expand Down Expand Up @@ -199,12 +238,12 @@ public static < S extends NumericType< S >, T extends NumericType< T > > void ga
* if source and target type are not compatible (they must be
* either both {@link RealType RealTypes} or the same type).
*/
@Deprecated
public static < S extends NumericType< S >, T extends NumericType< T > > void gauss( final double[] sigma, final RandomAccessible< S > source, final RandomAccessibleInterval< T > target, final ExecutorService service ) throws IncompatibleTypeException
{
final double[][] halfkernels = halfkernels( sigma );
final Convolution< NumericType< ? > > convolution = SeparableKernelConvolution.convolution( Kernel1D.symmetric( halfkernels ) );
convolution.setExecutor( service );
convolution.process( source, target );
Parallelization.runWithExecutor( service,
() -> gauss( sigma, source, target )
);
}

public static double[][] halfkernels( final double[] sigma )
Expand Down
Loading