Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/flucoma/algorithms/public/GriffinLim.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ under the European Union’s Horizon 2020 research and innovation programme

#include "STFT.hpp"
#include "../util/AlgorithmUtils.hpp"
#include "../util/EigenRandom.hpp"
#include "../util/FluidEigenMappings.hpp"
#include "../../data/FluidIndex.hpp"
#include "../../data/TensorTypes.hpp"
Expand All @@ -26,7 +27,7 @@ class GriffinLim

public:
void process(ComplexMatrixView in, index nSamples, index nIter, index winSize,
index fftSize, index hopSize)
index fftSize, index hopSize, index seed = -1)
{
using namespace Eigen;
using namespace _impl;
Expand All @@ -36,9 +37,8 @@ class GriffinLim
auto istft = ISTFT(winSize, fftSize, hopSize);
ArrayXd tmp = ArrayXd::Zero(nSamples);
ArrayXXcd magnitude = asEigen<Array>(in).abs();
ArrayXXcd phase =
ArrayXXcd::Random(magnitude.rows(), magnitude.cols()) * 2 * 1i * pi;
phase = phase.exp();
ArrayXXcd phase = EigenRandomPhase<ArrayXXcd>(
magnitude.rows(), magnitude.cols(), RandomSeed{seed});
ArrayXXcd estimate = ArrayXXcd::Zero(magnitude.rows(), magnitude.cols());
ArrayXXcd prev = ArrayXXcd::Zero(magnitude.rows(), magnitude.cols());
for (index i = 0; i < nIter; i++)
Expand Down
7 changes: 4 additions & 3 deletions include/flucoma/algorithms/public/NMFCross.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ under the European Union’s Horizon 2020 research and innovation programme
#pragma once

#include "STFT.hpp"
#include "../util/EigenRandom.hpp"
#include "../util/FluidEigenMappings.hpp"
#include "../../data/FluidIndex.hpp"
#include "../../data/TensorTypes.hpp"
Expand Down Expand Up @@ -57,16 +58,16 @@ class NMFCross
}

void process(const RealMatrixView X, RealMatrixView H1, RealMatrixView W0,
index r, index p, index c) const
index r, index p, index c, index randomSeed = -1) const
{
index nFrames = X.extent(0);
index nBins = X.extent(1);
index rank = W0.extent(0);
nBins = W0.extent(1);
MatrixXd W = asEigen<Matrix>(W0).transpose();
MatrixXd H;
H = MatrixXd::Random(rank, nFrames) * 0.5 +
MatrixXd::Constant(rank, nFrames, 0.5);
H = EigenRandom<MatrixXd>(rank, nFrames, RandomSeed{randomSeed},
Range{0.0, 1.0});
MatrixXd V = asEigen<Matrix>(X).transpose();
multiplicativeUpdates(V, W, H, r, p, c);
MatrixXd HT = H.transpose();
Expand Down
7 changes: 5 additions & 2 deletions include/flucoma/clients/nrt/NMFCrossClient.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum NMFCrossParamIndex {
kPolyphony,
kContinuity,
kIterations,
kRandomSeed,
kFFT
};

Expand All @@ -44,6 +45,7 @@ constexpr auto NMFCrossParams = defineParameters(
FrameSizeUpperLimit<kFFT>()),
LongParam("continuity", "Continuity", 7, Min(1), Odd()),
LongParam("iterations", "Number of Iterations", 50, Min(1)),
LongParam("seed", "Random Seed", -1),
FFTParam("fftSettings", "FFT Settings", 1024, -1, -1));

class NMFCrossClient : public FluidBaseClient,
Expand Down Expand Up @@ -154,7 +156,8 @@ class NMFCrossClient : public FluidBaseClient,
});

nmf.process(tgtMag, outputEnvelopes, W, get<kTimeSparsity>(),
std::min(srcWindows, get<kPolyphony>()), get<kContinuity>());
std::min(srcWindows, get<kPolyphony>()), get<kContinuity>(),
get<kRandomSeed>());

r = checkTask(c, progressCount, progressTotal);
if (!r.ok()) return r;
Expand All @@ -166,7 +169,7 @@ class NMFCrossClient : public FluidBaseClient,

GriffinLim gl;
gl.process(result, tgtFrames, 50, fftParams.winSize(), fftParams.fftSize(),
fftParams.hopSize());
fftParams.hopSize(), get<kRandomSeed>());

r = checkTask(c, ++progressCount, progressTotal);
if (!r.ok()) return r;
Expand Down
4 changes: 4 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ add_test_executable(TestTransientSlice algorithms/public/TestTransientSlice.cpp)

add_test_executable(TestMLP algorithms/public/TestMLP.cpp)
add_test_executable(TestKMeans algorithms/public/TestKMeans.cpp)
add_test_executable(TestNMFCross algorithms/public/TestNMFCross.cpp)
add_test_executable(TestGriffinLim algorithms/public/TestGriffinLim.cpp)
add_test_executable(TestNMF algorithms/public/TestNMF.cpp)
add_test_executable(TestUMAP algorithms/public/TestUMAP.cpp)

Expand Down Expand Up @@ -159,6 +161,8 @@ catch_discover_tests(TestBufferedProcess WORKING_DIRECTORY "${CMAKE_BINARY_DIR}"
catch_discover_tests(TestMLP WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
catch_discover_tests(TestKMeans WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
catch_discover_tests(TestEigenRandom WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
catch_discover_tests(TestNMFCross WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
catch_discover_tests(TestGriffinLim WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
catch_discover_tests(TestNNDSVD WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
catch_discover_tests(TestNMF WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
catch_discover_tests(TestRTPGHI WORKING_DIRECTORY "${CMAKE_BINARY_DIR}")
Expand Down
45 changes: 45 additions & 0 deletions tests/algorithms/public/TestGriffinLim.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#define CATCH_CONFIG_MAIN

#include <catch2/catch_all.hpp>
#include <flucoma/algorithms/public/GriffinLim.hpp>
#include <flucoma/data/FluidIndex.hpp>
#include <flucoma/data/FluidTensor.hpp>
#include <complex>
#include <vector>

namespace fluid {
TEST_CASE("GriffinLim is repeatable with user-supplied random seed")
{

using algorithm::GriffinLim;
using Tensor = FluidTensor<std::complex<double>, 2>;

index win = 64;
index fft = 64;
index hop = 64;
index bins = fft / 2 + 1;

// only actually interested in 1 frame of results, but need padding in algo
Tensor raw_input(2, bins);
raw_input(0, index(bins / 2)) = std::polar(1.0, 0.0);

std::vector<Tensor> inouts(3, raw_input);

GriffinLim algo;

algo.process(inouts[0], win, 1, win, fft, hop, 42);
algo.process(inouts[1], win, 1, win, fft, hop, 42);
algo.process(inouts[2], win, 1, win, fft, hop, 987234);

using Catch::Matchers::RangeEquals;

SECTION("Calls with the same seed have the same output")
{
REQUIRE_THAT(inouts[1], RangeEquals(inouts[0]));
}
SECTION("Calls with different seeds have different outputs")
{
REQUIRE_THAT(inouts[1], !RangeEquals(inouts[2]));
}
}
} // namespace fluid
37 changes: 37 additions & 0 deletions tests/algorithms/public/TestNMFCross.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#define CATCH_CONFIG_MAIN
#include <catch2/catch_all.hpp>
#include <flucoma/algorithms/public/NMFCross.hpp>
#include <flucoma/data/FluidTensor.hpp>
#include <algorithm>
#include <iostream>
#include <vector>

TEST_CASE("NMFCross is repeatable with user-supplied random seed")
{

using fluid::algorithm::NMFCross;
using Tensor = fluid::FluidTensor<double, 2>;
NMFCross algo(3);

Tensor targetMag{{0.5, 0.4}, {0.1, 1.1}, {0.7, 0.8},
{0.3, 0.0}, {1.0, 0.9}, {0.2, 0.6}};
Tensor sourceMag{{0.0, 0.4}, {0.6, 0.7}, {0.8, 0.1},
{1.0, 0.5}, {1.1, 0.2}, {0.9, 0.3}};

std::vector Hs(3, Tensor(6, 6));

algo.process(targetMag, Hs[0], sourceMag, 3, 2, 7, 42);
algo.process(targetMag, Hs[1], sourceMag, 3, 2, 7, 42);
algo.process(targetMag, Hs[2], sourceMag, 3, 2, 7, 5063);

using Catch::Matchers::RangeEquals;

SECTION("Calls with the same seed have the same output")
{
REQUIRE_THAT(Hs[1], RangeEquals(Hs[0]));
}
SECTION("Calls with different seeds have different outputs")
{
REQUIRE_THAT(Hs[1], !RangeEquals(Hs[2]));
}
}