Skip to content

Commit

Permalink
Fixing plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
MerlinDumeur committed May 28, 2024
1 parent ecd4f5c commit 5a6f9aa
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 30 deletions.
13 changes: 12 additions & 1 deletion pymultifracs/multiresquantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ class WaveletLeader(WaveletDec):
"""
p_exp: float
interval_size: int = 1
eta_p: np.ndarray = field(init=False, repr=False)
eta_p: np.ndarray = field(init=False, repr=False, default=None)
ZPJCorr: np.ndarray = field(init=False, default=None)

def bootstrap(self, R, min_scale=1, idx_reject=None):
Expand Down Expand Up @@ -712,6 +712,17 @@ def check_regularity(self, scaling_ranges, weighted=None,

self.correct_pleaders()

def plot(self, j1, j2, ax=None, vmin=None, vmax=None, cbar=True,
figsize=(4.5, 1.5), gamma=.3, nan_idx=None, signal_idx=0,
cbar_kw=None, cmap='magma'):

if self.eta_p is None and not np.isinf(self.p_exp):
self.check_regularity([(j1, j2)], None, None)

super().plot(j1, j2, ax, vmin, vmax, cbar, figsize, gamma,
nan_idx, signal_idx, cbar_kw, cmap)


@dataclass(kw_only=True)
class Wtwse(WaveletDec):
r"""
Expand Down
101 changes: 72 additions & 29 deletions pymultifracs/robust/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from dataclasses import dataclass, field
from typing import Any, Callable

Expand Down Expand Up @@ -26,57 +27,99 @@ def get_grid(param_grid):
return out


def get_fname(param):

fname = Path('.')

for param in param.values():

if isinstance(param, float):
fname /= f'{param:.2f}'
else:
fname /= str(param)

fname /= 'signal'

return fname.with_suffix('.npy')


@dataclass
class Benchmark:
signal_param_grid: dict[str, np.ndarray]
noise_param_grid: dict[str, np.ndarray]
signal_gen_func: Callable
noise_gen_func: Callable
# noise_param_grid: dict[str, np.ndarray]
signal_func: Callable
# noise_gen_func: Callable
estimation_grid: dict[str, Callable]
WT_params: dict[str, Any]
# parameters_df: pd.DataFrame = field(init=False, default=None, repr=False)
results: pd.DataFrame = field(init=False, repr=False)

def run(self, n_rep):

results = {}
def get_df_fnames(self):

return Path('results.pkl')

def generate_grids(self):

signal_grid = get_grid(self.signal_param_grid)
noise_grid = get_grid(self.noise_param_grid)
# noise_grid = get_grid(self.noise_param_grid)

for signal_params in signal_grid.itertuples(index=False):

signal_names = [*signal_params._fields]
signal_params = signal_params._asdict()
return signal_grid#, noise_grid

X = np.c_[
*[self.signal_gen_func(**signal_params)
for i in range(n_rep)]]

# for repetition in range(n_rep):
def load_df(self):

for noise_params in tqdm(noise_grid.itertuples(index=False)):
results_fname = self.get_df_fnames()
# self.parameters_df = self.generate_grids()

noise_names = [*noise_params._fields]
noise_params = noise_params._asdict()
if results_fname.exists():
self.results = pd.read_pickle(results_fname)

def compute_benchmark(self, n_jobs=1, save=False):

X_noisy = self.noise_gen_func(X, **noise_params)
WT = wavelet_analysis(X_noisy, **self.WT_params)
results = {}

signal_grid = get_grid(self.signal_param_grid)
signal_names = signal_grid.columns
print(signal_names)
# signals, signal_names = self.load_generate_signals()

def estimate_mf(signal, signal_params):
res = []
WT = wavelet_analysis(signal, **self.WT_params)
for method, est_fun in self.estimation_grid.items():
res.append((method, est_fun(WT)))

return res, signal_params

results = Parallel(n_jobs=n_jobs)(
delayed(estimate_mf)(*s)
for s in tqdm(self.signal_func(signal_grid), total=signal_grid.shape[0]))

results[(method, *signal_params.values(), *noise_params.values())] = [est_fun(WT)]
results = {
(method, *signal_params.values()): [estimate]
for res_list, signal_params in results
for method, estimate in res_list
}

self.results = pd.DataFrame.from_dict(results).transpose()

for i, name in enumerate(signal_names):
if name in noise_names:
self.results.index.names = ['method', *signal_names]
self.results.columns = ['cumulants']

signal_names[i] = name + '_signal'
noise_names[noise_names.index(name)] = name + '_noise'
results_fname = self.get_df_fnames()
results_fname.parent.mkdir(parents=True, exist_ok=True)

self.results.index.names = [
'method', *signal_names, *noise_names]
self.results.columns = ['cumulants']
# df_c2 = self.results.cumulants.apply(lambda x: x.c2[0, :, 0]).explode()
# df_c1 = self.results.cumulants.apply(lambda x: x.c1[0, :, 0]).explode()

# df_c2.name = 'c2'
# df_c1.name = 'c1'

# self.results = pd.concat([df_c1, df_c2], axis=1)
# self.results.loc[:, 'repetition'] = self.results.groupby(
# self.results.index.names).transform('cumcount')
# self.results.set_index('repetition', append=True, inplace=True)

self.results.to_pickle(results_fname)

def plot(self):
pass
Expand Down

0 comments on commit 5a6f9aa

Please sign in to comment.