Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/Examples/SpeechCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static BatchItem Collate(IEnumerable<SpeechCommandsDatasetItem> items, t
};
}

internal static void TrainingLoop(string dataset, Device device, M5 model, ITransform transform, Dataset<SpeechCommandsDatasetItem> train_data, Dataset<SpeechCommandsDatasetItem> test_data)
internal static void TrainingLoop(string dataset, Device device, M5 model, nn.IModule<Tensor, Tensor> transform, Dataset<SpeechCommandsDatasetItem> train_data, Dataset<SpeechCommandsDatasetItem> test_data)
{
using (var train_loader = new DataLoader<SpeechCommandsDatasetItem, BatchItem>(
train_data, _trainBatchSize, Collate, shuffle: true, device: device))
Expand Down Expand Up @@ -121,7 +121,7 @@ internal static void TrainingLoop(string dataset, Device device, M5 model, ITran

private static void Train(
M5 model,
ITransform transform,
nn.IModule<Tensor, Tensor> transform,
torch.optim.Optimizer optimizer,
Loss<Tensor,Tensor,Tensor> criteria,
DataLoader<SpeechCommandsDatasetItem, BatchItem> dataLoader,
Expand Down Expand Up @@ -159,7 +159,7 @@ private static void Train(

private static void Test(
M5 model,
ITransform transform,
nn.IModule<Tensor, Tensor> transform,
Loss<Tensor, Tensor, Tensor> criteria,
DataLoader<SpeechCommandsDatasetItem, BatchItem> dataLoader,
long size)
Expand Down
50 changes: 0 additions & 50 deletions src/TorchSharp/TorchAudio/Compose.cs

This file was deleted.

255 changes: 46 additions & 209 deletions src/TorchSharp/TorchAudio/Transforms.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
using System;
using System.Linq;
using System.Collections.Generic;
using System.Runtime.InteropServices;

using static TorchSharp.torch;
using TorchSharp.Transforms;

// A number of implementation details in this file have been translated from the Python version of torchaudio,
// largely located in the files found in this folder:
Expand All @@ -22,211 +22,6 @@ public static partial class torchaudio
{
public delegate torch.Tensor WindowFunction(long win_length);

internal class Spectrogram : ITransform
{
private readonly long n_fft;
private readonly long win_length;
private readonly long hop_length;
private readonly long pad;
private readonly Tensor window;
private readonly double? power;
private readonly bool normalized;
private readonly bool center;
private readonly PaddingModes pad_mode;
private readonly bool onesided;

public Spectrogram(
long n_fft = 400,
long? win_length = null,
long? hop_length = null,
long pad = 0,
WindowFunction window_fn = null,
Tensor window = null,
double? power = 2.0,
bool normalized = false,
bool center = true,
PaddingModes pad_mode = PaddingModes.Reflect,
bool onesided = true,
bool? return_complex = null)
{
this.n_fft = n_fft;
if (win_length.HasValue) {
this.win_length = win_length.Value;
} else {
this.win_length = n_fft;
}
if (hop_length.HasValue) {
this.hop_length = hop_length.Value;
} else {
this.hop_length = this.win_length / 2;
}
this.pad = pad;
if (window is not null) {
this.window = window;
} else if (window_fn != null) {
this.window = window_fn(this.win_length);
} else {
this.window = torch.hann_window(this.win_length);
}
this.power = power;
this.normalized = normalized;
this.center = center;
this.pad_mode = pad_mode;
this.onesided = onesided;
if (return_complex.HasValue) {
Console.WriteLine(
"`return_complex` argument is now deprecated and is not effective." +
"`torchaudio.transforms.Spectrogram(power=null)` always returns a tensor with " +
"complex dtype. Please remove the argument in the function call."
);
}
}

public Tensor forward(Tensor input)
{
return torchaudio.functional.spectrogram(
waveform: input,
pad: pad,
window: window,
n_fft: n_fft,
hop_length: hop_length,
win_length: win_length,
power: power,
normalized: normalized,
center: center,
pad_mode: pad_mode,
onesided: onesided);
}
}

internal class InverseSpectrogram : ITransform
{
private readonly long n_fft;
private readonly long win_length;
private readonly long hop_length;
private readonly long pad;
private readonly Tensor window;
private readonly bool normalized;
private readonly bool center;
private readonly PaddingModes pad_mode;
private readonly bool onesided;

public InverseSpectrogram(
long n_fft = 400,
long? win_length = null,
long? hop_length = null,
long pad = 0,
WindowFunction window_fn = null,
Tensor window = null,
bool normalized = false,
bool center = true,
PaddingModes pad_mode = PaddingModes.Reflect,
bool onesided = true)
{
this.n_fft = n_fft;
if (win_length.HasValue) {
this.win_length = win_length.Value;
} else {
this.win_length = n_fft;
}
if (hop_length.HasValue) {
this.hop_length = hop_length.Value;
} else {
this.hop_length = this.win_length / 2;
}
this.pad = pad;
if (window is not null) {
this.window = window;
} else if (window_fn != null) {
this.window = window_fn(this.win_length);
} else {
this.window = torch.hann_window(this.win_length);
}
this.normalized = normalized;
this.center = center;
this.pad_mode = pad_mode;
this.onesided = onesided;
}

public Tensor forward(Tensor input)
{
return forward(input, null);
}

public Tensor forward(Tensor input, long? length = null)
{
return torchaudio.functional.inverse_spectrogram(
spectrogram: input,
length: length,
pad: pad,
window: window,
n_fft: n_fft,
hop_length: hop_length,
win_length: win_length,
normalized: normalized,
center: center,
pad_mode: pad_mode,
onesided: onesided);
}
}

internal sealed class Resample : ITransform
{
private readonly int orig_freq;
private readonly int new_freq;
private readonly int gcd;
private readonly ResamplingMethod resampling_method;
private readonly int lowpass_filter_width;
private readonly double rolloff;
private readonly double? beta;
public readonly torch.Tensor kernel;
private readonly int width;

public Resample(
int orig_freq = 16000,
int new_freq = 16000,
ResamplingMethod resampling_method = ResamplingMethod.sinc_interpolation,
int lowpass_filter_width = 6,
double rolloff = 0.99,
double? beta = null,
torch.Device device = null,
torch.ScalarType? dtype = null)
{
this.orig_freq = orig_freq;
this.new_freq = new_freq;
this.gcd = functional.Gcd(this.orig_freq, this.new_freq);
this.resampling_method = resampling_method;
this.lowpass_filter_width = lowpass_filter_width;
this.rolloff = rolloff;
this.beta = beta;

if (this.orig_freq != this.new_freq) {
(this.kernel, this.width) = functional._get_sinc_resample_kernel(
this.orig_freq,
this.new_freq,
this.gcd,
this.lowpass_filter_width,
this.rolloff,
this.resampling_method,
beta,
device: device,
dtype: dtype);
}
}

public torch.Tensor forward(torch.Tensor waveform)
{
using (var d = torch.NewDisposeScope()) {

if (this.orig_freq == this.new_freq) {
return d.MoveToOuter(waveform.alias());
}
var resampled = functional._apply_sinc_resample_kernel(waveform, this.orig_freq, this.new_freq, this.gcd, this.kernel, this.width);
return d.MoveToOuter(resampled);
}
}
}

public static partial class transforms
{
/// <summary>
Expand All @@ -245,7 +40,7 @@ public static partial class transforms
/// <param name="onesided">Whether the output is onesided or not.</param>
/// <param name="return_complex">Deprecated and not used.</param>
/// <returns>ITransform to compute spectrograms of audio signals</returns>
public static ITransform Spectrogram(
public static Spectrogram Spectrogram(
long n_fft = 400,
long? win_length = null,
long? hop_length = null,
Expand All @@ -260,6 +55,7 @@ public static ITransform Spectrogram(
bool? return_complex = null)
{
return new Spectrogram(
"Spectrogram",
n_fft: n_fft,
hop_length: hop_length,
win_length: win_length,
Expand Down Expand Up @@ -288,7 +84,7 @@ public static ITransform Spectrogram(
/// <param name="pad_mode">The padding mode used when center is true.</param>
/// <param name="onesided">Whether the output is onesided or not.</param>
/// <returns>ITransform to compute inverse of spectrogram</returns>
public static ITransform InverseSpectrogram(
public static InverseSpectrogram InverseSpectrogram(
long n_fft = 400,
long? win_length = null,
long? hop_length = null,
Expand All @@ -301,6 +97,7 @@ public static ITransform InverseSpectrogram(
bool onesided = true)
{
return new InverseSpectrogram(
"InverseSpectrogram",
n_fft: n_fft,
hop_length: hop_length,
win_length: win_length,
Expand All @@ -326,7 +123,7 @@ public static ITransform InverseSpectrogram(
/// <param name="dtype">The scalar type</param>
/// <returns>The resampled waveform</returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
public static ITransform Resample(
public static Resample Resample(
int orig_freq = 16000,
int new_freq = 16000,
ResamplingMethod resampling_method = ResamplingMethod.sinc_interpolation,
Expand All @@ -337,6 +134,7 @@ public static ITransform Resample(
torch.ScalarType? dtype = null)
{
return new Resample(
"Resample",
orig_freq,
new_freq,
resampling_method,
Expand All @@ -346,6 +144,45 @@ public static ITransform Resample(
device,
dtype);
}

/// <summary>
/// Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
/// </summary>
/// <param name="n_fft">Size of FFT, creates ``n_fft // 2 + 1`` bins.</param>
/// <param name="n_iter">Number of iteration for phase recovery process.</param>
/// <param name="win_length">Window size.</param>
/// <param name="hop_length">Length of hop between STFT windows.</param>
/// <param name="window_fn">A function to create a window tensor
/// that is applied/multiplied to each frame/window.</param>
/// <param name="power">Exponent for the magnitude spectrogram,
/// (must be > 0) e.g., 1 for energy, 2 for power, etc.</param>
/// <param name="momentum">The momentum parameter for fast Griffin-Lim.</param>
/// <param name="length">Array length of the expected output.</param>
/// <param name="rand_init">Initializes phase randomly if True and to zero otherwise.</param>
/// <returns></returns>
public static GriffinLim GriffinLim(
int n_fft = 400,
int n_iter = 32,
long? win_length = null,
long? hop_length = null,
WindowFunction window_fn = null,
double power = 2.0,
double momentum = 0.99,
int? length = null,
bool rand_init = true)
{
return new GriffinLim(
"GriffinLim",
n_fft,
n_iter,
win_length,
hop_length,
window_fn,
power,
momentum,
length,
rand_init);
}
}
}
}
Loading