Skip to content

Commit

Permalink
fix(206): Fix ValueError in cwt_coefficients (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
khrapovs committed Apr 29, 2024
1 parent c28c47c commit 260388d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
4 changes: 2 additions & 2 deletions functime/feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,10 +594,10 @@ def cwt_coefficients(
for i, width in enumerate(widths):
points = np.min([10 * width, x.len()])
wavelet_x = np.conj(ricker(points, width)[::-1])
convolution[i] = np.convolve(x.to_numpy(zero_copy_only=True), wavelet_x)
convolution[i] = np.convolve(x.to_numpy(zero_copy_only=True), wavelet_x, mode="same")
coeffs = []
for coeff_idx in range(min(n_coefficients, convolution.shape[1])):
coeffs.extend(convolution[widths.index(), coeff_idx] for _ in widths)
coeffs.extend(convolution[widths.index(w), coeff_idx] for w in widths)
return coeffs
else:
logger.info(
Expand Down
12 changes: 12 additions & 0 deletions tests/test_feature_extractors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np
import polars as pl
import pytest
from functime.feature_extractors import cwt_coefficients


@pytest.mark.parametrize("length", np.random.random_integers(low=1, high=100, size=5))
@pytest.mark.parametrize("widths", [(2,), (2, 5, 10, 20), (2, 5, 10, 20, 30)])
@pytest.mark.parametrize("n_coefficients", np.random.random_integers(low=1, high=100, size=5))
def test_cwt(length: int, widths: tuple, n_coefficients: int) -> None:
out = cwt_coefficients(pl.Series([1 for _ in range(length)]), widths=widths, n_coefficients=n_coefficients)
assert len(out) == min(n_coefficients, length) * len(widths)

0 comments on commit 260388d

Please sign in to comment.