Skip to content

Commit

Permalink
Refactored by Sourcery
Browse files Browse the repository at this point in the history
  • Loading branch information
sourcery-ai-bot committed Jun 11, 2020
1 parent 8494339 commit d0f9056
Show file tree
Hide file tree
Showing 18 changed files with 94 additions and 141 deletions.
3 changes: 1 addition & 2 deletions conftest.py
Expand Up @@ -15,8 +15,7 @@
@pytest.fixture
def published_data():
name = "data/precision/precision_figueira_2016.dat"
df = pd.read_csv(name, sep="\t")
return df
return pd.read_csv(name, sep="\t")


@pytest.fixture(
Expand Down
44 changes: 22 additions & 22 deletions eniric/_config.py
Expand Up @@ -48,33 +48,33 @@ def __getitem__(self, key):
return self._config[key]

def __setitem__(self, key, value):
ret = self._config.__setitem__(key, value)
return ret
return self._config.__setitem__(key, value)

def __getattr__(self, key):
if key in self:
if key == "paths":
paths = self["paths"]
for k, value in paths.items():
if isinstance(value, list):
paths[k] = os.path.join(*value)
return paths
elif key == "cache":
cache = self["cache"]
if (cache["location"] is None) or (cache["location"] == "None"):
cache["location"] = None # Disables caching
elif isinstance(cache["location"], list):
cache["location"] = os.path.join(*cache["location"])
return cache
return self[key]
else:
if key not in self:
return super().__getattribute__(key)
if key == "cache":
cache = self["cache"]
if (cache["location"] is None) or (cache["location"] == "None"):
cache["location"] = None # Disables caching
elif isinstance(cache["location"], list):
cache["location"] = os.path.join(*cache["location"])
return cache
elif key == "paths":
paths = self["paths"]
for k, value in paths.items():
if isinstance(value, list):
paths[k] = os.path.join(*value)
return paths
return self[key]

def __setattr__(self, key, value):
if key not in ["_path", "_protect_rewrites", "_config", "pathdir"]:
if key in self:
self.__setitem__(key, value)
self._rewrite()
if (
key not in ["_path", "_protect_rewrites", "_config", "pathdir"]
and key in self
):
self.__setitem__(key, value)
self._rewrite()

super().__setattr__(key, value)

Expand Down
28 changes: 12 additions & 16 deletions eniric/atmosphere.py
Expand Up @@ -85,10 +85,7 @@ def __init__(
), "Wavelength and transmission do not match length."
self.wl = np.asarray(wavelength)
self.transmission = np.asarray(transmission)
if std is None:
self.std = np.zeros_like(wavelength)
else:
self.std = np.asarray(std)
self.std = np.zeros_like(wavelength) if std is None else np.asarray(std)
if mask is None:
self.mask = np.ones_like(wavelength, dtype=bool)
else:
Expand Down Expand Up @@ -305,22 +302,21 @@ def barycenter_broaden(self, rv: float = 30.0, consecutive_test: bool = False):
if consecutive_test:
# Mask value False if 3 or more consecutive zeros in slice.
len_consec_zeros = consecutive_truths(~mask_slice)
if np.all(
~mask_slice
if (
np.all(~mask_slice)
or not np.all(~mask_slice)
and np.max(len_consec_zeros) >= 3
): # All pixels of slice is zeros (shouldn't get here)
this_mask_value = False
elif np.max(len_consec_zeros) >= 3:
this_mask_value = False
else:
this_mask_value = True
if np.sum(~mask_slice) > 3:
if self.verbose:
print(
(
"There were {0}/{1} zeros in this "
"barycentric shift but None were 3 consecutive!"
).format(np.sum(~mask_slice), len(mask_slice))
)
if np.sum(~mask_slice) > 3 and self.verbose:
print(
(
"There were {0}/{1} zeros in this "
"barycentric shift but None were 3 consecutive!"
).format(np.sum(~mask_slice), len(mask_slice))
)

else:
this_mask_value = np.bool(
Expand Down
8 changes: 2 additions & 6 deletions eniric/broaden.py
Expand Up @@ -384,9 +384,7 @@ def unitary_gaussian(
sigma = np.abs(fwhm) / (2 * np.sqrt(2 * np.log(2)))
amp = 1.0 / (sigma * np.sqrt(2 * np.pi))
tau = -((x - center) ** 2) / (2 * (sigma ** 2))
kernel = amp * np.exp(tau)

return kernel
return amp * np.exp(tau)


def rotation_kernel(
Expand Down Expand Up @@ -420,9 +418,7 @@ def rotation_kernel(

c1 = 2.0 * (1.0 - epsilon) / denominator
c2 = 0.5 * np.pi * epsilon / denominator
kernel = c1 * np.sqrt(1.0 - lambda_ratio_sqr) + c2 * (1.0 - lambda_ratio_sqr)

return kernel
return c1 * np.sqrt(1.0 - lambda_ratio_sqr) + c2 * (1.0 - lambda_ratio_sqr)


def oned_circle_kernel(x: ndarray, center: float, fwhm: float):
Expand Down
5 changes: 2 additions & 3 deletions eniric/io_module.py
Expand Up @@ -274,9 +274,8 @@ def pdwrite_cols(filename: str, *data, **kwargs) -> int:
if kwargs: # check for unwanted key words
raise TypeError("Unexpected **kwargs: {!r}".format(kwargs))

if header is not None:
if len(header) != len(data):
raise ValueError("Size of data and header does not match.")
if header is not None and len(header) != len(data):
raise ValueError("Size of data and header does not match.")

data_dict = {}
for i, data_i in enumerate(data):
Expand Down
4 changes: 1 addition & 3 deletions eniric/legacy.py
Expand Up @@ -80,9 +80,7 @@ def RVprec_calc_masked(

# Zeros created from the initial empty array, when skipping single element chunks)
slice_rvs = slice_rvs[np.nonzero(slice_rvs)] # Only use nonzero values.
rv_value = 1.0 / (np.sqrt(np.nansum((1.0 / slice_rvs) ** 2.0)))

return rv_value
return 1.0 / (np.sqrt(np.nansum((1.0 / slice_rvs) ** 2.0)))


def mask_clumping(
Expand Down
13 changes: 4 additions & 9 deletions eniric/precision.py
Expand Up @@ -169,11 +169,7 @@ def sqrt_sum_wis(
pixel_wis = pixel_weights(wavelength, flux, grad=grad)

# Apply masking function
if grad:
masked_wis = pixel_wis * mask
else:
masked_wis = pixel_wis * mask[:-1]

masked_wis = pixel_wis * mask if grad else pixel_wis * mask[:-1]
sqrt_sum = np.sqrt(np.nansum(masked_wis))
if not np.isfinite(sqrt_sum):
warnings.warn("Weight sum is not finite = {}".format(sqrt_sum))
Expand All @@ -189,7 +185,7 @@ def sqrt_sum_wis(
def mask_check(mask):
"""Checks for mask array."""
if isinstance(mask, u.Quantity):
if not (mask.unit == u.dimensionless_unscaled):
if mask.unit != u.dimensionless_unscaled:
raise TypeError(
"Mask should not be a non-dimensionless and unscaled Quantity!"
)
Expand Down Expand Up @@ -218,8 +214,7 @@ def slope(wavelength: Union[ndarray, Quantity], flux: Union[ndarray, Quantity]):
ffd: numpy.ndarray
FFD slope of spectrum with n-1 points.
"""
ffd = np.diff(flux) / np.diff(wavelength)
return ffd
return np.diff(flux) / np.diff(wavelength)


def pixel_weights(
Expand Down Expand Up @@ -252,8 +247,8 @@ def pixel_weights(
else:
flux_variance = flux

dydx_unit = 1
if grad:
dydx_unit = 1
# Hack for quantities with numpy gradient
if isinstance(flux, Quantity):
dydx_unit *= flux.unit
Expand Down
18 changes: 6 additions & 12 deletions eniric/snr_normalization.py
Expand Up @@ -86,12 +86,10 @@ def snr_constant_band(
if not (wav[0] < band_middle < wav[-1]):
raise ValueError("Band center not in wavelength range.")

norm_value = snr_constant_wav(
return snr_constant_wav(
wav, flux, wav_ref=band_middle, snr=snr, sampling=sampling, verbose=verbose
)

return norm_value


def snr_constant_wav(
wav: ndarray,
Expand Down Expand Up @@ -147,8 +145,7 @@ def snr_constant_wav(
snr_estimate, wav_ref
)
)
norm_value = (snr_estimate / snr) ** 2
return norm_value
return (snr_estimate / snr) ** 2


def sampling_index(
Expand Down Expand Up @@ -177,16 +174,13 @@ def sampling_index(
# index values must be integer
indexes = np.arange(index - half_sampling, index + half_sampling, dtype=int)
assert len(indexes) % 2 == 0 # confirm even
assert len(indexes) == sampling
else:
indexes = index + np.arange(-half_sampling, sampling - half_sampling, dtype=int)
assert len(indexes) % 2 != 0 # confirm odd
assert len(indexes) == sampling

if array_length is not None:
if np.any(indexes >= array_length):
# This may need fixed up in the future.
raise ValueError("Indexes has values greater than the length of array.")
assert len(indexes) == sampling
if array_length is not None and np.any(indexes >= array_length):
# This may need fixed up in the future.
raise ValueError("Indexes has values greater than the length of array.")

if np.any(indexes < 0):
raise ValueError("Indexes has values less than 0.")
Expand Down
31 changes: 8 additions & 23 deletions eniric/utilities.py
Expand Up @@ -175,10 +175,7 @@ def res2int(res: Any) -> int:
if isinstance(res, (np.int, np.float, int, float)):
value = res
elif isinstance(res, str):
if res.lower().endswith("k"):
value = float(res[:-1]) * 1000
else:
value = float(res)
value = float(res[:-1]) * 1000 if res.lower().endswith("k") else float(res)
else:
raise TypeError("Resolution name Type error of type {}".format(type(res)))

Expand Down Expand Up @@ -211,10 +208,7 @@ def res2str(res: Any) -> str:
if isinstance(res, (np.int, np.float)):
value = res / 1000
elif isinstance(res, str):
if res.lower().endswith("k"):
value = res[:-1]
else:
value = float(res) / 1000
value = res[:-1] if res.lower().endswith("k") else float(res) / 1000
else:
raise TypeError("Resolution name TypeError of type {}".format(type(res)))

Expand Down Expand Up @@ -247,7 +241,7 @@ def rv_cumulative_full(rv_vector: Union[List, ndarray]) -> ndarray:
"""Function that calculates the cumulative RV vector weighted_error. In both directions."""
assert len(rv_vector) == 5, "This hardcoded solution only meant for 5 bands."

cumulation = np.asarray(
return np.asarray(
[
weighted_error(rv_vector[0]), # First
weighted_error(rv_vector[:2]),
Expand All @@ -261,15 +255,12 @@ def rv_cumulative_full(rv_vector: Union[List, ndarray]) -> ndarray:
],
dtype=float,
)
return cumulation


def weighted_error(rv_vector: Union[List[float], ndarray]) -> float:
"""Function that calculates the average weighted error from a vector of errors."""
rv_vector = np.asarray(rv_vector)
rv_value = 1.0 / (np.sqrt(np.sum((1.0 / rv_vector) ** 2.0)))

return rv_value
return 1.0 / (np.sqrt(np.sum((1.0 / rv_vector) ** 2.0)))


def moving_average(x: ndarray, window_size: Union[int, float]) -> ndarray:
Expand Down Expand Up @@ -386,7 +377,7 @@ def load_btsettl_spectrum(
"""
from Starfish.grid_tools import CIFISTGridInterface as BTSETTL

if (2 < len(params)) and (len(params) <= 4):
if len(params) > 2 and len(params) <= 4:
assert params[2] == 0
assert params[-1] == 0 # Checks index 3 when present.
params = params[0:2] # Only allow 2 params
Expand Down Expand Up @@ -451,8 +442,7 @@ def doppler_shift_wav(wavelength: ndarray, vel: float):
if not np.isfinite(vel):
ValueError("The velocity is not finite.")

shifted_wavelength = wavelength * (1 + (vel / const.c.to("km/s").value))
return shifted_wavelength
return wavelength * (1 + (vel / const.c.to("km/s").value))


def doppler_shift_flux(
Expand Down Expand Up @@ -486,8 +476,7 @@ def doppler_shift_flux(

if new_wav is None:
new_wav = wavelength
new_flux = np.interp(new_wav, shifted_wavelength, flux)
return new_flux
return np.interp(new_wav, shifted_wavelength, flux)


def doppler_limits(rvmax, wmin, wmax):
Expand Down Expand Up @@ -527,8 +516,4 @@ def cpu_minus_one() -> int:
num_cpu: Optional[int] = os.cpu_count()
num_cpu_minus_1: int

if (num_cpu is None) or (num_cpu == 1):
num_cpu_minus_1 = 1
else:
num_cpu_minus_1 = num_cpu - 1
return num_cpu_minus_1
return 1 if (num_cpu is None) or (num_cpu == 1) else num_cpu - 1
10 changes: 3 additions & 7 deletions scripts/barycenter_broaden_atmmodel.py
Expand Up @@ -7,6 +7,7 @@
"""


import argparse
import sys
from os.path import join
Expand All @@ -17,8 +18,7 @@
from eniric import config
from eniric.atmosphere import Atmosphere

choices = [None, "ALL"]
choices.extend(config.bands["all"])
choices = [None, "ALL", *config.bands["all"]]


def _parser():
Expand Down Expand Up @@ -51,11 +51,7 @@ def main(bands: Optional[List[str]] = None, verbose: bool = False) -> None:
Wavelength bands to perform barycenter shifts on. Default is all bands.
"""
if (bands is None) or ("ALL" in bands):
bands_ = config.bands["all"]
else:
bands_ = bands

bands_ = config.bands["all"] if (bands is None) or ("ALL" in bands) else bands
for band in bands_:
unshifted_atmmodel = join(
config.pathdir,
Expand Down
13 changes: 6 additions & 7 deletions scripts/phoenix_precision.py
Expand Up @@ -5,6 +5,7 @@
Script to generate RV precision of synthetic spectra, see :ref:`Calculating-Precisions`.
"""

import argparse
import itertools
import os
Expand Down Expand Up @@ -33,8 +34,7 @@

num_cpu_minus_1 = cpu_minus_one()

ref_choices = ["SELF"]
ref_choices.extend(config.bands["all"])
ref_choices = ["SELF", *config.bands["all"]]


def _parser():
Expand Down Expand Up @@ -482,11 +482,10 @@ def check_model(model: str) -> str:
args = _parser()

# check bt-settl parameters
if args.model == "btsettl":
if (args.metal != [0]) or (args.alpha != [0]):
raise ValueError(
"You cannot vary metallicity and alpha for BT-Settl, remove these flags."
)
if args.model == "btsettl" and ((args.metal != [0]) or (args.alpha != [0])):
raise ValueError(
"You cannot vary metallicity and alpha for BT-Settl, remove these flags."
)
try:
normalize = not args.disable_normalization
except AttributeError:
Expand Down

0 comments on commit d0f9056

Please sign in to comment.