Skip to content

Commit

Permalink
add factorized (per-image) optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
dstndstn committed Feb 28, 2024
1 parent cdb8200 commit 20d28a7
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 9 deletions.
17 changes: 12 additions & 5 deletions tractor/dense_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ class ConstrainedDenseOptimizer(ConstrainedOptimizer):

def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
scale_columns=True, scales_only=False,
chiImages=None, variance=False,
chiImages=None,
variance=False,
shared_params=True,
get_A_matrix=False):

Expand All @@ -25,12 +26,12 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
# I don't want to deal with this right now!
assert(shared_params == False)
assert(scales_only == False)
assert(damp == 0.)
assert(variance == False)
assert(damp == 0.)

# Returns: numpy array containing update direction.
# If *variance* is True, return (update,variance)
# If *get_A_matrix* is True, returns the sparse matrix of derivatives.
# If *get_A_matrix* is True, returns the matrix of derivatives.
# If *scale_only* is True, return column scalings
# In cases of an empty matrix, returns the list []
#
Expand Down Expand Up @@ -231,8 +232,9 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
plt.ylabel('Relative change in New matrix element')
plt.ylim(-mx, mx)
ps.savefig()

del A

if not get_A_matrix:
del A
del B

if scale_columns:
Expand All @@ -250,6 +252,11 @@ def getUpdateDirection(self, tractor, allderivs, damp=0., priors=True,
if not np.all(np.isfinite(X)):
return None

if get_A_matrix:
if scale_columns:
A *= colscales[np.newaxis,:]
return X,A

return X


119 changes: 119 additions & 0 deletions tractor/factored_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from tractor.dense_optimizer import ConstrainedDenseOptimizer
import numpy as np

'''
A mixin class for LsqrOptimizer that does the linear update direction step
by factorizing over images -- it solves the linear problem for each image
independently, and then combines those results (via their covariances) into
the overall result.
'''
class FactoredOptimizer(object):

def getSingleImageUpdateDirection(self, tr, **kwargs):
#print('getSingleImageUpdateDirection( kwargs=', kwargs, ')')
allderivs = tr.getDerivs()
x,A = self.getUpdateDirection(tr, allderivs, get_A_matrix=True, **kwargs)
icov = np.matmul(A.T, A)
del A
return x, icov

def getLinearUpdateDirection(self, tr, **kwargs):
#print('getLinearUpdateDirection( kwargs=', kwargs, ')')
img_opts = []
from tractor import Images

imgs = tr.images
for i,img in enumerate(imgs):
tr.images = Images(img)
x,x_icov = self.getSingleImageUpdateDirection(tr, **kwargs)
# print('Opt for img', i, ':')
# print(x)
# print('And icov')
# print(x_icov)
img_opts.append((x,x_icov))
tr.images = imgs

# ~ inverse-covariance-weighted sum of img_opts...
xicsum = 0
icsum = 0
for x,ic in img_opts:
xicsum = xicsum + np.dot(ic, x)
icsum = icsum + ic
C = np.linalg.inv(icsum)
x = np.dot(C, xicsum)
# print('Total opt:')
# print(x)
return x


class FactoredDenseOptimizer(FactoredOptimizer, ConstrainedDenseOptimizer):
pass


if __name__ == '__main__':

import pylab as plt
from tractor import Image, PixPos, Flux, Tractor, NullWCS, NCircularGaussianPSF, PointSource

n_ims = 2
sig1s = [3., 10.]
H,W = 50,50
cx,cy = 23,27
psf_sigmas = [2., 1.]
fluxes = [500., 500.]

tims = []
for i in range(n_ims):
x = np.arange(W)
y = np.arange(H)
data = np.exp(-0.5 * ((x[np.newaxis,:] - cx)**2 + (y[:,np.newaxis] - cy)**2) /
psf_sigmas[i]**2)
data *= fluxes[i] / (2. * np.pi * psf_sigmas[i]**2)
data += np.random.normal(size=(50,50)) * sig1s[i]

tims.append(Image(data=data, inverr=np.ones_like(data) / sig1s[i],
psf=NCircularGaussianPSF([psf_sigmas[i]], [1.]),
wcs=NullWCS()))
src = PointSource(PixPos(W//2, H//2), Flux(100.))

opt = FactoredDenseOptimizer()

opt2 = ConstrainedDenseOptimizer()

tr = Tractor(tims, [src], optimizer=opt)
tr2 = Tractor(tims, [src], optimizer=opt2)
tr.freezeParam('images')
tr2.freezeParam('images')

mods = list(tr.getModelImages())
plt.clf()
for i in range(n_ims):
ima = dict(interpolation='nearest', origin='lower', vmin=-3.*sig1s[i],
vmax=5.*sig1s[i])
plt.subplot(2,2, i*2 + 1)
plt.imshow(tims[i].data, **ima)
plt.subplot(2,2, i*2 + 2)
plt.imshow(mods[i], **ima)
plt.savefig('1.png')

fit_kwargs = dict(shared_params=False, priors=False)
up1 = tr.optimizer.getLinearUpdateDirection(tr, **fit_kwargs)
up2 = tr2.optimizer.getLinearUpdateDirection(tr2, **fit_kwargs)

print('Update directions:')
print(up1)
print(up2)

tr.optimize_loop(**fit_kwargs)

mods = list(tr.getModelImages())
plt.clf()
for i in range(n_ims):
ima = dict(interpolation='nearest', origin='lower', vmin=-3.*sig1s[i],
vmax=5.*sig1s[i])
plt.subplot(2,2, i*2 + 1)
plt.imshow(tims[i].data, **ima)
plt.subplot(2,2, i*2 + 2)
plt.imshow(mods[i], **ima)
plt.savefig('2.png')

19 changes: 15 additions & 4 deletions tractor/lsqr_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,12 @@ def _lnp_for_update(self, tractor, mod0, imgs, umodels, X, alpha, p0, rois,
lnp += -0.5 * chisq
return lnp, chis, ims

def optimize(self, tractor, alphas=None, damp=0, priors=True,
scale_columns=True,
shared_params=True, variance=False, just_variance=False,
**nil):
def getLinearUpdateDirection(self, tractor,
damp=0,
priors=True,
scale_columns=True,
shared_params=True,
variance=False):
#logverb(tractor.getName() + ': Finding derivs...')
#t0 = Time()
allderivs = tractor.getDerivs()
Expand All @@ -270,6 +272,15 @@ def optimize(self, tractor, alphas=None, damp=0, priors=True,
scale_columns=scale_columns,
shared_params=shared_params,
variance=variance)
return X

def optimize(self, tractor, alphas=None, damp=0, priors=True,
scale_columns=True,
shared_params=True, variance=False, just_variance=False,
**nil):
kwa = dict(damp=damp, priors=priors,
scale_columns=scale_columns, shared_params=shared_params)
X = self.getLinearUpdateDirection(tractor, **kwa)
#print('Update:', X)
if X is None:
# Failure
Expand Down

0 comments on commit 20d28a7

Please sign in to comment.