Skip to content

Commit

Permalink
Added option for deformed back-projection based on MLS deformation.
Browse files Browse the repository at this point in the history
Added source point computation based on SimpleMatrix instead of Jama.Matrix.
  • Loading branch information
jenniferMaier committed Nov 30, 2021
1 parent 4b0656e commit 3c34107
Show file tree
Hide file tree
Showing 3 changed files with 768 additions and 23 deletions.
194 changes: 172 additions & 22 deletions src/edu/stanford/rsl/conrad/opencl/OpenCLBackProjector.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@
import com.jogamp.opencl.CLImageFormat.ChannelType;
import com.jogamp.opencl.CLMemory.Mem;

import Jama.Matrix;
import Jama.SingularValueDecomposition;
import edu.stanford.rsl.apps.gui.Citeable;
import edu.stanford.rsl.conrad.data.numeric.Grid2D;
import edu.stanford.rsl.conrad.data.numeric.Grid3D;
import edu.stanford.rsl.conrad.geometry.trajectories.Trajectory;
import edu.stanford.rsl.conrad.io.ImagePlusDataSink;
import edu.stanford.rsl.conrad.numerics.SimpleMatrix;
import edu.stanford.rsl.conrad.numerics.SimpleOperators;
import edu.stanford.rsl.conrad.numerics.SimpleVector;
import edu.stanford.rsl.conrad.numerics.SimpleVector.VectorNormType;
import edu.stanford.rsl.conrad.reconstruction.VOIBasedReconstructionFilter;
import edu.stanford.rsl.conrad.utils.CONRAD;
import edu.stanford.rsl.conrad.utils.Configuration;
Expand Down Expand Up @@ -77,6 +83,13 @@ public class OpenCLBackProjector extends VOIBasedReconstructionFilter implements
*/
protected float h_volume[];

/**
* Deformed 3D coordinates
*/
protected SimpleVector[] p = null;
protected SimpleVector[][] qs = null;
private CLBuffer<FloatBuffer> pBuff;
private CLBuffer<FloatBuffer> qBuff;

/**
* The global variable of the module which stores the
Expand All @@ -95,6 +108,8 @@ public class OpenCLBackProjector extends VOIBasedReconstructionFilter implements

private boolean initialized = false;



public OpenCLBackProjector () {
super();
}
Expand Down Expand Up @@ -187,8 +202,11 @@ protected void init(){
}

// create the computing kernel
kernelFunction = program.createCLKernel("backprojectKernel");

if (this.p == null) {
kernelFunction = program.createCLKernel("backprojectKernel");
} else {
kernelFunction = program.createCLKernel("backprojectKernelDeformed");
}
// create the reconstruction volume;
// createFloatBuffer uses a byteBuffer internally --> h_volume.length cannot be > 2^31/4 = 2^31/2^2 = 2^29
// Thus, 2^29 would already cause a overflow (negative sign) of the integer in the byte buffer! Maximum length is (2^29-1)
Expand Down Expand Up @@ -326,15 +344,15 @@ protected synchronized void initProjectionData(Grid2D projection){
// Create the array that will contain the projection data.
projectionArray = context.createFloatBuffer(projection.getWidth()*projection.getHeight(), Mem.READ_ONLY);
}

// Copy the projection data to the array
projectionArray.getBuffer().put(projection.getBuffer());
projectionArray.getBuffer().rewind();

if(projectionTex != null && !projectionTex.isReleased()){
projectionTex.release();
}

// set the texture
CLImageFormat format = new CLImageFormat(ChannelOrder.INTENSITY, ChannelType.FLOAT);
projectionTex = context.createImage2d(projectionArray.getBuffer(), projection.getWidth(), projection.getHeight(), format, Mem.READ_ONLY);
Expand Down Expand Up @@ -399,6 +417,11 @@ protected synchronized void projectSingleProjection(int projectionNumber, int di
if (!largeVolumeMode) {
projections.remove(projectionNumber);
}

if (this.p != null) {
initDeformedCoordinates(projectionNumber);
}

// backproject for each slice
// OpenCL Grids are only two dimensional!
int reconDimensionZ = dimz;
Expand All @@ -407,28 +430,69 @@ protected synchronized void projectSingleProjection(int projectionNumber, int di
double voxelSpacingZ = getGeometry().getVoxelSpacingZ();

// write kernel parameters
kernelFunction.rewind();
kernelFunction
.putArg(volumePointer)
.putArg(getGeometry().getReconDimensionX())
.putArg(getGeometry().getReconDimensionY())
.putArg(reconDimensionZ)
.putArg((int) lineOffset)
.putArg((float) voxelSpacingX)
.putArg((float) voxelSpacingY)
.putArg((float) voxelSpacingZ)
.putArg((float) offsetX)
.putArg((float) offsetY)
.putArg((float) offsetZ)
.putArg(projectionTex)
.putArg(projectionMatrix)
.putArg(projectionMultiplier);
// deformed coordinates
if (this.pBuff == null) {
kernelFunction.rewind();
kernelFunction
.putArg(volumePointer)
.putArg(getGeometry().getReconDimensionX())
.putArg(getGeometry().getReconDimensionY())
.putArg(reconDimensionZ)
.putArg((int) lineOffset)
.putArg((float) voxelSpacingX)
.putArg((float) voxelSpacingY)
.putArg((float) voxelSpacingZ)
.putArg((float) offsetX)
.putArg((float) offsetY)
.putArg((float) offsetZ)
.putArg(projectionTex)
.putArg(projectionMatrix)
.putArg(projectionMultiplier);
} else {
// System.out.println("pBuff = " + pBuff.getBuffer().get(0) + ", " + pBuff.getBuffer().get(1) + ", " + pBuff.getBuffer().get(2));
// System.out.println("qBuff = " + qBuff.getBuffer().get(0) + ", " + qBuff.getBuffer().get(1) + ", " + qBuff.getBuffer().get(2));
kernelFunction.rewind();
kernelFunction
.putArg(volumePointer)
.putArg(getGeometry().getReconDimensionX())
.putArg(getGeometry().getReconDimensionY())
.putArg(reconDimensionZ)
.putArg((int) lineOffset)
.putArg((float) voxelSpacingX)
.putArg((float) voxelSpacingY)
.putArg((float) voxelSpacingZ)
.putArg((float) offsetX)
.putArg((float) offsetY)
.putArg((float) offsetZ)
.putArg(projectionTex)
.putArg(projectionMatrix)
.putArg(projectionMultiplier)
.putArg((float) p[0].getElement(0))
.putArg((float) p[0].getElement(1))
.putArg((float) p[0].getElement(2))
.putArg((float) p[1].getElement(0))
.putArg((float) p[1].getElement(1))
.putArg((float) p[1].getElement(2))
.putArg((float) p[2].getElement(0))
.putArg((float) p[2].getElement(1))
.putArg((float) p[2].getElement(2))
.putArg((float) qs[projectionNumber][0].getElement(0))
.putArg((float) qs[projectionNumber][0].getElement(1))
.putArg((float) qs[projectionNumber][0].getElement(2))
.putArg((float) qs[projectionNumber][1].getElement(0))
.putArg((float) qs[projectionNumber][1].getElement(1))
.putArg((float) qs[projectionNumber][1].getElement(2))
.putArg((float) qs[projectionNumber][2].getElement(0))
.putArg((float) qs[projectionNumber][2].getElement(1))
.putArg((float) qs[projectionNumber][2].getElement(2));
}




int[] realLocalSize = new int[2];
realLocalSize[0] = Math.min(device.getMaxWorkGroupSize(),bpBlockSize[0]);
realLocalSize[1] = Math.max(1, Math.min(device.getMaxWorkGroupSize()/realLocalSize[0], bpBlockSize[1]));

// rounded up to the nearest multiple of localWorkSize
int[] globalWorkSize = {getGeometry().getReconDimensionX(), getGeometry().getReconDimensionY()};
if ((globalWorkSize[0] % realLocalSize[0] ) != 0){
Expand All @@ -447,6 +511,85 @@ protected synchronized void projectSingleProjection(int projectionNumber, int di
.finish();
}

private void initDeformedCoordinates(int projectionNumber) {

SimpleVector[] q = this.qs[projectionNumber];

float[] p_floats = {
(float)p[0].getElement(0), (float)p[0].getElement(1), (float)p[0].getElement(2),
(float)p[1].getElement(0), (float)p[1].getElement(1), (float)p[1].getElement(2),
(float)p[2].getElement(0), (float)p[2].getElement(1), (float)p[2].getElement(2)};

float[] q_floats = {
(float)q[0].getElement(0), (float)q[0].getElement(1), (float)q[0].getElement(2),
(float)q[1].getElement(0), (float)q[1].getElement(1), (float)q[1].getElement(2),
(float)q[2].getElement(0), (float)q[2].getElement(1), (float)q[2].getElement(2)};

CLBuffer<FloatBuffer> pBuffer = context.createFloatBuffer(p.length*p[0].getLen(), Mem.READ_ONLY);
pBuffer.getBuffer().put(p_floats);
pBuffer.getBuffer().rewind();
this.pBuff = pBuffer;

CLBuffer<FloatBuffer> qBuffer = context.createFloatBuffer(q.length*q[0].getLen(), Mem.READ_ONLY);
qBuffer.getBuffer().put(q_floats);
qBuffer.getBuffer().rewind();
this.qBuff = qBuffer;
}

public static SimpleVector mls_rigid_deformation3D(Grid3D volume, SimpleVector coords, SimpleVector[] pIn, SimpleVector[] qIn, double alpha) {
// Rigid deformation
// Params:
// image - ndarray: original image
// p - ndarray: an array with size [n, 2], original control points
// q - ndarray: an array with size [n, 2], final control points
// alpha - float: parameter used by weights
// density - float: density of the grids
// Return:
// A deformed image.


SimpleVector[] p = pIn.clone();
SimpleVector[] q = qIn.clone();
SimpleVector v = coords.clone();

// wi
double[] w = new double[q.length];
double wSum = 0.0;
for (int i = 0; i < w.length; i++) {
w[i] = 1.0/Math.pow(SimpleOperators.subtract(p[i], v).norm(VectorNormType.VEC_NORM_L2), 2.0*alpha);
wSum += w[i];
}

// q* and p*
SimpleVector qStar = new SimpleVector(q[0].getLen());
SimpleVector pStar = new SimpleVector(p[0].getLen());
for (int i = 0; i < w.length; i++) {
qStar.add(q[i].multipliedBy(w[i]));
pStar.add(p[i].multipliedBy(w[i]));
}
qStar.divideBy(wSum);
pStar.divideBy(wSum);

// PQ'
SimpleMatrix PQtrans = new SimpleMatrix(pIn[0].getLen(), qIn[0].getLen());
for (int i = 0; i < pIn.length; i++) {
SimpleMatrix pi_min_pStar = new SimpleMatrix(new double[][] {SimpleOperators.subtract(p[i], pStar).copyAsDoubleArray()});
SimpleMatrix qi_min_qStar = new SimpleMatrix(new double[][] {SimpleOperators.subtract(q[i], qStar).copyAsDoubleArray()});
PQtrans.add(SimpleOperators.multiplyMatrices(pi_min_pStar.transposed(), qi_min_qStar).multipliedBy(w[i]));
}
SingularValueDecomposition svd = new SingularValueDecomposition(new Matrix(PQtrans.copyAsDoubleArray()));
SimpleMatrix U = new SimpleMatrix(svd.getU().getArrayCopy());
SimpleMatrix V = new SimpleMatrix(svd.getV().getArrayCopy());

// transform input coords
SimpleVector v_min_pStar = SimpleOperators.subtract(v, pStar);
SimpleMatrix VUtrans = SimpleOperators.multiplyMatrices(V,U.transposed());
SimpleVector v_transformed = SimpleOperators.add(SimpleOperators.multiply(VUtrans, v_min_pStar), qStar);

return v_transformed;

}

public void loadInputQueue(Grid3D input) throws IOException {
ImageGridBuffer buf = new ImageGridBuffer();
buf.set(input);
Expand Down Expand Up @@ -566,6 +709,7 @@ public void OpenCLRun() {

private synchronized void workOnProjectionData(){
if (projectionsAvailable.size() > 0){
System.out.println(projectionsAvailable.size());
Integer current = projectionsAvailable.get(0);
projectionsAvailable.remove(0);
projectSingleProjection(current.intValue(),
Expand Down Expand Up @@ -644,6 +788,12 @@ public String getToolName(){
return "OpenCL Backprojector";
}

public void setDeformParameters(SimpleVector[] p, SimpleVector[][] qs) {
this.p = p;
this.qs = qs;
}


}
/*
* Copyright (C) 2010-2014 Martin Berger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public void computeCanonicalProjectionMatrix(CLBuffer<FloatBuffer> detectorDirec
originShift = getOriginTransform();

// New srcPoint in the Canonical coord sys
SimpleVector srcPtW = proj.computeCameraCenter().negated();//computeSrcPt(projectionMatrix, invARmatrixMat);
SimpleVector srcPtW = proj.computeCameraCenter().negated();//computeSrcPt(proj.computeP(), invARmatrixMat); //
srcPoint.getBuffer().put((float) -(-0.5 * (volumeSize[0] -1.0) + originShift.getElement(0)*invVoxelScale.getElement(0,0) + invVoxelScale.getElement(0,0) * srcPtW.getElement(0)));
srcPoint.getBuffer().put((float) -(-0.5 * (volumeSize[1] -1.0) + originShift.getElement(1)*invVoxelScale.getElement(1,1) + invVoxelScale.getElement(1,1) * srcPtW.getElement(1)));
srcPoint.getBuffer().put((float) -(-0.5 * (volumeSize[2] -1.0) + originShift.getElement(2)*invVoxelScale.getElement(2,2) + invVoxelScale.getElement(2,2) * srcPtW.getElement(2)));
Expand Down Expand Up @@ -281,6 +281,12 @@ public static Jama.Matrix computeSrcPt(Jama.Matrix projectionMatrix, Jama.Matrix
return invertedProjMatrix.times(at);
}

public static SimpleVector computeSrcPt(SimpleMatrix projectionMatrix, SimpleMatrix invertedProjMatrix) {
SimpleVector at = projectionMatrix.getSubCol(0, 3, 3);//.getMatrix(0, 2, 3, 3);
//at = at.times(-1.0);
return SimpleOperators.multiply(invertedProjMatrix, at);
}

protected SimpleVector getOriginTransform(){
SimpleVector currOrigin = new SimpleVector(this.origin);
// compute centered origin as assumed by forward projector
Expand Down
Loading

0 comments on commit 3c34107

Please sign in to comment.