Skip to content

Commit

Permalink
more pep 8 on oidata and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ceyzeriat committed Sep 14, 2016
1 parent c5b78aa commit 1e97d3b
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 40 deletions.
3 changes: 2 additions & 1 deletion soif/core.py
Expand Up @@ -110,7 +110,8 @@
'is_angle': False
}
}
DATAKEYSLOWER = ['vis2', 't3phi', 't3amp', 'visphi', 'visamp']
DATAKEYSLOWER = [item.lower() for item in ATTRDATATYPE.keys()]
DATAKEYSUPPER = [item.upper() for item in ATTRDATATYPE.keys()]


def abs2(ar):
Expand Down
2 changes: 1 addition & 1 deletion soif/oidata.py
Expand Up @@ -51,7 +51,7 @@ def __init__(self, src, hduidx, datatype, hduwlidx, indices=(),
hdu = hdus[self._input_hduidx[-1]]
hduwl = hdus[self._input_hduwlidx[-1]]

if self.datatype not in core.ATTRDATATYPE.keys():
if self.datatype not in core.DATAKEYSUPPER:
if exc.raiseIt(exc.InvalidDataType,
self.raiseError,
datatype=self.datatype):
Expand Down
5 changes: 2 additions & 3 deletions soif/oidataempty.py
Expand Up @@ -36,7 +36,7 @@ class OidataEmpty(object):
def __init__(self, datatype, **kwargs):
self.raiseError = bool(kwargs.pop('raiseError', True))
self.datatype = str(datatype).upper()
if self.datatype not in core.ATTRDATATYPE.keys():
if self.datatype not in core.DATAKEYSUPPER:
if exc.raiseIt(exc.InvalidDataType,
self.raiseError,
datatype=self.datatype):
Expand All @@ -50,8 +50,7 @@ def _info(self):
def __repr__(self):
return self._info()

def __str__(self):
return self._info()
__str__ = __repr__

@property
def useit(self):
Expand Down
47 changes: 23 additions & 24 deletions soif/oigrab.py
Expand Up @@ -165,59 +165,58 @@ def show_specs(self, ret=False, **kwargs):
if ret:
return tgtlist

def show_filtered(self, tgt=None, mjd=(None, None), hduNums=(),
def show_filtered(self, tgt=None, mjd=(None, None), hdus=(),
vis2=True, t3phi=True, t3amp=True, visphi=True,
visamp=True, verbose=False, **kwargs):
"""
Given an oifits file 'src' and filtering parameters on the
target name (OI_TARGET table), the instrument name
(OI_WAVELENGTH table), the array name (OI_ARRAY table),
the observation wavelength (OI_WAVELENGTH table) and the
acquisition time [t_min, t_max] (OI_VIS2, OI_VIS, OI_T3
tables), this function returns the data indices of the data
matching all of these different filters. These lists are used
to load the data within an Oidata object.
Leave input parameter to 'None' to discard filtering on that
Give filtering parameters on the target name (OI_TARGET table),
the observation wavelength (OI_WAVELENGTH table), the
acquisition time [mjd_min, mjd_max], or the hdu index.
Returns the data indices of the data matching all of these
different filters. These lists are used to load the data within
an Oidata object.
Leave input parameter to 'None' to discard filtering on this
particular parameter.
Returns: VIS2, T3, VIS indices as a tuple of 3 lists
"""
hdus = pf.open(self.src)
mjd = [float(mjd[0] if mjd[0] is not None else -np.inf),
float(mjd[1] if mjd[1] is not None else np.inf)]
allhdus = pf.open(self.src)
mjd = (float(mjd[0] if mjd[0] is not None else -np.inf),
float(mjd[1] if mjd[1] is not None else np.inf))
datayouwant = {'data': {'VIS2': bool(vis2),
'T3PHI': bool(t3phi),
'T3AMP': bool(t3amp),
'VISPHI': bool(visphi),
'VISAMP': bool(visamp)
}
}
hduNums = core.aslist(hduNums)
for idx, item in enumerate(hdus):
if len(hduNums) > 0 and idx not in hduNums:
hdus = core.aslist(hdus)
for idx, item in enumerate(allhdus):
# do we want this hdu?
if len(hdus) > 0 and idx not in hdus:
continue
# is this hdu actual data?
if core.hduToDataType(item) is not None:
mjditem = core.gethduMJD(item)[1].ravel()
filt = ((mjditem >= mjd[0]) & (mjditem <= mjd[1]))
if tgt is not None:
filt = (filt & (item.data.field("TARGET_ID") == int(tgt)))
if verbose:
print("{}:\n {}/{}\n".format(
print("hdu {:d}: {}:\n {}/{}\n".format(
idx,
core.hduToDataType(item),
filt.sum(),
item.data["TARGET_ID"].size))
datayouwant[idx] = np.arange(item.data["TARGET_ID"].size)[filt]
hdus.close()
if filt.any():
datayouwant[idx] = np.arange(item.data["TARGET_ID"].size)[filt]
allhdus.close()
return datayouwant

def extract(self, tgt=None, mjd=(None, None), wl=(None, None), hduNums=(),
def extract(self, tgt=None, mjd=(None, None), wl=(None, None), hdus=(),
vis2=True, t3phi=True, t3amp=True, visphi=True, visamp=True,
flatten=False, degree=True, significant_figures=5,
erb_sigma=None, sigma_erb=None, systematic_prior=None,
systematic_bounds=None, verbose=False, **kwargs):
datayouwant = self.show_filtered(
tgt=tgt, mjd=mjd, vis2=vis2, hduNums=hduNums,
tgt=tgt, mjd=mjd, vis2=vis2, hdus=hdus,
t3phi=t3phi, t3amp=t3amp, visphi=visphi, visamp=visamp,
verbose=verbose, **kwargs)
return Oifits(src=self.src, datafilter=datayouwant, flatten=flatten,
Expand Down
3 changes: 2 additions & 1 deletion soif/test/test_oidata.py
Expand Up @@ -35,7 +35,7 @@
from ..oifits import Oifits
from .. import oiexception as exc

FILENAME = os.path.dirname(os.path.abspath(__file__)) + '/MWC361.oifits'
"""FILENAME = os.path.dirname(os.path.abspath(__file__)) + '/MWC361.oifits'
FILENAME_NOTARGET = os.path.dirname(os.path.abspath(__file__)) + '/MWC361_notarget.oifits'
FILENAME_NOWL = os.path.dirname(os.path.abspath(__file__)) + '/MWC361_nowl.oifits'
VALIDHDU = 4
Expand All @@ -50,3 +50,4 @@ def test_oigrab():
@raises(exc.NoTargetTable)
def test_oigrab_NoTargetTable():
oig = Oigrab(FILENAME_NOTARGET)
"""
66 changes: 66 additions & 0 deletions soif/test/test_oidataempty.py
@@ -0,0 +1,66 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

###############################################################################
#
# SOIF - Sofware for Optical Interferometry fitting
# Copyright (C) 2016 Guillaume Schworer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# For any information, bug report, idea, donation, hug, beer, please contact
# guillaume.schworer@obspm.fr
#
###############################################################################


import numpy as np
import os
from nose.tools import raises

from ..oidata import Oidata
from ..oidataempty import OidataEmpty
from .. import oiexception as exc
from .. import core




def test_create():
for typ in core.DATAKEYSUPPER:
oie = OidataEmpty(datatype=typ.lower())
assert not oie.useit
oie.useit = True
assert not oie.useit
assert not bool(oie)
assert not oie
oie._has = True
assert oie.useit
assert bool(oie)
assert oie
oie.useit = False
assert not oie.useit
assert bool(oie)
assert oie
assert str(oie) == repr(oie)


@raises(exc.InvalidDataType)
def test_InvalidDataType():
oie = OidataEmpty(datatype='random')

def test_InvalidDataType_noraise():
oie = OidataEmpty(datatype='random', raiseError=False)
assert not hasattr(oie, '_has')

48 changes: 48 additions & 0 deletions soif/test/test_oifits.py
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

###############################################################################
#
# SOIF - Sofware for Optical Interferometry fitting
# Copyright (C) 2016 Guillaume Schworer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# For any information, bug report, idea, donation, hug, beer, please contact
# guillaume.schworer@obspm.fr
#
###############################################################################


import numpy as np
import os
from nose.tools import raises

from ..oidata import Oidata
from ..oidataempty import OidataEmpty
from ..oifits import Oifits
from ..oigrab import Oigrab
from .. import oiexception as exc




"""def test_extract():
oig = Oigrab(FILENAME)
ans1 = oig.extract(tgt=VALIDTGT)
filt = np.asarray([item[1] for item in oig.show_specs(ret=True)[VALIDHDU]]) == VALIDTGT
ans2 = Oifits(oig.src, datafilter={VALIDHDU: np.arange(DATASETSIZE)[filt]+1})
assert np.allclose(ans1.vis2.data, ans2.vis2.data)
"""
32 changes: 22 additions & 10 deletions soif/test/test_oigrab.py
Expand Up @@ -40,6 +40,7 @@
FILENAME_NOWL = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'MWC361_nowl.oifits')
FILENAME_FULL = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'MWC361_full.oifits')
VALIDHDU = 4
T3HDU = 6
DATASETSIZE = 12
VALIDTGT = 1

Expand All @@ -53,38 +54,49 @@ def test():
def test_NoTargetTable():
oig = Oigrab(FILENAME_NOTARGET)

def test_NoTargetTable_noraise():
oig = Oigrab(FILENAME_NOTARGET, raiseError=False)
assert not hasattr(oig, '_targets')

@raises(exc.NoWavelengthTable)
def test_NoWavelengthTable():
oig = Oigrab(FILENAME_NOWL)

def test_NoWavelengthTable_noraise():
oig = Oigrab(FILENAME_NOWL, raiseError=False)

@raises(exc.ReadOnly)
def test_NoTargetTable():
oig = Oigrab(FILENAME)
oig.targets = []

def test_show_specs():
oig = Oigrab(FILENAME)
ans = oig.show_specs(ret=False)
ans = oig.show_specs(ret=True)
for item in range(10):
if item != VALIDHDU:
assert ans.get(item) is None
assert len(ans[VALIDHDU]) == DATASETSIZE
assert np.allclose(ans[VALIDHDU][0], (0, 0, 57190.4437, 1, 38))
assert (np.diff([item[2] for item in ans[VALIDHDU]]) >= 0).all()


def test_show_filtered():
oig = Oigrab(FILENAME)
for item in range(10):
if item != VALIDHDU:
assert oig.show_filtered(tgt=VALIDTGT).get(item) is None
assert oig.show_filtered(tgt=VALIDTGT, verbose=True).get(item) is None
else:
assert oig.show_filtered(tgt=VALIDTGT).get(item).tolist() == [ 2, 5, 8, 11]
assert oig.show_filtered(tgt=VALIDTGT, verbose=True).get(item).tolist() == [ 2, 5, 8, 11]
oig = Oigrab(FILENAME_FULL)

def test_extract():
oig = Oigrab(FILENAME)
ans1 = oig.extract(tgt=VALIDTGT)
filt = np.asarray([item[1] for item in oig.show_specs(ret=True)[VALIDHDU]]) == VALIDTGT
ans2 = Oifits(oig.src, datafilter={VALIDHDU: np.arange(DATASETSIZE)[filt]+1})
assert np.allclose(ans1.vis2.data, ans2.vis2.data)
ans = oig.show_filtered(tgt=VALIDTGT, verbose=True)
assert len(ans) == 5
assert len(ans[VALIDHDU]) == 4
assert len(ans[T3HDU]) == 140
assert len(oig.show_filtered(tgt=VALIDTGT, hdus=(VALIDHDU,T3HDU), verbose=True)) == 3
ans = oig.show_filtered(tgt=VALIDTGT, hdus=(T3HDU), t3amp=False, mjd=(55636.3382228, 55636.3396117), verbose=True)
assert len(ans) == 2
assert ans['data']['T3AMP'] == False
assert ans[T3HDU].min() == 70
assert ans[T3HDU].max() == 174
assert len(ans[T3HDU]) == 70

0 comments on commit 1e97d3b

Please sign in to comment.