Skip to content

Commit

Permalink
implementing stft
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellga committed Jul 25, 2024
1 parent 4a61d32 commit 54610de
Show file tree
Hide file tree
Showing 10 changed files with 1,052 additions and 126 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ members = [
"harmonium-fft",
"harmonium-resample",
"harmonium-window",
#"harmonium-stft",
"harmonium-stft",
]

[workspace.dependencies]
Expand All @@ -24,7 +24,7 @@ harmonium-io = { path = "harmonium-io", default-features = false }
harmonium-fft = { path = "harmonium-fft", default-features = false }
harmonium-resample = { path = "harmonium-resample", default-features = false }
harmonium-window = { path = "harmonium-window", default-features = false }
#harmonium-stft = { path = "harmonium-stft", default-features = false }
harmonium-stft = { path = "harmonium-stft", default-features = false }

[profile.release]
opt-level = 3
Expand Down
5 changes: 5 additions & 0 deletions harmonium-core/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ where
pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
self.0.as_slice_mut()
}

/// Returns `true` if the `HArray` shares the inner arc with another one.
pub fn is_shared(&self) -> bool {
todo!()
}
}

#[cfg(test)]
Expand Down
74 changes: 31 additions & 43 deletions harmonium-fft/src/fft.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::sync::Arc;

use harmonium_core::{array::HArray, errors::HError, errors::HResult};
use ndarray::{ArcArray1, ArcArray2, Axis, Dimension, Ix1, Ix2, IxDyn, Zip};
use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
use rustfft::{
num_complex::Complex,
num_traits::{Float, FloatConst},
num_traits::{ConstZero, Float, FloatConst},
FftNum, FftPlanner,
};
use std::sync::Arc;

#[derive(Clone)]
pub struct Fft<T> {
Expand All @@ -17,23 +16,22 @@ pub struct Fft<T> {

#[derive(Clone)]
pub struct RealFftForward<T> {
fft: Arc<dyn RealToComplex<T>>,
scratch_buffer: Arc<[Complex<T>]>,
pub fft: Arc<dyn RealToComplex<T>>,
pub scratch_buffer: Arc<[Complex<T>]>,
}

#[derive(Clone)]
pub struct RealFftInverse<T> {
fft: Arc<dyn ComplexToReal<T>>,
scratch_buffer: Arc<[Complex<T>]>,
pub fft: Arc<dyn ComplexToReal<T>>,
pub scratch_buffer: Arc<[Complex<T>]>,
}

impl<T: FftNum + Float + FloatConst> Fft<T> {
impl<T: FftNum + Float + FloatConst + ConstZero> Fft<T> {
pub fn new_fft_forward(length: usize) -> Self {
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(length);
let scratch_len = fft.get_inplace_scratch_len();
let zero = T::zero();
let scratch_buffer = vec![Complex::new(zero, zero); scratch_len];
let scratch_buffer = vec![Complex::<T>::ZERO; scratch_len];
let scratch_buffer: Arc<[Complex<T>]> = Arc::from(scratch_buffer);

Self {
Expand All @@ -46,8 +44,7 @@ impl<T: FftNum + Float + FloatConst> Fft<T> {
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_inverse(length);
let scratch_len = fft.get_inplace_scratch_len();
let zero = T::zero();
let scratch_buffer = vec![Complex::new(zero, zero); scratch_len];
let scratch_buffer = vec![Complex::<T>::ZERO; scratch_len];
let scratch_buffer: Arc<[Complex<T>]> = Arc::from(scratch_buffer);

Self {
Expand All @@ -57,13 +54,12 @@ impl<T: FftNum + Float + FloatConst> Fft<T> {
}
}

impl<T: FftNum + Float + FloatConst> RealFftForward<T> {
impl<T: FftNum + Float + FloatConst + ConstZero> RealFftForward<T> {
pub fn new_real_fft_forward(length: usize) -> Self {
let mut planner = RealFftPlanner::new();
let fft = planner.plan_fft_forward(length);
let zero = T::zero();
let scratch_len = fft.get_scratch_len();
let scratch_buffer = vec![Complex::new(zero, zero); scratch_len];
let scratch_buffer = vec![Complex::<T>::ZERO; scratch_len];
let scratch_buffer: Arc<[Complex<T>]> = Arc::from(scratch_buffer);

Self {
Expand All @@ -73,13 +69,12 @@ impl<T: FftNum + Float + FloatConst> RealFftForward<T> {
}
}

impl<T: FftNum + Float + FloatConst> RealFftInverse<T> {
impl<T: FftNum + Float + FloatConst + ConstZero> RealFftInverse<T> {
pub fn new_real_fft_inverse(length: usize) -> Self {
let mut planner = RealFftPlanner::new();
let fft = planner.plan_fft_inverse(length);
let zero = T::zero();
let scratch_len = fft.get_scratch_len();
let scratch_buffer = vec![Complex::new(zero, zero); scratch_len];
let scratch_buffer = vec![Complex::<T>::ZERO; scratch_len];
let scratch_buffer: Arc<[Complex<T>]> = Arc::from(scratch_buffer);

Self {
Expand Down Expand Up @@ -127,12 +122,11 @@ where

impl<T> ProcessRealFftForward<T, Ix1> for RealFftForward<T>
where
T: FftNum + Float + FloatConst,
T: FftNum + Float + FloatConst + ConstZero,
{
fn process(&mut self, harray: &mut HArray<T, Ix1>) -> HResult<HArray<Complex<T>, Ix1>> {
let zero = T::zero();
let length = harray.len();
let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::new(zero, zero));
let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::<T>::ZERO);
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);
self.fft
.process_with_scratch(
Expand All @@ -147,13 +141,12 @@ where

impl<T> ProcessRealFftInverse<T, Ix1> for RealFftInverse<T>
where
T: FftNum + Float + FloatConst,
T: FftNum + Float + FloatConst + ConstZero,
{
fn process(&mut self, harray: &mut HArray<Complex<T>, Ix1>) -> HResult<HArray<T, Ix1>> {
let zero = T::zero();
let length = self.fft.len();
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);
let mut ndarray = ArcArray1::from_elem(length, zero);
let mut ndarray = ArcArray1::from_elem(length, T::ZERO);
self.fft
.process_with_scratch(
harray.as_slice_mut().unwrap(),
Expand All @@ -172,24 +165,23 @@ where
fn process(&mut self, harray: &mut HArray<Complex<T>, Ix2>) -> HResult<()> {
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);

Zip::from(harray.0.lanes_mut(Axis(1))).for_each(|mut row| {
for mut row in harray.0.lanes_mut(Axis(1)) {
self.fft
.process_with_scratch(row.as_slice_mut().unwrap(), scratch_buffer);
});
}
Ok(())
}
}

impl<T> ProcessRealFftForward<T, Ix2> for RealFftForward<T>
where
T: FftNum + Float + FloatConst,
T: FftNum + Float + FloatConst + ConstZero,
{
fn process(&mut self, harray: &mut HArray<T, Ix2>) -> HResult<HArray<Complex<T>, Ix2>> {
let zero = T::zero();
let nrows = harray.0.nrows();
let ncols = harray.0.ncols();
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);
let mut ndarray = ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::new(zero, zero));
let mut ndarray = ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::<T>::ZERO);

Zip::from(ndarray.lanes_mut(Axis(1)))
.and(harray.0.lanes_mut(Axis(1)))
Expand All @@ -209,14 +201,13 @@ where

impl<T> ProcessRealFftInverse<T, Ix2> for RealFftInverse<T>
where
T: FftNum + Float + FloatConst,
T: FftNum + Float + FloatConst + ConstZero,
{
fn process(&mut self, harray: &mut HArray<Complex<T>, Ix2>) -> HResult<HArray<T, Ix2>> {
let zero = T::zero();
let length = self.fft.len();
let nrows = harray.0.nrows();
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);
let mut ndarray = ArcArray2::from_elem((nrows, length), zero);
let mut ndarray = ArcArray2::from_elem((nrows, length), T::ZERO);

Zip::from(ndarray.lanes_mut(Axis(1)))
.and(harray.0.lanes_mut(Axis(1)))
Expand Down Expand Up @@ -262,15 +253,14 @@ where

impl<T> ProcessRealFftForward<T, IxDyn> for RealFftForward<T>
where
T: FftNum + Float + FloatConst,
T: FftNum + Float + FloatConst + ConstZero,
{
fn process(&mut self, harray: &mut HArray<T, IxDyn>) -> HResult<HArray<Complex<T>, IxDyn>> {
let zero = T::zero();
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);
match harray.ndim() {
1 => {
let length = harray.len();
let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::new(zero, zero));
let mut ndarray = ArcArray1::from_elem(length / 2 + 1, Complex::<T>::ZERO);
self.fft
.process_with_scratch(
harray.as_slice_mut().unwrap(),
Expand All @@ -284,8 +274,7 @@ where
let nrows = harray.0.len_of(Axis(0));
let ncols = harray.0.len_of(Axis(1));
let mut ndarray =
ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::new(zero, zero))
.into_dyn();
ArcArray2::from_elem((nrows, ncols / 2 + 1), Complex::<T>::ZERO).into_dyn();

Zip::from(ndarray.lanes_mut(Axis(1)))
.and(harray.0.lanes_mut(Axis(1)))
Expand All @@ -310,15 +299,14 @@ where

impl<T> ProcessRealFftInverse<T, IxDyn> for RealFftInverse<T>
where
T: FftNum + Float + FloatConst,
T: FftNum + Float + FloatConst + ConstZero,
{
fn process(&mut self, harray: &mut HArray<Complex<T>, IxDyn>) -> HResult<HArray<T, IxDyn>> {
let zero = T::zero();
let length = self.fft.len();
let scratch_buffer = make_mut_slice(&mut self.scratch_buffer);
match harray.ndim() {
1 => {
let mut ndarray = ArcArray1::from_elem(length, zero);
let mut ndarray = ArcArray1::from_elem(length, T::ZERO);
self.fft
.process_with_scratch(
harray.as_slice_mut().unwrap(),
Expand All @@ -330,7 +318,7 @@ where
}
2 => {
let nrows = harray.0.len_of(Axis(0));
let mut ndarray = ArcArray2::from_elem((nrows, length), zero).into_dyn();
let mut ndarray = ArcArray2::from_elem((nrows, length), T::ZERO).into_dyn();

Zip::from(ndarray.lanes_mut(Axis(1)))
.and(harray.0.lanes_mut(Axis(1)))
Expand All @@ -354,11 +342,11 @@ where
}

// replace this function by make_mut when in stable (it is currently, but doesn't work for slices.)
fn make_mut_slice<T: Clone>(arc: &mut Arc<[T]>) -> &mut [T] {
pub fn make_mut_slice<T: Clone>(arc: &mut Arc<[T]>) -> &mut [T] {
if Arc::get_mut(arc).is_none() {
*arc = Arc::from(&arc[..]);
}
// Replace by get_mut_unchecked when available in stable. This can't failed since get_mut was
// Replace by get_mut_unchecked when available in stable. This can't fail since get_mut was
// checked above.
unsafe { Arc::get_mut(arc).unwrap_unchecked() }
}
Expand Down
3 changes: 3 additions & 0 deletions harmonium-stft/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@ edition = "2021"

[dependencies]
harmonium-core = { workspace = true }
harmonium-fft = { workspace = true }
rustfft = { workspace = true }
realfft = { workspace = true }
ndarray = { workspace = true }

2 changes: 1 addition & 1 deletion harmonium-stft/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
mod stft;
pub mod stft;
Loading

0 comments on commit 54610de

Please sign in to comment.