Skip to content

Commit

Permalink
toy examples from HOM paper and support for geodesics of non-unit length
Browse files Browse the repository at this point in the history
  • Loading branch information
nefan committed Nov 26, 2013
1 parent aacb3f0 commit 1361e1f
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 31 deletions.
10 changes: 7 additions & 3 deletions lddmm/getPointLDDMMTransport.m
Expand Up @@ -30,13 +30,17 @@
scaleweight = lddmmoptions.scaleweight;
energyweight = lddmmoptions.energyweight;

function res = ltransport(x, points, varargin)
function [res,Gt] = ltransport(x, points, varargin)
backwards = false;
if size(varargin,2) > 0
backwards = varargin{1};
end
tend = 1;
if size(varargin,2) > 1
tend = varargin{2};
end

Gt = pointPath(x); % rhot is a solver structure
Gt = pointPath(x,tend); % rhot is a solver structure

if iscell(points)
g0 = points;
Expand All @@ -50,7 +54,7 @@
end
end
if order == 0
g1 = shootgrid(g0,Gt,lddmmoptions,backwards);
g1 = shootgrid(g0,Gt,lddmmoptions,backwards,tend);
else
g1 = shootgridDKernels(x,g0,Gt,lddmmoptions,backwards);
end
Expand Down
15 changes: 10 additions & 5 deletions lddmm/order0/getPointPathEnergyOrder0.m
Expand Up @@ -31,7 +31,11 @@

[Ks D1Ks D2Ks] = gaussianKernels();

function [E v] = lpathEnergy(x,rhot)
function [E v] = lpathEnergy(x,rhot,varargin)
tend = 1;
if size(varargin,2) > 0
tend = varargin{1};
end

function vt = gradEt(tt,y)
t = intTime(tt,true,lddmmoptions);
Expand Down Expand Up @@ -59,13 +63,14 @@
vt = -intResult(vt,true,lddmmoptions); % sign for backwards integration already accounted for
end

function Et = Gc(tt,yt) % wrapper for C version of G
function Et = Gc(tt) % wrapper for C version of G
t = intTime(tt,false,lddmmoptions);
rhott = deval(rhot,t);

Et = fastPointPathEnergyOrder0(rhott,L,R,cdim,scales.^2,scaleweight.^2);

% debug
% [t Et]
if getOption(lddmmoptions,'testC')
Et2 = G(tt);
assert(norm(Et2-Et) < 10e-12);
Expand All @@ -91,14 +96,14 @@
end
end

E = integrate(@Gc,0,1); % fast C version
% E = energyweight(1)*integrate(@G,0,1); % sloow matlab version
E = integrate(@Gc,0,tend); % fast C version
% E = energyweight(1)*integrate(@G,0,tend); % sloow matlab version
assert(E >= 0);

if nargout > 1
v1 = zeros(CSP*L,1);
options = odeset('RelTol',1e-6,'AbsTol',1e-6);
vt = ode45(@gradEt,[0 1],v1,options);
vt = ode45(@gradEt,[0 tend],v1,options);
v = deval(vt,0);
end
end
Expand Down
10 changes: 7 additions & 3 deletions lddmm/order0/getPointPathOrder0.m
Expand Up @@ -81,7 +81,11 @@
drho = reshape(drho,L*cCSP,1);
end

function rhot = pointPathOrder0(x)
function rhot = pointPathOrder0(x,varargin)
tend = 1;
if size(varargin,2) > 0
tend = varargin{1};
end
if dim == cdim
rho0 = [moving; reshape(x,dim*R,L)];
else
Expand All @@ -93,9 +97,9 @@
rho0(cdim+(2:cdim:cdim*R),:) = xx(2:dim:dim*R,:);
end
options = odeset('RelTol',1e-6,'AbsTol',1e-6);
rhot = ode45(@Gc,[0 1],reshape(rho0,L*cCSP,1),options); % C version
rhot = ode45(@Gc,[0 tend],reshape(rho0,L*cCSP,1),options); % C version
% rhot = ode45(@G,[0 1],reshape(rho0,L*cCSP,1),options); % matlab version
assert(rhot.x(end) == 1); % if not, integration failed
assert(rhot.x(end) == tend); % if not, integration failed

end

Expand Down
13 changes: 9 additions & 4 deletions lddmm/order1/getPointPathOrder1.m
Expand Up @@ -31,7 +31,12 @@

ks = dkernelsGaussian(cdim);

function Gt = pointPathOrder1(x)
function Gt = pointPathOrder1(x,varargin)
tend = 1;
if size(varargin,2) > 0
tend = varargin{1};
end

x = reshape(x,(1+dim)*dim,L);
rhoj0 = x(dim+(1:dim^2),:);
if dim ~= cdim
Expand Down Expand Up @@ -123,9 +128,9 @@
initial(1:dim,:) = moving; % position
initial(cdim+(1:cdim^2),:) = repmat(reshape(eye(cdim),cdim^2,1),1,L); % Dphi
initial(cdim+cdim^2+(1:dim),:) = x(1:R*dim,:); % mu/rho0
Gt = ode45(@Gc,[0 1],reshape(initial,L*cCSP,1),options); % fast native version
% Gt = ode45(@G,[0 1],reshape(initial,L*cCSP,1),options); % matlab version
assert(Gt.x(end) == 1); % if not, integration failed
Gt = ode45(@Gc,[0 tend],reshape(initial,L*cCSP,1),options); % fast native version
% Gt = ode45(@G,[0 tend],reshape(initial,L*cCSP,1),options); % matlab version
assert(Gt.x(end) == tend); % if not, integration failed

end

Expand Down
30 changes: 17 additions & 13 deletions lddmm/workers/order1/shootgrid.m
Expand Up @@ -36,24 +36,28 @@
if size(varargin,2) > 0
backwards = varargin{1};
end
% for movies
selectscale = -1;
tend = 1;
if size(varargin,2) > 1
selectscale = varargin{2};
tend = varargin{2};
end
% % for movies
% selectscale = -1;
% if size(varargin,2) > 1
% selectscale = varargin{2};
% end

function dgrid = Gc(tt,gridt) % wrapper for C version
t = intTime(tt,backwards,lddmmoptions);
rhot = deval(rhott,t);

if selectscale > 0
rhot = reshape(rhot,CSP,L);

rhots = zeros(size(rhot));
rhots(1:dim,:) = rhot(1:dim,:); % points
rhots(dim+(selectscale-1)*dim+1:dim+(selectscale-1)*dim+dim,:) = rhot(dim+(selectscale-1)*dim+1:dim+(selectscale-1)*dim+dim,:);
rhot = reshape(rhots,CSP*L,1);
end
% if selectscale > 0
% rhot = reshape(rhot,CSP,L);
%
% rhots = zeros(size(rhot));
% rhots(1:dim,:) = rhot(1:dim,:); % points
% rhots(dim+(selectscale-1)*dim+1:dim+(selectscale-1)*dim+dim,:) = rhot(dim+(selectscale-1)*dim+1:dim+(selectscale-1)*dim+dim,:);
% rhot = reshape(rhots,CSP*L,1);
% end

dgrid = fastPointTransportOrder0(t,gridt,rhot,L,R,cdim,scales.^2,scaleweight.^2);
dgrid = intResult(dgrid,backwards,lddmmoptions);
Expand Down Expand Up @@ -93,8 +97,8 @@

siz = size(grid0{1});
g0 = reshape([reshape(grid0{1},1,Ngrid); reshape(grid0{2},1,Ngrid); reshape(grid0{3},1,Ngrid)],3*Ngrid,1);
gridt = ode45(@Gc,[0 1],g0);
g1 = reshape(deval(gridt,1),3,Ngrid);
gridt = ode45(@Gc,[0 tend],g0);
g1 = reshape(deval(gridt,tend),3,Ngrid);
grid1 = cell(1);
grid1{1} = reshape(g1(1,:),siz);
grid1{2} = reshape(g1(2,:),siz);
Expand Down
11 changes: 8 additions & 3 deletions registration/visualizers/getPointVisualizer.m
Expand Up @@ -37,10 +37,15 @@
end


function pointVisualizer(x)
function pointVisualizer(x,varargin)
tend = 1;
if size(varargin,2) > 0
tend = varargin{1};
end


% determine area
res = transport(x,moving);
res = transport(x,moving,false,tend);
if ~isempty(fixed)
minx = min([moving(1,:) fixed(1,:) res(1,:)]);
maxx = max([moving(1,:) fixed(1,:) res(1,:)]);
Expand Down Expand Up @@ -79,7 +84,7 @@ function pointVisualizer(x)

printStatus('integrating grid');
pgrid0 = [reshape(grid0{1},1,[]); reshape(grid0{2},1,[]); reshape(grid0{3},1,[])];
pgrid1 = transport(x,pgrid0);
pgrid1 = transport(x,pgrid0,false,tend);
grid1{1} = reshape(pgrid1(1,:),size(grid0{1}));
grid1{2} = reshape(pgrid1(2,:),size(grid0{1}));
grid1{3} = zeros(size(grid0{1}));
Expand Down
93 changes: 93 additions & 0 deletions tests/toyexsOrder0.m
@@ -0,0 +1,93 @@
% produce toy figures (originally for higher-order momenta paper)

% LDDMM options
clear lddmmoptions
lddmmoptions.scale = 1.0; % Gaussian kernels
lddmmoptions.energyweight = [1 12]; % weighting between energy of curve and match
lddmmoptions.energyweight = lddmmoptions.energyweight/sum(lddmmoptions.energyweight);
lddmmoptions.order = 0;

% output options
% global globalOptions;
% globalOptions.verbose = true;
visoptions.dim = 2;
visoptions.Ngridpoints = 25;
visoptions.margin = 1.5;
visoptions.skip = 4;
visoptions.fixedDim = true;
visoptions.minx = 3.5;
visoptions.maxx = 6.5;
visoptions.miny = 3.5;
visoptions.maxy = 6.5;
% scale relative to grid:
scaleInGridUnits = lddmmoptions.scale/(2*visoptions.margin/(visoptions.Ngridpoints-1))

options = getDefaultOptions();

% one point examples
dim = 2;
moving = [
5.0 5.0;
]';
% fixed - matching against these
fixed = [
5.0 5.0;
]';
[methods lddmmoptions1] = setupPointLDDMM(moving,fixed,[],lddmmoptions);
visualizer = getPointVisualizer(methods.transport,moving,[],visoptions);
% translation
figure(1)
x = [1.25 0]';
visualizer(x);
% approximated derivatives
figure(8)
moving = [
5.25 5.0;
4.75 5.0;
5.0 5.25;
5.0 4.75;
]';
% fixed
fixed = [
5.2 5.0;
4.8 5.0;
5.0 5.2;
5.0 4.8;
]';
[methods lddmmoptions1] = setupPointLDDMM(moving,fixed,[],lddmmoptions);
visualizer = getPointVisualizer(methods.transport,moving,[],visoptions);
x = [
2.3 0;
-2.3 0;
0 2.3;
0 -2.3;
]';
visualizer(x);


% two points examples
dim = 2;
moving = [
4.5 5.0;
5.5 5.0;
]';
% fixed - matching against these
fixed = [
4.0 5.0;
5.0 5.0;
]';
[methods lddmmoptions1] = setupPointLDDMM(moving,fixed,[],lddmmoptions);
visualizer = getPointVisualizer(methods.transport,moving,[],visoptions);
% rotation
figure(4)
x = 10*[
1 0;
-1 0;
]';
x = reshape(x,[],1);
tend = 1.0;
visualizer(x,tend);
[res,Gt] = methods.transport(x,moving,false,tend);
E = methods.pathEnergy(x,Gt,tend);
% E
% Gt.x, diff(Gt.x)

0 comments on commit 1361e1f

Please sign in to comment.