Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions specparam/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from specparam.data import SpectrumMetaData, ModelChecks
from specparam.utils.spectral import trim_spectrum
from specparam.utils.checks import check_input_options
from specparam.reports.strings import gen_data_str
from specparam.modutils.errors import DataError, InconsistentDataError
from specparam.modutils.docs import docs_get_section, replace_docstring_sections
from specparam.plts.settings import PLT_COLORS
Expand Down Expand Up @@ -77,6 +78,17 @@ def has_data(self):
return bool(np.any(self.power_spectrum))


@property
def n_freqs(self):
"""Indicator for the number of frequency values."""

n_freqs = None
if self.has_data:
n_freqs = len(self.freqs)

return n_freqs


def add_data(self, freqs, power_spectrum, freq_range=None):
"""Add data (frequencies, and power spectrum values) to the current object.

Expand Down Expand Up @@ -151,6 +163,18 @@ def plot(self, plt_log=False, **plt_kwargs):
log_powers=False, **data_kwargs)


def print(self, concise=False):
"""Print out a data summary.

Parameters
----------
concise : bool, optional, default: False
Whether to print the report in a concise mode, or not.
"""

print(gen_data_str(self, concise))


def set_checks(self, check_freqs=None, check_data=None):
"""Set check statuses, which control if an error is raised based on check on the inputs.

Expand Down Expand Up @@ -330,6 +354,17 @@ def has_data(self):
return bool(np.any(self.power_spectra))


@property
def n_spectra(self):
"""Indicator for the number of power spectra."""

n_spectra = None
if self.has_data:
n_spectra = len(self.power_spectra)

return n_spectra


def add_data(self, freqs, power_spectra, freq_range=None):
"""Add data (frequencies and power spectrum values) to the current object.

Expand Down Expand Up @@ -515,6 +550,17 @@ def n_events(self):
return len(self.spectrograms)


@property
def n_spectra(self):
"""Redefine n_spectra marker to reflect the total number of spectra."""

n_spectra = None
if self.has_data:
n_spectra = self.n_events * self.n_time_windows

return n_spectra


def add_data(self, freqs, spectrograms, freq_range=None):
"""Add data (frequencies and spectrograms) to the current object.

Expand Down
50 changes: 50 additions & 0 deletions specparam/reports/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,56 @@ def gen_version_str(concise=False):
return output


def gen_data_str(data, concise=False):
"""Generate a string representation summarizing current data.

Parameters
----------
data : Data
Data object to summarize data for.
Can also be any derived data object (e.g. Data2D).
concise : bool, optional, default: False
Whether to print the report in concise mode.

Returns
-------
output : str
Formatted string of data summary.
"""

if not data.has_data:

no_data_str = "No data currently loaded in the object."
str_lst = [DIVIDER,'', no_data_str, '', DIVIDER]

else:

# Get number of spectra, checking attributes for {Data3D, Data2DT, Data2D, Data}
if getattr(data, 'n_events', None):
n_spectra_str = '{} spectrograms with {} windows each'.format(data.n_events, data.n_time_windows)
elif getattr(data, 'n_time_windows', None):
n_spectra_str = '1 spectrogram with {} windows'.format(data.n_time_windows)
elif getattr(data, 'n_spectra', None):
n_spectra_str = '{} power spectra'.format(data.n_spectra)
else:
n_spectra_str = '1 power spectrum'

str_lst = [

DIVIDER,
'',
'The data object contains {}'.format(n_spectra_str),
'with a frequency range of {} Hz'.format(data.freq_range),
'and a frequency resolution of {} Hz.'.format(data.freq_res),
'',
DIVIDER,
]

output = _format(str_lst, concise)

return output


def gen_modes_str(modes, description=False, concise=False):
"""Generate a string representation of fit modes.

Expand Down
14 changes: 11 additions & 3 deletions specparam/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from specparam.modutils.dependencies import safe_import

from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tfm, get_tfm2, get_tfg, get_tfg2,
get_tft, get_tfe, get_tbands, get_tresults, get_tmodes,
get_tdocstring)
from specparam.tests.tdata import (get_tdata, get_tdata2d, get_tdata2dt, get_tdata3d,
get_tfm, get_tfm2, get_tfg, get_tfg2, get_tft, get_tfe,
get_tbands, get_tresults, get_tmodes, get_tdocstring)
from specparam.tests.tsettings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH,
TEST_REPORTS_PATH, TEST_PLOTS_PATH)

Expand Down Expand Up @@ -67,6 +67,14 @@ def tdata():
def tdata2d():
yield get_tdata2d()

@pytest.fixture(scope='session')
def tdata2dt():
yield get_tdata2dt()

@pytest.fixture(scope='session')
def tdata3d():
yield get_tdata3d()

@pytest.fixture(scope='session')
def tfm():
yield get_tfm()
Expand Down
10 changes: 9 additions & 1 deletion specparam/tests/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ def test_data():

tdata = Data()
assert tdata
assert not tdata.has_data
assert not tdata.n_freqs

def test_data_add_data():

tdata = Data()
freqs, pows = np.array([1, 2, 3]), np.array([10, 10, 10])
tdata.add_data(freqs, pows)
assert tdata.has_data
assert tdata.n_freqs == len(freqs)

def test_data_meta_data():

Expand Down Expand Up @@ -66,13 +69,15 @@ def test_data2d():
assert tdata2d
assert isinstance(tdata2d, Data)
assert isinstance(tdata2d, Data2D)
assert not tdata2d.has_data

def test_data2d_add_data():

tdata2d = Data2D()
freqs, pows = np.array([1, 2, 3]), np.array([[10, 10, 10], [20, 20, 20]])
tdata2d.add_data(freqs, pows)
assert tdata2d.has_data
assert tdata2d.n_spectra == len(pows)

@plot_test
def test_data2d_plot(tdata2d, skip_if_no_mpl):
Expand All @@ -88,6 +93,7 @@ def test_data2dt():
assert isinstance(tdata2dt, Data)
assert isinstance(tdata2dt, Data2D)
assert isinstance(tdata2dt, Data2DT)
assert not tdata2dt.has_data

def test_data2dt_add_data():

Expand All @@ -96,7 +102,7 @@ def test_data2dt_add_data():
tdata2dt.add_data(freqs, pows)
assert tdata2dt.has_data
assert np.all(tdata2dt.spectrogram)
assert tdata2dt.n_time_windows
assert tdata2dt.n_spectra == tdata2dt.n_time_windows == len(pows.T)

## 3D Data Object

Expand All @@ -108,6 +114,7 @@ def test_data3d():
assert isinstance(tdata3d, Data2D)
assert isinstance(tdata3d, Data2DT)
assert isinstance(tdata3d, Data3D)
assert not tdata3d.has_data

def test_data3d_add_data():

Expand All @@ -117,3 +124,4 @@ def test_data3d_add_data():
assert tdata3d.has_data
assert np.all(tdata3d.spectrograms)
assert tdata3d.n_events
assert tdata3d.n_spectra == 2 * len(pows.T)
7 changes: 7 additions & 0 deletions specparam/tests/reports/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def test_gen_version_str():

assert gen_version_str()

def test_gen_data_str(tdata, tdata2d, tdata2dt, tdata3d):

assert gen_data_str(tdata)
assert gen_data_str(tdata2d)
assert gen_data_str(tdata2dt)
assert gen_data_str(tdata3d)

def test_gen_modes_str(tfm):

assert gen_modes_str(tfm.modes)
Expand Down
24 changes: 23 additions & 1 deletion specparam/tests/tdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from specparam.bands import Bands
from specparam.modes.modes import Modes
from specparam.data.data import Data, Data2D
from specparam.data.data import Data, Data2D, Data2DT, Data3D
from specparam.data.stores import FitResults
from specparam.models import (SpectralModel, SpectralGroupModel,
SpectralTimeModel, SpectralTimeEventModel)
Expand Down Expand Up @@ -50,6 +50,26 @@ def get_tdata2d():

return tdata2d

def get_tdata2dt():

n_spectra = 3
tdata2dt = Data2DT()
tdata2dt.add_data(*sim_spectrogram(n_spectra, *default_group_params()))

return tdata2dt

def get_tdata3d():

n_events = 2
n_spectra = 3
tdata3d = Data3D()
freqs, spectrogram = sim_spectrogram(n_spectra, *default_group_params())
tdata3d.add_data(freqs, [spectrogram] * n_events)

return tdata3d

## TEST MODEL OBJECTS

def get_tfm():
"""Get a model object, with a fit power spectrum, for testing."""

Expand Down Expand Up @@ -117,6 +137,8 @@ def get_tfe():

return tfe

## TEST OTHER OBJECTS

def get_tbands():
"""Get a bands object, for testing."""

Expand Down