Skip to content

Commit

Permalink
Merge 317b5f4 into daeca5d
Browse files Browse the repository at this point in the history
  • Loading branch information
prisae committed Jul 23, 2022
2 parents daeca5d + 317b5f4 commit a2e5f11
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 15 deletions.
3 changes: 2 additions & 1 deletion empymod/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import numpy as np
import numba as nb
from scipy import special # Only used for halfspace solution

__all__ = ['wavenumber', 'angle_factor', 'fullspace', 'greenfct',
'reflections', 'fields', 'halfspace']
Expand Down Expand Up @@ -943,6 +942,8 @@ def halfspace(off, angle, zsrc, zrec, etaH, etaV, freqtime, ab, signal,
the input and solution parameters.
"""
from scipy import special # Lazy for faster CLI load

xco = np.cos(angle)*off
yco = np.sin(angle)*off
res = np.real(1/etaH[0, 0])
Expand Down
37 changes: 25 additions & 12 deletions empymod/scripts/fdesign.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,6 @@ def rhs(r):
import numpy as np
from copy import deepcopy as dc
from scipy.constants import mu_0
from scipy.optimize import brute, fmin_powell

# Optional imports
try:
import matplotlib.pyplot as plt
except ImportError:
plt = False
plt_msg = "* WARNING :: `matplotlib` is not installed, no figures shown."

from empymod.filters import DigitalFilter
from empymod.model import dipole, dipole_k
Expand Down Expand Up @@ -372,15 +364,15 @@ def design(n, spacing, shift, fI, fC=False, r=None, r_def=(1, 1, 2), reim=None,
`full_output` is True.)
"""
from scipy.optimize import brute, fmin_powell # Lazy for faster CLI load

# === 1. LET'S START ============
t0 = printstartfinish(verb)

# Check plot with matplotlib (soft dependency)
plt = _get_matplotlib(plot*verb)
if plot > 0 and not plt:
plot = 0
if verb > 0:
print(plt_msg)

# Ensure fI, fC are lists
def check_f(f):
Expand Down Expand Up @@ -598,8 +590,8 @@ def plot_result(filt, full, prntres=True):
"""
# Check matplotlib (soft dependency)
plt = _get_matplotlib(1)
if not plt:
print(plt_msg)
return

if prntres:
Expand Down Expand Up @@ -713,6 +705,7 @@ def print_result(filt, full=None):

def _call_qc_transform_pairs(n, ispacing, ishift, fI, fC, r, r_def, reim):
r"""QC the input transform pairs."""
plt = _get_matplotlib()
print("* QC: Input transform-pairs:")
print(" fC: x-range defined through `n`, `spacing`, `shift`, and "
"`r`-parameters; b-range defined through `r`-parameter.")
Expand Down Expand Up @@ -760,6 +753,7 @@ def _call_qc_transform_pairs(n, ispacing, ishift, fI, fC, r, r_def, reim):

def _plot_transform_pairs(fCI, r, k, axes, tit):
r"""Plot the input transform pairs."""
plt = _get_matplotlib()

# Plot lhs
plt.sca(axes[0])
Expand Down Expand Up @@ -814,8 +808,8 @@ def _plot_inversion(f, rhs, r, k, imin, spacing, shift, cvar):
r"""QC the resulting filter."""

# Check matplotlib (soft dependency)
plt = _get_matplotlib(1)
if not plt:
print(plt_msg)
return

plt.figure("Inversion result "+f.name, figsize=(9.5, 4))
Expand Down Expand Up @@ -1449,3 +1443,22 @@ def _print_count(log):
log['cnt1'] = cp

return log


# Load matplotlib, if available.
def _get_matplotlib(verb=0):
"""Lazy load of matplotlib.
Matplotlib is a soft dependency of empymod, and only used in fdesign.
However, if it is installed we want to avoid loading straight away, as
this slows down the start of the CLI significantly.
"""
try:
import matplotlib.pyplot as plt # Lazy for faster CLI load
except ImportError:
if verb > 0:
print(
"* WARNING :: `matplotlib` is not installed, no figures shown."
)
plt = False
return plt
11 changes: 10 additions & 1 deletion empymod/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@


import numpy as np
from scipy import special, fftpack, integrate
from scipy.interpolate import InterpolatedUnivariateSpline as iuSpline

from empymod import kernel
Expand Down Expand Up @@ -164,6 +163,8 @@ def hankel_qwe(zsrc, zrec, lsrc, lrec, off, ang_fact, depth, ab, etaH, etaV,
If true, QWE/QUAD converged. If not, `htarg` might have to be adjusted.
"""
from scipy import special # Lazy for faster CLI load

# Input params have an additional dimension for frequency, reduce here
etaH = etaH[0, :]
etaV = etaV[0, :]
Expand Down Expand Up @@ -559,6 +560,8 @@ def fourier_qwe(fEM, time, freq, ftarg):
If true, QWE/QUAD converged. If not, `ftarg` might have to be adjusted.
"""
from scipy import special, integrate # Lazy for faster CLI load

# Get rtol, atol, nquad, maxint, diff_quad, a, b, and limit
rtol = ftarg['rtol']
atol = ftarg['atol']
Expand Down Expand Up @@ -680,6 +683,8 @@ def fourier_fftlog(fEM, time, freq, ftarg):
Only relevant for QWE/QUAD.
"""
from scipy import fftpack, special # Lazy for faster CLI load

# Get tcalc, dlnr, kr, rk, q; a and n
q = ftarg['q']
mu = ftarg['mu']
Expand Down Expand Up @@ -804,6 +809,8 @@ def fourier_fft(fEM, time, freq, ftarg):
Only relevant for QWE/QUAD.
"""
from scipy import fftpack # Lazy for faster CLI load

# Get ftarg values
dfreq = ftarg['dfreq']
nfreq = ftarg['nfreq']
Expand Down Expand Up @@ -1133,6 +1140,7 @@ def quad(sPJ0r, sPJ0i, sPJ1r, sPJ1i, sPJ0br, sPJ0bi, ab, off, ang_fact, iinp):
suited for QWE).
"""
from scipy import special, integrate # Lazy for faster CLI load

# Define the quadrature kernels
def quad_PJ0(klambd, sPJ0, koff):
Expand Down Expand Up @@ -1250,6 +1258,7 @@ def get_dlf_points(filt, inp, nr_per_dec):

def get_fftlog_input(rmin, rmax, n, q, mu):
r"""Return parameters required for FFTLog."""
from scipy import special # Lazy for faster CLI load

# Central point log10(r_c) of periodic interval
logrc = (rmin + rmax)/2
Expand Down
4 changes: 3 additions & 1 deletion empymod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# Mandatory imports
import copy
import numpy as np
from scipy import special
from timeit import default_timer
from datetime import timedelta, datetime

Expand Down Expand Up @@ -1007,6 +1006,7 @@ def check_time(time, signal, ft, ftarg, verb):
freq = np.squeeze(omega/2/np.pi)

elif ft == 'qwe': # QWE (using sine and imag-part)
from scipy import special # Lazy for faster CLI load

# If switch-off is required, use cosine, else sine
args.pop('sincos', None)
Expand Down Expand Up @@ -1663,6 +1663,8 @@ def get_azm_dip(inp, iz, ninpz, intpts, isdipole, strength, name, verb):

# Gauss quadrature if intpts > 2; else set to center of tinp
if intpts > 2: # Calculate the dipole positions
from scipy import special # Lazy for faster CLI load

# Get integration positions and weights
g_x, g_w = special.roots_legendre(intpts)
g_x = np.outer(g_x, dl/2.0) # Adjust to tinp length
Expand Down
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import subprocess
import numpy as np
from numpy.testing import assert_allclose

Expand Down Expand Up @@ -1196,3 +1197,17 @@ def test_report(capsys):
_ = utils.Report()
out, _ = capsys.readouterr() # Empty capsys
assert 'WARNING :: `empymod.Report` requires `scooby`' in out


def test_import_time():
# Relevant for the CLI: How long does it take to import?
cmd = ["time", "-f", "%U", "python", "-c", "import empymod"]
# Run it twice, just in case.
subprocess.run(cmd)
subprocess.run(cmd)
# Capture it
out = subprocess.run(cmd, capture_output=True)

# Currently we check t < 1.2s.
# => That should come down to t < 0.5s in the future!
assert float(out.stderr.decode("utf-8")[:-1]) < 1.2

0 comments on commit a2e5f11

Please sign in to comment.