Skip to content

Commit

Permalink
fix Bug in Biquad assignment
Browse files Browse the repository at this point in the history
add UnitTest to compare IIRJ results with given Scipy results

relates to #27

add UnitTest to compare IIRJ results with given Scipy results
fix Bug in Biquad assignment
  • Loading branch information
Andy Leuckert committed Oct 5, 2023
1 parent d349a3c commit 54e8e38
Show file tree
Hide file tree
Showing 65 changed files with 122,420 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/main/java/uk/me/berndporr/iirj/Biquad.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ public void setOnePole(Complex pole, Complex zero) {
double a0 = 1;
double a1 = -pole.getReal();
double a2 = 0;
double b0 = -zero.getReal();
double b1 = 1;
double b0 = 1;
double b1 = -zero.getReal();
double b2 = 0;
setCoefficients(a0, a1, a2, b0, b1, b2);
}
Expand Down
43 changes: 43 additions & 0 deletions src/test/java/uk/me/berndporr/iirj/DoubleSignal.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package uk.me.berndporr.iirj;

class DoubleSignal
{
private final double[] xValues;
private final double[] values;
private final String name;

DoubleSignal(int length, String name) {
xValues = new double[length];
values = new double[length];
this.name = name;
}

void setValue(int index, double time, double value) {
xValues[index] = time;
values[index] = value;
}

double getValue(int index) {
return values[index];
}

double getXValue(int index) {
return xValues[index];
}

double[] getxValues() {
return xValues;
}

double[] getValues() {
return values;
}

int getSize() {
return values.length;
}

String getName() {
return this.name;
}
}
199 changes: 199 additions & 0 deletions src/test/java/uk/me/berndporr/iirj/IIRJScipyCompare.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package uk.me.berndporr.iirj;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;

import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;

@RunWith(Parameterized.class)
public class IIRJScipyCompare
{
private final FilterType filterType;
private final FilterBType filterBType;
private final int order;

public IIRJScipyCompare(FilterType filterType, FilterBType filterBType, int order) {
this.filterType = filterType;
this.filterBType = filterBType;
this.order = order;
}

@Parameters(name = "Test:{index}, Filtertype: {0}, FilterBType: {1}, Order:{2}")
public static Collection<Object[]> data() {
Collection<Object[]> entries = new ArrayList<>();
int[] orders = {1, 2, 4, 5, 10};
for (FilterType filterType : FilterType.values()) {
for (FilterBType filterBType : FilterBType.values()) {
for (int order : orders) {
entries.add(new Object[]{filterType, filterBType, order});
}
}
}
return entries;
}

@Override
public String toString() {
return "IIRJScipyCompare{" +
"filterType=" + filterType +
", filterBType=" + filterBType +
", order=" + order +
'}';
}

private enum FilterType
{
Butterworth("Butterworth"),
Chebychev1("Chebychev1"),
Chebychev2("Chebychev2");

private final String scipyString;

FilterType(String scipyString) {
this.scipyString = scipyString;
}
}

private enum FilterBType
{
Lowpass("lowpass"),
Highpass("highpass"),
Bandpass("bandpass"),
Bandstop("bandstop");

private final String scipyString;

FilterBType(String scipyString) {
this.scipyString = scipyString;
}
}

@Test
public void compareValues() {

ClassLoader classloader = Thread.currentThread().getContextClassLoader();

InputStream resource = classloader.getResourceAsStream("scipy/signal.txt");
DoubleSignal inputSignal = readSignalFromCSV(resource, "Input");

double[] time = inputSignal.getxValues();
double sampleRate = 1.0 / (time[1] - time[0]);
double lc = 200;
double hc = 300;

Cascade filter = createIIRJFilter(filterType, filterBType, order, lc, hc, sampleRate);
String filterFileName = generateFileNameFromConfig(filterType.scipyString, filterBType.scipyString, lc, hc, order);
DoubleSignal scipyResult = readSignalFromCSV(classloader.getResourceAsStream("scipy/" + filterFileName), "Scipy#" + filterFileName);

DoubleSignal filteredSignal = new DoubleSignal(inputSignal.getSize(), "IIRJ#" + filterFileName);
int size = inputSignal.getSize();
for (int index = 0; index < size; index++) {
filteredSignal.setValue(index, inputSignal.getXValue(index), filter.filter(inputSignal.getValue(index)));
}

compareSignals(scipyResult, filteredSignal);
}

private Cascade createIIRJFilter(FilterType filterType, FilterBType filterBType, int order, double lc, double hc, double sampleRate) {
double centerFrequency = (hc + lc) / 2;
double widthFrequency = hc - lc;
double rippleInDb = 1;

switch (filterType) {
case Butterworth:
Butterworth butterworth = new Butterworth();
switch (filterBType) {
case Lowpass:
butterworth.lowPass(order, sampleRate, hc);
break;
case Highpass:
butterworth.highPass(order, sampleRate, lc);
break;
case Bandpass:
butterworth.bandPass(order, sampleRate, centerFrequency, widthFrequency);
break;
case Bandstop:
butterworth.bandStop(order, sampleRate, centerFrequency, widthFrequency);
break;
}
return butterworth;
case Chebychev1:
ChebyshevI chebyshevI = new ChebyshevI();

switch (filterBType) {
case Lowpass:
chebyshevI.lowPass(order, sampleRate, hc, rippleInDb);
break;
case Highpass:
chebyshevI.highPass(order, sampleRate, lc, rippleInDb);
break;
case Bandpass:
chebyshevI.bandPass(order, sampleRate, centerFrequency, widthFrequency, rippleInDb);
break;
case Bandstop:
chebyshevI.bandStop(order, sampleRate, centerFrequency, widthFrequency, rippleInDb);
break;
}
return chebyshevI;
case Chebychev2:
ChebyshevII chebyshevII = new ChebyshevII();
switch (filterBType) {
case Lowpass:
chebyshevII.lowPass(order, sampleRate, hc, rippleInDb);
break;
case Highpass:
chebyshevII.highPass(order, sampleRate, lc, rippleInDb);
break;
case Bandpass:
chebyshevII.bandPass(order, sampleRate, centerFrequency, widthFrequency, rippleInDb);
break;
case Bandstop:
chebyshevII.bandStop(order, sampleRate, centerFrequency, widthFrequency, rippleInDb);
break;
}
return chebyshevII;
}

throw new IllegalArgumentException(
"Unknown filter configuration: "
+ "Filter Type: " + filterType
+ ", Filter B Type: " + filterBType);
}

private void compareSignals(DoubleSignal scipySignal, DoubleSignal iirjSignal) {
int signal1Size = scipySignal.getSize();
int signal2Size = iirjSignal.getSize();
Assert.assertEquals("Different signal1Size of signals", signal1Size, signal2Size);

for (int index = 0; index < signal1Size; index++) {
double scipySignalValue = scipySignal.getValue(index);
double iirjSignalValue = iirjSignal.getValue(index);
Assert.assertEquals("Different values at index " + index, scipySignalValue, iirjSignalValue, 1E-5);
}
}

private DoubleSignal readSignalFromCSV(InputStream inputStream, String signalName) {
List<String> lines = new BufferedReader(new InputStreamReader(inputStream,
StandardCharsets.UTF_8)).lines().collect(Collectors.toList());
DoubleSignal doubleSignal = new DoubleSignal(lines.size() - 1, signalName);
for (int i = 1; i < lines.size(); i++) {
String[] parts = lines.get(i).split(";");
doubleSignal.setValue(i - 1, Double.parseDouble(parts[0]), Double.parseDouble(parts[1]));
}
return doubleSignal;
}

private String generateFileNameFromConfig(String filterType, String filterBTye, double lc, double hc, int order) {
return filterType.toLowerCase() + "-" + filterBTye + "-LC_" + (int) lc + "-HC_" + (int) hc + "-Order_" + order + ".csv";
}
}
115 changes: 115 additions & 0 deletions src/test/resources/scipy/AnalogFilterTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import numpy as np
from scipy.signal import butter, lfilter, bessel, cheby1, cheby2, sosfilt
import matplotlib.pyplot as plt


def showSignals(time, signal, filtered_signal):
plt.plot(time, signal, label='Original Signal')
plt.plot(time, filtered_signal, label='Filtered Signal (Butterworth, 350 Hz lowpass)')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()


def show_fft(time, signal, filtered_signal, sample_rate):
N = len(signal)
fft_original = np.fft.fft(signal)
fft_filtered = np.fft.fft(filtered_signal)
freqs = np.fft.fftfreq(N, 1 / sample_rate)

# Only take the positive frequencies (up to the Nyquist frequency)
positive_freqs = freqs[:N // 2]
fft_original_positive = np.abs(fft_original[:N // 2])
fft_filtered_positive = np.abs(fft_filtered[:N // 2])

plt.figure()
plt.plot(positive_freqs, fft_original_positive, label='Original Signal')
plt.plot(positive_freqs, fft_filtered_positive, label='Filtered Signal')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude')
plt.title('FFT of Original and Filtered Signals')
plt.legend()
plt.show()


def export_to_csv(time, filtered_signal, path):
with open(path, 'w') as file:
# Write the header row
file.write('time;signal\n')
# Write the data rows
for t, s in zip(time, filtered_signal):
file.write(f'{t:.4f};{s:.5f}\n')


def readData(filename):
# Read the data
data = np.loadtxt(fname=filename, delimiter=';')
time = data[:, 0]
signal = data[:, 1]
return time, signal


def filterAndExport(sos, signal, exportName):
# Apply the filter to the signal (using lfilter for one-pass filtering)
filtered_signal = sosfilt(sos, signal)
# Export the result to a CSV file
export_to_csv(time, filtered_signal, exportName)
print(f"Filtered signal has been saved to {exportName}")

def createFileNameFromFilter(filtertype, btype, lc, hc, order):
return str(filtertype + "-" + btype + "-LC_" + str(lc) + "-HC_" + str(hc) + "-Order_" + str(order) + ".csv")

time, signal = readData('signal.txt')

# Calculate the sample rate
sample_rate = 1 / (time[1] - time[0])
lc = 200
hc = 300

filtertypes = ['Butterworth', 'Bessel', 'Chebychev1', 'Chebychev2']
btypes = ['lowpass', 'highpass', 'bandpass', 'bandstop']
orders = [1,2,4,5,10]

###########################################################################

for filtertype in filtertypes:
for btype in btypes:
for order in orders:
# Handle different critical frequencies for lowpass/highpass vs bandpass/bandstop
if btype in ['lowpass', 'highpass']:
Wn = hc if btype == 'lowpass' else lc
else:
Wn = [lc, hc]

if filtertype == "Butterworth":
sos = butter(N=order, Wn=Wn, fs=sample_rate, btype=btype, output='sos')
elif filtertype == "Bessel":
sos = bessel(N=order, Wn=Wn, fs=sample_rate, btype=btype, output='sos')
elif filtertype == "Chebychev1":
sos = cheby1(N=order, rp=1, Wn=Wn, fs=sample_rate, btype=btype, output='sos')
elif filtertype == "Chebychev2":
sos = cheby2(N=order, rs=1, Wn=Wn, fs=sample_rate, btype=btype, output='sos')

filename = createFileNameFromFilter(filtertype.lower(), btype, lc, hc, order)
filterAndExport(sos, signal, filename)


############################################################################


#b, a = butter(N=2, Wn=lc, fs=sample_rate, btype='lowpass')
#filterAndExport(b, a, signal, "Butterworth-Lowcut-350Hz-2.Order.txt")

#b, a = butter(N=1, Wn=hc, fs=sample_rate, btype='highpass')#
#filterAndExport(b, a, signal, "Butterworth-Highcut-450Hz-1.Order.txt")

#b, a = butter(N=2, Wn=hc, fs=sample_rate, btype='highpass')
#filterAndExport(b, a, signal, "Butterworth-Highcut-450Hz-2.Order.txt")


# Plot the original and filtered signals
# showSignals(time, signal, filtered_signal)

# Show the FFT of both signals
# show_fft(time, signal, filtered_signal, sample_rate)

0 comments on commit 54e8e38

Please sign in to comment.