Skip to content

Commit

Permalink
made oidata optional argumant in oimodel
Browse files Browse the repository at this point in the history
  • Loading branch information
ceyzeriat committed Sep 10, 2016
1 parent 4e1a615 commit 46f0528
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 61 deletions.
15 changes: 9 additions & 6 deletions soif/oidata.py
Expand Up @@ -68,7 +68,7 @@ def _init(self, src, **kwargs):
break
else:
hdus.close()
if exc.raiseIt(exc.NoTargetTable, self.raiseError, src=self.src): return False
if exc.raiseIt(exc.NoTargetTable, self.raiseError, src=self.src): return
self._targets = {}
for ind, tgt in zip(hdutgt.data["TARGET_ID"], hdutgt.data["TARGET"]):
self._targets[ind] = tgt
Expand Down Expand Up @@ -165,7 +165,7 @@ def __init__(self, datatype, **kwargs):
self.raiseError = bool(kwargs.pop('raiseError', True))
self.datatype = str(datatype).upper()
if self.datatype not in core.ATTRDATATYPE.keys():
if exc.raiseIt(exc.InvalidDataType, self.raiseError, datatype=self.datatype): return False
if exc.raiseIt(exc.InvalidDataType, self.raiseError, datatype=self.datatype): return
self._has = False
self._useit = False

Expand All @@ -178,6 +178,9 @@ def __str__(self):

@property
def useit(self):
"""
Turn this True or False whether you wan't the fitting to include this datatype
"""
return (self._useit and self._has)
@useit.setter
def useit(self, value):
Expand All @@ -202,9 +205,9 @@ def __init__(self, src, hduidx, datatype, hduwlidx, indices=[], wlindices=[], de
hduwl = hdus[self._input_hduwlidx[-1]]

if self.datatype not in core.ATTRDATATYPE.keys():
if exc.raiseIt(exc.InvalidDataType, self.raiseError, datatype=self.datatype): return False
if exc.raiseIt(exc.InvalidDataType, self.raiseError, datatype=self.datatype): return
if core.DATAKEYSDATATYPE[self.datatype]['data'] not in core.hduToColNames(hdu):
if exc.raiseIt(exc.HduDatatypeMismatch, self.raiseError, hduhead=core.hduToDataType(hdu), datatype=self.datatype): return False
if exc.raiseIt(exc.HduDatatypeMismatch, self.raiseError, hduhead=core.hduToDataType(hdu), datatype=self.datatype): return

self._input_degree = [bool(degree)]
self._input_flatten = [bool(flatten)]
Expand Down Expand Up @@ -442,10 +445,10 @@ def __init__(self, src, datafilter, wl=[None, None], erb_sigma=None, sigma_erb=N

self.erb_sigma = core.ident if erb_sigma is None else erb_sigma
if not callable(self.erb_sigma):
if exc.raiseIt(exc.NotCallable, self.raiseError, fct="erb_sigma"): return False
if exc.raiseIt(exc.NotCallable, self.raiseError, fct="erb_sigma"): return
self.sigma_erb = core.ident if sigma_erb is None else sigma_erb
if not callable(self.sigma_erb):
if exc.raiseIt(exc.NotCallable, self.raiseError, fct="sigma_erb"): return False
if exc.raiseIt(exc.NotCallable, self.raiseError, fct="sigma_erb"): return
self.systematic_bounds = None if systematic_bounds is None else list(map(float, list(systematic_bounds)[:2]))
self.systematic_prior = None if systematic_prior is None else float(systematic_prior)
self._systematic_prior = self.systematic_prior
Expand Down
30 changes: 19 additions & 11 deletions soif/oiexception.py
Expand Up @@ -53,84 +53,92 @@ class NoTargetTable(OIException):
"""
def __init__(self, src="", *args, **kwargs):
super(NoTargetTable, self).__init__(src, *args, **kwargs)
self.message = "There's no OI_TARGET table in the oifits file '%s'! Go get some coffee!" % (src)
self.message = "There's no OI_TARGET table in the oifits file '%s'! Better go get some coffee before tackling that one" % (src)

class NoDataModel(OIException):
"""
If there is no Oidata provided with the model
"""
def __init__(self, *args, **kwargs):
super(NoTargetTable, self).__init__(src, *args, **kwargs)
self.message = "There is no data in this model. You can't do that you are specifically authorized"

class NoWavelengthTable(OIException):
"""
If the file has no OITARGET table
"""
def __init__(self, src="", *args, **kwargs):
super(NoWavelengthTable, self).__init__(src, *args, **kwargs)
self.message = "There's no OI_WAVELENGTH table in the oifits file '%s'! You're pretty much screwed!" % (src)
self.message = "There's no OI_WAVELENGTH table in the oifits file '%s'! Haha, you're screwed!" % (src)

class ReadOnly(OIException):
"""
If the parameter is read-only
"""
def __init__(self, attr, *args, **kwargs):
super(ReadOnly, self).__init__(attr, *args, **kwargs)
self.message = "Attribute '%s' is read-only" % (attr)
self.message = "Attribute '%s' is read-only. What did you think?" % (attr)

class InvalidDataType(OIException):
"""
If the data type provided does not exist
"""
def __init__(self, datatype, *args, **kwargs):
super(InvalidDataType, self).__init__(datatype, *args, **kwargs)
self.message = "Data type '%s' does not exist" % (datatype)
self.message = "Surprise! Data type '%s' does not exist" % (datatype)

class HduDatatypeMismatch(OIException):
"""
If the data type and the hdu provided do not match
"""
def __init__(self, hduhead, datatype, *args, **kwargs):
super(HduDatatypeMismatch, self).__init__(hduhead, datatype, *args, **kwargs)
self.message = "Data type '%s' and hdu with '%s' data do not match" % (datatype, hduhead)
self.message = "Data type '%s' and hdu with '%s' data do not match. Don't mix apples with camembert" % (datatype, hduhead)

class BadMaskShape(OIException):
"""
If the mask shape does not match the data shape
"""
def __init__(self, shape, *args, **kwargs):
super(BadMaskShape, self).__init__(shape, *args, **kwargs)
self.message = "Bad mask shape. Should be '%s'" % shape
self.message = "Bad mask shape. Should be '%s'. Nudge it and try again" % shape

class WrongData(OIException):
"""
If the data provided has the wrong data type
"""
def __init__(self, typ, *args, **kwargs):
super(WrongData, self).__init__(typ, *args, **kwargs)
self.message = "Wrong data given, should be '%s'" % typ
self.message = "Wrong data given, should be '%s'. Already told you, don't mix apples and camembert" % typ

class IncompatibleData(OIException):
"""
If the data type and the hdu provided do not match
"""
def __init__(self, typ1, typ2, *args, **kwargs):
super(IncompatibleData, self).__init__(typ1, typ2, *args, **kwargs)
self.message = "Can't merge '%s' and '%s'" % (typ1, typ2)
self.message = "Can't merge '%s' and '%s'. You looked very optimistic until now" % (typ1, typ2)

class NotADataHdu(OIException):
"""
If the hdu provided does not contain data
"""
def __init__(self, idx, src, *args, **kwargs):
super(NotADataHdu, self).__init__(idx, src, *args, **kwargs)
self.message = "Hdu index '%s' in file '%s' does not contain data" % (idx, src)
self.message = "Hdu index '%s' in file '%s' does not contain data. Yep. Too bad" % (idx, src)

class NoSystematicsFit(OIException):
"""
If the user did not set on the fit of systematics
"""
def __init__(self, *args, **kwargs):
super(NoSystematicsFit, self).__init__(*args, **kwargs)
self.message = "You are not fitting systematics"
self.message = "You are not fitting systematics. But maybe you should"

class NotCallable(OIException):
"""
If the function is callable
"""
def __init__(self, fct, *args, **kwargs):
super(NotCallable, self).__init__(fct, *args, **kwargs)
self.message = "'%s' should be callable" % fct
self.message = "'%s' should be callable. Get up and go get it" % fct
3 changes: 3 additions & 0 deletions soif/oifiting.py
Expand Up @@ -45,6 +45,9 @@

class Oifiting(MCres):
def __init__(self, model, nwalkers=100, niters=500, burnInIts=100, threads=1, customlike=None, **kwargs):
self.raiseError = bool(kwargs.pop('raiseError', True))
if not model._hasdata:
if exc.raiseIt(exc.NoDataModel, self.raiseError, src=src): return
if 'sampler' in kwargs.keys():
self._init(sampler=kwargs.pop('sampler'), paramstr=self.model.paramstr, nwalkers=self.nwalkers, niters=self.niters, burnInIts=self.burnInIts)
elif EMCEE:
Expand Down
2 changes: 1 addition & 1 deletion soif/oimainobject.py
Expand Up @@ -213,7 +213,7 @@ def getP0(self):

def compVis(self, oidata, params=None, flat=False):
"""
Does the paperwork before calculating the complex visibilities of the object
Calculates the complex visibilities of the object
"""
if params is not None: self.setParams(params)
if flat:
Expand Down
104 changes: 64 additions & 40 deletions soif/oimodel.py
Expand Up @@ -45,10 +45,13 @@


class Oimodel(object):
def __init__(self, oidata, objs=[], tweakparams=None, **kwargs):
def __init__(self, oidata=None, objs=[], tweakparams=None, **kwargs):
self._objs = []
self.raiseError = bool(kwargs.pop('raiseError', True))
self.oidata = oidata
if isinstance(oidata, Oifits):
self.oidata = oidata
else:
self.oidata = None
if tweakparams is not None and not callable(tweakparams):
if exc.raiseIt(exc.NotCallable, self.raiseError, fct="tweakparams"): return False
self._tweakparams = tweakparams
Expand All @@ -61,16 +64,26 @@ def __init__(self, oidata, objs=[], tweakparams=None, **kwargs):
self.add_obj(item)

def _info(self):
return core.font.blue+"<SOIF Model>%s\n %s objects:\n %s\n%s"%(core.font.normal, self.nobj, "\n ".join(map(str, self._objs)), str(self.oidata))
txt = core.font.blue+"<SOIF Model>%s\n %s objects:\n %s" % (core.font.normal, self.nobj, "\n ".join(map(str, self._objs)))
if self._hasdata:
return "%s\n%s" % (txt, str(self.oidata))
else:
return txt
def __repr__(self):
return self._info()
def __str__(self):
return self._info()


@property
def _hasdata(self):
return self.oidata is not None

@property
def nparams(self):
return self._nparamsObj + self.oidata.systematic_fit
if self._hasdata:
return self._nparamsObj + int(self.oidata.systematic_fit)
else:
return self._nparamsObj
@nparams.setter
def nparams(self, value):
exc.raiseIt(exc.ReadOnly, self.raiseError, attr="nparams")
Expand Down Expand Up @@ -105,7 +118,8 @@ def add_obj(self, typ, name=None, params={}, prior={}):
print(core.font.red+"ERROR: Could not find the object name for '%s', name given: %s%s" % (typ, name, core.font.normal))
return
setattr(self, "o_"+name, self._objs[-1]) # quick access as a class property
if hasattr(self._objs[-1], '_prepare'): self._objs[-1]._prepare(oidata=self.oidata)
if hasattr(self._objs[-1], 'prepare') and self._hasdata:
self._objs[-1].prepare(oidata=self.oidata)
self.nobj += 1
dum = self.nparamsObjs

Expand Down Expand Up @@ -134,7 +148,7 @@ def getP0(self):
ret = []
for item in self._objs:
ret += item.getP0()
if self.oidata.systematic_fit: ret += [self.oidata.systematic_p0()]
if getattr(self.oidata, "systematic_fit", False): ret += [self.oidata.systematic_p0()]
return ret


Expand All @@ -144,7 +158,7 @@ def paramstr(self):
for item in self._objs:
for arg in item._pkeys:
ret.append(item.name+"_"+arg)
if self.oidata.systematic_fit: ret += ["sys"]
if getattr(self.oidata, "systematic_fit", False): ret += ["sys"]
return ret
@paramstr.setter
def paramstr(self, value):
Expand All @@ -159,7 +173,7 @@ def params(self):
ret = []
for item in self._objs:
ret += getattr(item, "params", [])
if self.oidata.systematic_fit: ret += [self.oidata.systematic_prior if self.oidata.systematic_prior is not None else self.oidata.systematic_p0()]
if getattr(self.oidata, "systematic_fit", False): ret += [self.oidata.systematic_prior if self.oidata.systematic_prior is not None else self.oidata.systematic_p0()]
return ret
@params.setter
def params(self, value):
Expand All @@ -186,52 +200,58 @@ def setParams(self, params, priors=False):
for item in self._objs:
item.setParams(params=params[parampos:parampos+item._nparams], priors=priors)
parampos += item._nparams
if self.oidata.systematic_fit:
if getattr(self.oidata, "systematic_fit", False):
self.oidata.systematic_prior = params[parampos]
self.oidata._systematic_prior = self.oidata.systematic_prior


def compVis(self, params=None):
def compVis(self, params=None, u=None, v=None, wl=None):
"""
Calculate the complex visibility of the model from each separate object
Calculates the complex visibility of the model from all unitary models
"""
if params is None: params = self.getP0() # initialize at p0 in case no params is given
if self._tweakparams is not None: self._tweakparams(self, params)
parampos = 0
totflx = 0.
totviscomp = np.zeros(self.oidata.uvwl['u'].shape, dtype=complex) # initialize array
for item in self._objs:
viscomp, flx = item.compVis(oidata=self.oidata, params=params[parampos:parampos+item._nparams], flat=True)
totviscomp += viscomp*flx
totflx += flx
parampos += item._nparams
totviscomp /= totflx
return self.oidata.remorph(totviscomp)
if self._hasdata:
totviscomp = self._compVis(u=oidata.uvwl['u'], v=oidata.uvwl['v'], wl=oidata.uvwl['wl'], blwl=oidata.uvwl['blwl'], params=params)
return self.oidata.remorph(totviscomp)
else:
if u is None or v is None or wl is None:
if exc.raiseIt(exc.NoDataModel, self.raiseError, src=src): return
else:
return self._compVis(u=u, v=v, wl=wl, params=params)

def compuvimage(self, blmax, wl=None, params=None, nbpts=101):
parampos = 0
totFluxvis2 = 0.
if wl is None: wl = self.oidata._wlmin+self.oidata._wlspan*0.5
compvis = np.zeros((nbpts, nbpts), dtype=complex) # initialize array

def calcUVImage(self, blmax, wl, params=None, nbpts=101):
"""
Outputs the complex visibility for a grid a (u,v) in [-blmax,blmax], with nbpts in each dimension
"""
u, v = np.meshgrid(np.linspace(-blmax, blmax, nbpts), np.linspace(-blmax, blmax, nbpts))
return self._compVis(u=u, v=v, wl=wl, blwl=np.hypot(u, v)/wl, params=params)


def _compVis(self, u, v, wl, blwl, params=None):
parampos = 0
totFluxvis = 0.
totviscomp = np.zeros(u.shape, dtype=complex) # initialize array
if params is None:
params = self.getP0() # initialize at p0 in case no params is given
else:
self.setParams(params)
if self._tweakparams is not None: self._tweakparams(self, params)
if params is not None: self.setParams(params)
for item in self._objs:
try:
vis2 = item._calcCompVis(u=u, v=v, wl=wl, blwl=np.hypot(u, v)/wl)
except AttributeError:
raise AttributeError("it looks like some parameters are not initialized. Input your parameters after params=[...]")
compvis += vis2[0]*vis2[1]
totFluxvis2 += vis2[1]
compvis, flx = item._calcCompVis(u=u, v=v, wl=wl, blwl=blwl)
compvis *= flx
totviscomp += compvis
totFluxvis += flx
parampos += item._nparams
return compvis/totFluxvis2
totviscomp /= totFluxvis
return totviscomp


def compimage(self, params=None, sepmax=None, wl=None, masperpx=None, nbpts=101, psfConvolve=None, **kwargs):
def calcImage(self, params=None, sepmax=None, wl=None, masperpx=None, nbpts=101, psfConvolve=None, **kwargs):
"""
psfConvolve in mas (lambda/D)
"""
parampos = 0
totFluxvis2 = 0.
totFluxvis = 0.
# check for set-resolution objects
masperpxfixed = None
nbptscheck = nbpts
Expand Down Expand Up @@ -320,6 +340,8 @@ def uvimage(self, params=None, blmax=None, wl=None, typ='vis2', nbpts=101, cmap=
if ret: return toplot

def residual(self, params, c=None, cmap='jet', cm_min=None, cm_max=None, datatype='All'):
if not self._hasdata:
if exc.raiseIt(exc.NoDataModel, self.raiseError, src=src): return
calcindex = {'vis2':0, 't3phi':1, 't3amp':2, 'visphi':3, 'visamp':4}
fullmodel = self.compVis(params=params)
cm_min_orig = cm_min
Expand Down Expand Up @@ -381,6 +403,8 @@ def residual(self, params, c=None, cmap='jet', cm_min=None, cm_max=None, datatyp
# pass

def likelihood(self, params, customlike=None, chi2=False, **kwargs):
if not self._hasdata:
if exc.raiseIt(exc.NoDataModel, self.raiseError, src=src): return
kwargs['chi2'] = chi2
return standardLikelihood(params=params, model=self, customlike=customlike, kwargs=kwargs)

Expand Down Expand Up @@ -412,7 +436,7 @@ def save(self, filename, clobber=False):
for item in self._objs:
item.save(filename, append=True, clobber=clobber)

self.oidata.save(filename, append=True, clobber=clobber)
if self._hasdata: self.oidata.save(filename, append=True, clobber=clobber)

return filename

Expand Down
6 changes: 3 additions & 3 deletions soif/oiunitmodels.py
Expand Up @@ -140,19 +140,19 @@ def __init__(self, name, img=None, masperpx=None, priors={}, bounds={}, negRA=Fa
self.totFlux = img.sum() if totFlux is None else float(totFlux)
self._img = img/self.totFlux
self._prepared = False
if kwargs.get('oidata') is not None: self._prepare(oidata=kwargs['oidata'])
if kwargs.get('oidata') is not None: self.prepare(oidata=kwargs['oidata'])

def _calcCompVis(self, *args, **kwargs):
return self._compvis, 1./self.cr

def _prepare(self, oidata):
def prepare(self, oidata):
if self._prepared: return
self._compvis = core.calcImgVis(img=self.img, masperpx=self.masperpx, u=oidata.uvwl['u'], v=oidata.uvwl['v'], wl=oidata.uvwl['wl'])
self._prepared = True

def prepare(self, oidata, force=False):
if force: self._prepared = False
self._prepare(oidata=oidata)
self.prepare(oidata=oidata)

@property
def img(self):
Expand Down

0 comments on commit 46f0528

Please sign in to comment.