From 600a0e8265e6dc9845e0dd98688ed420d09981ce Mon Sep 17 00:00:00 2001 From: Katsuya Iida Date: Sun, 9 Oct 2022 02:47:55 +0900 Subject: [PATCH 1/2] Add `torchaudio.transforms.GriffinLim()`. Replace `torchaudio.ITransform` with `IModule`. --- src/Examples/SpeechCommands.cs | 6 +- src/TorchSharp/TorchAudio/Compose.cs | 50 ---- src/TorchSharp/TorchAudio/ITransform.cs | 14 - src/TorchSharp/TorchAudio/Transforms.cs | 255 ++++-------------- .../TorchAudio/Transforms/GriffinLim.cs | 80 ++++++ .../Transforms/InverseSpectrogram.cs | 92 +++++++ .../TorchAudio/Transforms/Resample.cs | 77 ++++++ .../TorchAudio/Transforms/Spectrogram.cs | 97 +++++++ test/TorchSharpTest/TestTorchAudio.cs | 28 ++ 9 files changed, 423 insertions(+), 276 deletions(-) delete mode 100644 src/TorchSharp/TorchAudio/Compose.cs delete mode 100644 src/TorchSharp/TorchAudio/ITransform.cs create mode 100644 src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs create mode 100644 src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs create mode 100644 src/TorchSharp/TorchAudio/Transforms/Resample.cs create mode 100644 src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs diff --git a/src/Examples/SpeechCommands.cs b/src/Examples/SpeechCommands.cs index ebe39831e..6e0cc01d1 100644 --- a/src/Examples/SpeechCommands.cs +++ b/src/Examples/SpeechCommands.cs @@ -87,7 +87,7 @@ private static BatchItem Collate(IEnumerable items, t }; } - internal static void TrainingLoop(string dataset, Device device, M5 model, ITransform transform, Dataset train_data, Dataset test_data) + internal static void TrainingLoop(string dataset, Device device, M5 model, nn.IModule transform, Dataset train_data, Dataset test_data) { using (var train_loader = new DataLoader( train_data, _trainBatchSize, Collate, shuffle: true, device: device)) @@ -121,7 +121,7 @@ internal static void TrainingLoop(string dataset, Device device, M5 model, ITran private static void Train( M5 model, - ITransform transform, + nn.IModule transform, torch.optim.Optimizer optimizer, Loss criteria, DataLoader dataLoader, @@ -159,7 +159,7 @@ private static void Train( private static void Test( M5 model, - ITransform transform, + nn.IModule transform, Loss criteria, DataLoader dataLoader, long size) diff --git a/src/TorchSharp/TorchAudio/Compose.cs b/src/TorchSharp/TorchAudio/Compose.cs deleted file mode 100644 index 5b63ce06d..000000000 --- a/src/TorchSharp/TorchAudio/Compose.cs +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. - -using System; -using static TorchSharp.torch; - -namespace TorchSharp -{ - public static partial class torchaudio - { - internal class ComposedTransforms : IDisposable, ITransform - { - public ComposedTransforms(ITransform[] transforms) - { - this.transforms = transforms; - } - - public void Dispose() - { - foreach (var t in transforms) { - if (t is IDisposable) { - ((IDisposable)t).Dispose(); - } - } - } - - public Tensor forward(Tensor input) - { - foreach (var t in transforms) { - input = t.forward(input); - } - return input; - } - - private ITransform[] transforms; - } - - public static partial class transforms - { - /// - /// Composes several transforms together. - /// - /// A list of transforms to compose serially. - /// - static public ITransform Compose(params ITransform[] transforms) - { - return new ComposedTransforms(transforms); - } - } - } -} \ No newline at end of file diff --git a/src/TorchSharp/TorchAudio/ITransform.cs b/src/TorchSharp/TorchAudio/ITransform.cs deleted file mode 100644 index 49e4214da..000000000 --- a/src/TorchSharp/TorchAudio/ITransform.cs +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. - -using static TorchSharp.torch; - -namespace TorchSharp -{ - public partial class torchaudio - { - public interface ITransform - { - Tensor forward(Tensor input); - } - } -} diff --git a/src/TorchSharp/TorchAudio/Transforms.cs b/src/TorchSharp/TorchAudio/Transforms.cs index 6ed5b3284..2e7cb6f58 100644 --- a/src/TorchSharp/TorchAudio/Transforms.cs +++ b/src/TorchSharp/TorchAudio/Transforms.cs @@ -2,9 +2,9 @@ using System; using System.Linq; using System.Collections.Generic; -using System.Runtime.InteropServices; using static TorchSharp.torch; +using TorchSharp.Transforms; // A number of implementation details in this file have been translated from the Python version of torchaudio, // largely located in the files found in this folder: @@ -22,211 +22,6 @@ public static partial class torchaudio { public delegate torch.Tensor WindowFunction(long win_length); - internal class Spectrogram : ITransform - { - private readonly long n_fft; - private readonly long win_length; - private readonly long hop_length; - private readonly long pad; - private readonly Tensor window; - private readonly double? power; - private readonly bool normalized; - private readonly bool center; - private readonly PaddingModes pad_mode; - private readonly bool onesided; - - public Spectrogram( - long n_fft = 400, - long? win_length = null, - long? hop_length = null, - long pad = 0, - WindowFunction window_fn = null, - Tensor window = null, - double? power = 2.0, - bool normalized = false, - bool center = true, - PaddingModes pad_mode = PaddingModes.Reflect, - bool onesided = true, - bool? return_complex = null) - { - this.n_fft = n_fft; - if (win_length.HasValue) { - this.win_length = win_length.Value; - } else { - this.win_length = n_fft; - } - if (hop_length.HasValue) { - this.hop_length = hop_length.Value; - } else { - this.hop_length = this.win_length / 2; - } - this.pad = pad; - if (window is not null) { - this.window = window; - } else if (window_fn != null) { - this.window = window_fn(this.win_length); - } else { - this.window = torch.hann_window(this.win_length); - } - this.power = power; - this.normalized = normalized; - this.center = center; - this.pad_mode = pad_mode; - this.onesided = onesided; - if (return_complex.HasValue) { - Console.WriteLine( - "`return_complex` argument is now deprecated and is not effective." + - "`torchaudio.transforms.Spectrogram(power=null)` always returns a tensor with " + - "complex dtype. Please remove the argument in the function call." - ); - } - } - - public Tensor forward(Tensor input) - { - return torchaudio.functional.spectrogram( - waveform: input, - pad: pad, - window: window, - n_fft: n_fft, - hop_length: hop_length, - win_length: win_length, - power: power, - normalized: normalized, - center: center, - pad_mode: pad_mode, - onesided: onesided); - } - } - - internal class InverseSpectrogram : ITransform - { - private readonly long n_fft; - private readonly long win_length; - private readonly long hop_length; - private readonly long pad; - private readonly Tensor window; - private readonly bool normalized; - private readonly bool center; - private readonly PaddingModes pad_mode; - private readonly bool onesided; - - public InverseSpectrogram( - long n_fft = 400, - long? win_length = null, - long? hop_length = null, - long pad = 0, - WindowFunction window_fn = null, - Tensor window = null, - bool normalized = false, - bool center = true, - PaddingModes pad_mode = PaddingModes.Reflect, - bool onesided = true) - { - this.n_fft = n_fft; - if (win_length.HasValue) { - this.win_length = win_length.Value; - } else { - this.win_length = n_fft; - } - if (hop_length.HasValue) { - this.hop_length = hop_length.Value; - } else { - this.hop_length = this.win_length / 2; - } - this.pad = pad; - if (window is not null) { - this.window = window; - } else if (window_fn != null) { - this.window = window_fn(this.win_length); - } else { - this.window = torch.hann_window(this.win_length); - } - this.normalized = normalized; - this.center = center; - this.pad_mode = pad_mode; - this.onesided = onesided; - } - - public Tensor forward(Tensor input) - { - return forward(input, null); - } - - public Tensor forward(Tensor input, long? length = null) - { - return torchaudio.functional.inverse_spectrogram( - spectrogram: input, - length: length, - pad: pad, - window: window, - n_fft: n_fft, - hop_length: hop_length, - win_length: win_length, - normalized: normalized, - center: center, - pad_mode: pad_mode, - onesided: onesided); - } - } - - internal sealed class Resample : ITransform - { - private readonly int orig_freq; - private readonly int new_freq; - private readonly int gcd; - private readonly ResamplingMethod resampling_method; - private readonly int lowpass_filter_width; - private readonly double rolloff; - private readonly double? beta; - public readonly torch.Tensor kernel; - private readonly int width; - - public Resample( - int orig_freq = 16000, - int new_freq = 16000, - ResamplingMethod resampling_method = ResamplingMethod.sinc_interpolation, - int lowpass_filter_width = 6, - double rolloff = 0.99, - double? beta = null, - torch.Device device = null, - torch.ScalarType? dtype = null) - { - this.orig_freq = orig_freq; - this.new_freq = new_freq; - this.gcd = functional.Gcd(this.orig_freq, this.new_freq); - this.resampling_method = resampling_method; - this.lowpass_filter_width = lowpass_filter_width; - this.rolloff = rolloff; - this.beta = beta; - - if (this.orig_freq != this.new_freq) { - (this.kernel, this.width) = functional._get_sinc_resample_kernel( - this.orig_freq, - this.new_freq, - this.gcd, - this.lowpass_filter_width, - this.rolloff, - this.resampling_method, - beta, - device: device, - dtype: dtype); - } - } - - public torch.Tensor forward(torch.Tensor waveform) - { - using (var d = torch.NewDisposeScope()) { - - if (this.orig_freq == this.new_freq) { - return d.MoveToOuter(waveform.alias()); - } - var resampled = functional._apply_sinc_resample_kernel(waveform, this.orig_freq, this.new_freq, this.gcd, this.kernel, this.width); - return d.MoveToOuter(resampled); - } - } - } - public static partial class transforms { /// @@ -245,7 +40,7 @@ public static partial class transforms /// Whether the output is onesided or not. /// Deprecated and not used. /// ITransform to compute spectrograms of audio signals - public static ITransform Spectrogram( + public static Spectrogram Spectrogram( long n_fft = 400, long? win_length = null, long? hop_length = null, @@ -260,6 +55,7 @@ public static ITransform Spectrogram( bool? return_complex = null) { return new Spectrogram( + "Spectrogram", n_fft: n_fft, hop_length: hop_length, win_length: win_length, @@ -288,7 +84,7 @@ public static ITransform Spectrogram( /// The padding mode used when center is true. /// Whether the output is onesided or not. /// ITransform to compute inverse of spectrogram - public static ITransform InverseSpectrogram( + public static InverseSpectrogram InverseSpectrogram( long n_fft = 400, long? win_length = null, long? hop_length = null, @@ -301,6 +97,7 @@ public static ITransform InverseSpectrogram( bool onesided = true) { return new InverseSpectrogram( + "InverseSpectrogram", n_fft: n_fft, hop_length: hop_length, win_length: win_length, @@ -326,7 +123,7 @@ public static ITransform InverseSpectrogram( /// The scalar type /// The resampled waveform /// - public static ITransform Resample( + public static Resample Resample( int orig_freq = 16000, int new_freq = 16000, ResamplingMethod resampling_method = ResamplingMethod.sinc_interpolation, @@ -337,6 +134,7 @@ public static ITransform Resample( torch.ScalarType? dtype = null) { return new Resample( + "Resample", orig_freq, new_freq, resampling_method, @@ -346,6 +144,45 @@ public static ITransform Resample( device, dtype); } + + /// + /// Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. + /// + /// Size of FFT, creates ``n_fft // 2 + 1`` bins. + /// Number of iteration for phase recovery process. + /// Window size. + /// Length of hop between STFT windows. + /// A function to create a window tensor + /// that is applied/multiplied to each frame/window. + /// Exponent for the magnitude spectrogram, + /// (must be > 0) e.g., 1 for energy, 2 for power, etc. + /// The momentum parameter for fast Griffin-Lim. + /// Array length of the expected output. + /// Initializes phase randomly if True and to zero otherwise. + /// + public static GriffinLim GriffinLim( + int n_fft = 400, + int n_iter = 32, + long? win_length = null, + long? hop_length = null, + WindowFunction window_fn = null, + double power = 2.0, + double momentum = 0.99, + int? length = null, + bool rand_init = true) + { + return new GriffinLim( + "GriffinLim", + n_fft, + n_iter, + win_length, + hop_length, + window_fn, + power, + momentum, + length, + rand_init); + } } } } \ No newline at end of file diff --git a/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs b/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs new file mode 100644 index 000000000..0dd8880c2 --- /dev/null +++ b/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs @@ -0,0 +1,80 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Text; +using static TorchSharp.torch; +using static TorchSharp.torchaudio; + +// A number of implementation details in this file have been translated from the Python version of torchaudio, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/audio/blob/bb77cbebb620a46fdc0dc7e6dae2253eef3f37e2/torchaudio/transforms/_transforms.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/audio/blob/main/LICENSE +// + +namespace TorchSharp.Transforms +{ + public class GriffinLim : nn.Module + { + public readonly long n_fft; + public readonly int n_iter; + public readonly long win_length; + public readonly long hop_length; + public readonly Tensor window; + public readonly long? length; + public readonly double power; + public readonly double momentum; + public readonly bool rand_init; + + internal GriffinLim( + string name, + long n_fft = 400, + int n_iter = 32, + long? win_length = null, + long? hop_length = null, + WindowFunction window_fn = null, + double power = 2.0, + double momentum = 0.99, + long? length = null, + bool rand_init = true) : base(name) + { + if (momentum < 0 || 1 <= momentum) { + throw new ArgumentOutOfRangeException($"momentum must be in the range [0, 1). Found: {momentum}"); + } + + this.n_fft = n_fft; + this.n_iter = n_iter; + this.win_length = win_length ?? n_fft; + this.hop_length = hop_length ?? this.win_length / 2; + if (window_fn != null) { + this.window = window_fn(this.win_length); + } else { + this.window = torch.hann_window(this.win_length); + } + this.register_buffer("window", this.window); + this.length = length; + this.power = power; + this.momentum = momentum; + this.rand_init = rand_init; + } + + public override Tensor forward(Tensor specgram) + { + return torchaudio.functional.griffinlim( + specgram, + this.window, + this.n_fft, + this.hop_length, + this.win_length, + this.power, + this.n_iter, + this.momentum, + this.length, + this.rand_init + ); + } + } +} diff --git a/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs b/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs new file mode 100644 index 000000000..896bd5c59 --- /dev/null +++ b/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs @@ -0,0 +1,92 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Net; +using System.Text; +using static TorchSharp.torch; +using static TorchSharp.torchaudio; + +// A number of implementation details in this file have been translated from the Python version of torchaudio, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/audio/blob/bb77cbebb620a46fdc0dc7e6dae2253eef3f37e2/torchaudio/transforms/_transforms.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/audio/blob/main/LICENSE +// + +namespace TorchSharp.Transforms +{ + public sealed class InverseSpectrogram : nn.Module, nn.IModule + { + private readonly long n_fft; + private readonly long win_length; + private readonly long hop_length; + private readonly long pad; + private readonly Tensor window; + private readonly bool normalized; + private readonly bool center; + private readonly PaddingModes pad_mode; + private readonly bool onesided; + + internal InverseSpectrogram( + string name, + long n_fft = 400, + long? win_length = null, + long? hop_length = null, + long pad = 0, + WindowFunction window_fn = null, + Tensor window = null, + bool normalized = false, + bool center = true, + PaddingModes pad_mode = PaddingModes.Reflect, + bool onesided = true) : base(name) + { + this.n_fft = n_fft; + if (win_length.HasValue) { + this.win_length = win_length.Value; + } else { + this.win_length = n_fft; + } + if (hop_length.HasValue) { + this.hop_length = hop_length.Value; + } else { + this.hop_length = this.win_length / 2; + } + this.pad = pad; + if (window is not null) { + this.window = window; + } else if (window_fn != null) { + this.window = window_fn(this.win_length); + } else { + this.window = torch.hann_window(this.win_length); + } + this.normalized = normalized; + this.center = center; + this.pad_mode = pad_mode; + this.onesided = onesided; + } + + public override Tensor forward(Tensor input) + { + return forward(input, null); + } + + public Tensor forward(Tensor input, long? length = null) + { + return torchaudio.functional.inverse_spectrogram( + spectrogram: input, + length: length, + pad: pad, + window: window, + n_fft: n_fft, + hop_length: hop_length, + win_length: win_length, + normalized: normalized, + center: center, + pad_mode: pad_mode, + onesided: onesided); + } + } +} diff --git a/src/TorchSharp/TorchAudio/Transforms/Resample.cs b/src/TorchSharp/TorchAudio/Transforms/Resample.cs new file mode 100644 index 000000000..988344a03 --- /dev/null +++ b/src/TorchSharp/TorchAudio/Transforms/Resample.cs @@ -0,0 +1,77 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Text; +using static TorchSharp.torch; +using static TorchSharp.torchaudio; + +// A number of implementation details in this file have been translated from the Python version of torchaudio, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/audio/blob/bb77cbebb620a46fdc0dc7e6dae2253eef3f37e2/torchaudio/transforms/_transforms.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/audio/blob/main/LICENSE +// + +namespace TorchSharp.Transforms +{ + public sealed class Resample : nn.Module + { + private readonly int orig_freq; + private readonly int new_freq; + private readonly int gcd; + private readonly ResamplingMethod resampling_method; + private readonly int lowpass_filter_width; + private readonly double rolloff; + private readonly double? beta; + public readonly torch.Tensor kernel; + private readonly int width; + + internal Resample( + string name, + int orig_freq = 16000, + int new_freq = 16000, + ResamplingMethod resampling_method = ResamplingMethod.sinc_interpolation, + int lowpass_filter_width = 6, + double rolloff = 0.99, + double? beta = null, + torch.Device device = null, + torch.ScalarType? dtype = null) : base(name) + { + this.orig_freq = orig_freq; + this.new_freq = new_freq; + this.gcd = functional.Gcd(this.orig_freq, this.new_freq); + this.resampling_method = resampling_method; + this.lowpass_filter_width = lowpass_filter_width; + this.rolloff = rolloff; + this.beta = beta; + + if (this.orig_freq != this.new_freq) { + (this.kernel, this.width) = functional._get_sinc_resample_kernel( + this.orig_freq, + this.new_freq, + this.gcd, + this.lowpass_filter_width, + this.rolloff, + this.resampling_method, + beta, + device: device, + dtype: dtype); + } + } + + public override Tensor forward(Tensor waveform) + { + using (var d = torch.NewDisposeScope()) { + + if (this.orig_freq == this.new_freq) { + return d.MoveToOuter(waveform.alias()); + } + var resampled = functional._apply_sinc_resample_kernel(waveform, this.orig_freq, this.new_freq, this.gcd, this.kernel, this.width); + return d.MoveToOuter(resampled); + } + } + } +} diff --git a/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs b/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs new file mode 100644 index 000000000..a106de53c --- /dev/null +++ b/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs @@ -0,0 +1,97 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Text; +using static TorchSharp.torch; +using static TorchSharp.torchaudio; + +// A number of implementation details in this file have been translated from the Python version of torchaudio, +// largely located in the files found in this folder: +// +// https://github.com/pytorch/audio/blob/bb77cbebb620a46fdc0dc7e6dae2253eef3f37e2/torchaudio/transforms/_transforms.py +// +// The origin has the following copyright notice and license: +// +// https://github.com/pytorch/audio/blob/main/LICENSE +// + +namespace TorchSharp.Transforms +{ + public sealed class Spectrogram : nn.Module + { + private readonly long n_fft; + private readonly long win_length; + private readonly long hop_length; + private readonly long pad; + private readonly Tensor window; + private readonly double? power; + private readonly bool normalized; + private readonly bool center; + private readonly PaddingModes pad_mode; + private readonly bool onesided; + + public Spectrogram( + string name, + long n_fft = 400, + long? win_length = null, + long? hop_length = null, + long pad = 0, + WindowFunction window_fn = null, + Tensor window = null, + double? power = 2.0, + bool normalized = false, + bool center = true, + PaddingModes pad_mode = PaddingModes.Reflect, + bool onesided = true, + bool? return_complex = null) : base(name) + { + this.n_fft = n_fft; + if (win_length.HasValue) { + this.win_length = win_length.Value; + } else { + this.win_length = n_fft; + } + if (hop_length.HasValue) { + this.hop_length = hop_length.Value; + } else { + this.hop_length = this.win_length / 2; + } + this.pad = pad; + if (window is not null) { + this.window = window; + } else if (window_fn != null) { + this.window = window_fn(this.win_length); + } else { + this.window = torch.hann_window(this.win_length); + } + this.power = power; + this.normalized = normalized; + this.center = center; + this.pad_mode = pad_mode; + this.onesided = onesided; + if (return_complex.HasValue) { + Console.WriteLine( + "`return_complex` argument is now deprecated and is not effective." + + "`torchaudio.transforms.Spectrogram(power=null)` always returns a tensor with " + + "complex dtype. Please remove the argument in the function call." + ); + } + } + + public override Tensor forward(Tensor input) + { + return torchaudio.functional.spectrogram( + waveform: input, + pad: pad, + window: window, + n_fft: n_fft, + hop_length: hop_length, + win_length: win_length, + power: power, + normalized: normalized, + center: center, + pad_mode: pad_mode, + onesided: onesided); + } + } +} diff --git a/test/TorchSharpTest/TestTorchAudio.cs b/test/TorchSharpTest/TestTorchAudio.cs index c0b7d86e4..601972b99 100644 --- a/test/TorchSharpTest/TestTorchAudio.cs +++ b/test/TorchSharpTest/TestTorchAudio.cs @@ -161,6 +161,34 @@ public void TestGriffinLim() Assert.Equal(new long[] { 1, 80320 }, recovered_waveform.shape); } + [Fact] + public void TestTransformsGriffinLim() + { + var transform = torchaudio.transforms.Spectrogram( + pad: 200, + n_fft: 512, + hop_length: 160, + win_length: 400, + window_fn: win_length => torch.hann_window(400), + power: 2.0, + normalized: false); + var inverse_transform = torchaudio.transforms.GriffinLim( + n_fft: 512, + hop_length: 160, + win_length: 400, + window_fn: win_length => torch.hann_window(400), + power: 2.0, + n_iter: 32, + momentum: 0.99, + length: null, + rand_init: true); + var waveform = make_waveform(); + var specgram = transform.forward(waveform); + var recovered_waveform = inverse_transform.forward(specgram); + + Assert.Equal(new long[] { 1, 80320 }, recovered_waveform.shape); + } + [Fact] public void TestMelscaleFbanks() { From 8ba1587760633abe776970089bddc5ff75ade827 Mon Sep 17 00:00:00 2001 From: Katsuya Iida Date: Sun, 9 Oct 2022 12:53:34 +0900 Subject: [PATCH 2/2] Add `torchaudio.ITransform` back --- src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs | 2 +- src/TorchSharp/TorchAudio/Transforms/ITransform.cs | 14 ++++++++++++++ .../TorchAudio/Transforms/InverseSpectrogram.cs | 2 +- src/TorchSharp/TorchAudio/Transforms/Resample.cs | 2 +- .../TorchAudio/Transforms/Spectrogram.cs | 2 +- 5 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 src/TorchSharp/TorchAudio/Transforms/ITransform.cs diff --git a/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs b/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs index 0dd8880c2..29118d7c8 100644 --- a/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs +++ b/src/TorchSharp/TorchAudio/Transforms/GriffinLim.cs @@ -17,7 +17,7 @@ namespace TorchSharp.Transforms { - public class GriffinLim : nn.Module + public sealed class GriffinLim : nn.Module, ITransform { public readonly long n_fft; public readonly int n_iter; diff --git a/src/TorchSharp/TorchAudio/Transforms/ITransform.cs b/src/TorchSharp/TorchAudio/Transforms/ITransform.cs new file mode 100644 index 000000000..4505451f6 --- /dev/null +++ b/src/TorchSharp/TorchAudio/Transforms/ITransform.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +using static TorchSharp.torch; +using static TorchSharp.torch.nn; + +namespace TorchSharp +{ + public partial class torchaudio + { + public interface ITransform : IModule + { + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs b/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs index 896bd5c59..36a0c08bf 100644 --- a/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs +++ b/src/TorchSharp/TorchAudio/Transforms/InverseSpectrogram.cs @@ -18,7 +18,7 @@ namespace TorchSharp.Transforms { - public sealed class InverseSpectrogram : nn.Module, nn.IModule + public sealed class InverseSpectrogram : nn.Module, nn.IModule, ITransform { private readonly long n_fft; private readonly long win_length; diff --git a/src/TorchSharp/TorchAudio/Transforms/Resample.cs b/src/TorchSharp/TorchAudio/Transforms/Resample.cs index 988344a03..79acadedf 100644 --- a/src/TorchSharp/TorchAudio/Transforms/Resample.cs +++ b/src/TorchSharp/TorchAudio/Transforms/Resample.cs @@ -17,7 +17,7 @@ namespace TorchSharp.Transforms { - public sealed class Resample : nn.Module + public sealed class Resample : nn.Module, ITransform { private readonly int orig_freq; private readonly int new_freq; diff --git a/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs b/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs index a106de53c..d5e002134 100644 --- a/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs +++ b/src/TorchSharp/TorchAudio/Transforms/Spectrogram.cs @@ -17,7 +17,7 @@ namespace TorchSharp.Transforms { - public sealed class Spectrogram : nn.Module + public sealed class Spectrogram : nn.Module, ITransform { private readonly long n_fft; private readonly long win_length;