Skip to content

Commit

Permalink
CHG: fixing undecimated wavelet
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelcarcamov committed Jul 4, 2023
1 parent 9fa24a2 commit 93fe834
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions src/csromer/dictionaries/undecimated.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __post_init__(self):
@staticmethod
def calculate_max_level(x):
n = len(x)
return pywt.swt_max_level(n)
max_level = min(pywt.swt_max_level(n), 8)
return max_level

def decompose(self, x):
if self.wavelet_level is not None:
Expand All @@ -64,7 +65,7 @@ def decompose(self, x):
"Your signal length is not multiple of 2**" + str(self.wavelet_level) +
". Padding array..."
)
padded_size = next_power_2(signal_size)
padded_size = 2**(np.ceil(np.log2(abs(signal_size))))
self.pad_width = padded_size - signal_size

if self.mode is None:
Expand Down Expand Up @@ -146,14 +147,14 @@ def reconstruct(self, input_coeffs):
arr=input_coeffs[self.n:len(input_coeffs)],
coeff_slices=self.coeff_slices,
coeff_shapes=self.coeff_shapes,
output_format="wavedec",
output_format="swt",
)
else:
coeffs = pywt.unravel_coeffs(
arr=input_coeffs,
coeff_slices=self.coeff_slices,
coeff_shapes=self.coeff_shapes,
output_format="wavedec",
output_format="swt",
)

signal_from_coeffs = pywt.iswt(coeffs, self.wavelet, self.norm)
Expand All @@ -175,19 +176,15 @@ def reconstruct_complex(self, input_coeffs):
if self.append_signal:
signal = input_coeffs[0:self.n].copy()
coeffs = input_coeffs[self.n:len(input_coeffs)]
coeffs_re = pywt.array_to_coeffs(
coeffs.real, self.coeff_slices[0], output_format="wavedec"
)
coeffs_im = pywt.array_to_coeffs(
coeffs.imag, self.coeff_slices[1], output_format="wavedec"
)
coeffs_re = pywt.array_to_coeffs(coeffs.real, self.coeff_slices[0], output_format="swt")
coeffs_im = pywt.array_to_coeffs(coeffs.imag, self.coeff_slices[1], output_format="swt")

else:
coeffs_re = pywt.array_to_coeffs(
input_coeffs.real, self.coeff_slices[0], output_format="wavedec"
input_coeffs.real, self.coeff_slices[0], output_format="swt"
)
coeffs_im = pywt.array_to_coeffs(
input_coeffs.imag, self.coeff_slices[1], output_format="wavedec"
input_coeffs.imag, self.coeff_slices[1], output_format="swt"
)

signal_re = pywt.iswt(coeffs_re, self.wavelet, self.norm)
Expand Down

0 comments on commit 93fe834

Please sign in to comment.