Skip to content

Commit

Permalink
LogMelSpectogram shifted to FeatureExtractor
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Mar 20, 2023
1 parent 61c2fbf commit e5d2aec
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 127 deletions.
6 changes: 4 additions & 2 deletions src/transformers/models/pop2piano/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
}

try:
if not is_torch_available() and not is_torchaudio_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
Expand All @@ -52,6 +52,7 @@
and is_soundfile_availble()
and is_tf_available()
and is_torch_available()
and is_torchaudio_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
Expand All @@ -64,7 +65,7 @@
from .configuration_pop2piano import POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP, Pop2PianoConfig

try:
if not is_torch_available() and not is_torchaudio_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
Expand All @@ -84,6 +85,7 @@
and is_soundfile_availble()
and is_tf_available()
and is_torch_available()
and is_torchaudio_available()
):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
Expand Down
25 changes: 0 additions & 25 deletions src/transformers/models/pop2piano/configuration_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,6 @@ class Pop2PianoConfig(PretrainedConfig):
Determines `max_length` for transformer `generate` function along with `dataset_n_bars`.
dataset_n_bars (`int`, *optional*, defaults to 2):
Determines `max_length` for transformer `generate` function along with `dataset_target_length`.
dataset_sampling_rate (`int` *optional*, defaults to 22050):
Sample rate of audio signal.
dataset_mel_is_conditioned (`bool`, *optional*, defaults to `True`):
Whether to use `ConcatEmbeddingToMel` or not.
n_fft (`int`, *optional*, defaults to 4096):
Size of Fast Fourier Transform, creates n_fft // 2 + 1 bins.
hop_length (`int`, *optional*, defaults to 1024):
Length of hop between Short-Time Fourier Transform windows.
f_min (`float`, *optional*, defaults to 10.0):
Minimum frequency.
n_mels (`int`, *optional*, defaults to 512):
Number of mel filterbanks.
"""

model_type = "pop2piano"
Expand Down Expand Up @@ -138,12 +126,6 @@ def __init__(
dense_act_fn="relu",
dataset_target_length=256,
dataset_n_bars=2,
dataset_sampling_rate=22050,
dataset_mel_is_conditioned=True,
n_fft=4096,
hop_length=1024,
f_min=10.0,
n_mels=512,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -166,15 +148,8 @@ def __init__(
self.is_gated_act = act_info[0] == "gated"
self.composer_to_feature_token = COMPOSER_TO_FEATURE_TOKEN

self.dataset_mel_is_conditioned = dataset_mel_is_conditioned
self.dataset_target_length = dataset_target_length
self.dataset_n_bars = dataset_n_bars
self.dataset_sampling_rate = dataset_sampling_rate

self.n_fft = n_fft
self.hop_length = hop_length
self.f_min = f_min
self.n_mels = n_mels

super().__init__(
pad_token_id=pad_token_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import scipy
import soundfile as sf
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence

from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
Expand Down Expand Up @@ -75,6 +76,14 @@ class also contains postprocessing methods to convert model outputs to midi audi
vocab_size_time (`int`, *optional*, defaults to 100):
This represents the number of Beat Shifts. Beat Shift [100 values] Indicates the relative time shift within
the segment quantized into 8th-note beats(half-beats).
n_fft (`int`, *optional*, defaults to 4096):
Size of Fast Fourier Transform, creates n_fft // 2 + 1 bins.
hop_length (`int`, *optional*, defaults to 1024):
Length of hop between Short-Time Fourier Transform windows.
f_min (`float`, *optional*, defaults to 10.0):
Minimum frequency.
n_mels (`int`, *optional*, defaults to 512):
Number of mel filterbanks.
"""
model_input_names = ["input_features"]

Expand All @@ -88,6 +97,10 @@ def __init__(
vocab_size_note: int = 128,
vocab_size_velocity: int = 2,
vocab_size_time: int = 100,
n_fft: int = 4096,
hop_length: int = 1024,
f_min: float = 10.0,
n_mels: int = 512,
feature_size=None,
**kwargs,
):
Expand All @@ -105,6 +118,27 @@ def __init__(
self.vocab_size_note = vocab_size_note
self.vocab_size_velocity = vocab_size_velocity
self.vocab_size_time = vocab_size_time
self.n_fft = n_fft
self.hop_length = hop_length
self.f_min = f_min
self.n_mels = n_mels

def log_mel_spectogram(self, sequence):
"""Generates MelSpectrogram then applies log base e."""

melspectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sampling_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
f_min=self.f_min,
n_mels=self.n_mels,
)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
X = melspectrogram(sequence)
X = X.clamp(min=1e-6).log()

return X

def extract_rhythm(self, raw_audio):
"""
Expand Down Expand Up @@ -265,8 +299,11 @@ def __call__(
beatstep=beatsteps - beatsteps[0],
n_bars=self.n_bars,
)
batch = batch.cpu().numpy()

# Apply LogMelSpectogram
batch = self.log_mel_spectogram(batch).transpose(-1, -2)

batch = batch.cpu().numpy()
output = BatchFeature(
{
"input_features": batch,
Expand Down
51 changes: 9 additions & 42 deletions src/transformers/models/pop2piano/modeling_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import numpy as np
import torch
import torchaudio
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
Expand Down Expand Up @@ -216,28 +215,6 @@ def _shift_right(self, input_ids):
return shifted_input_ids


class LogMelSpectrogram(nn.Module):
"""Generates MelSpectrogram then applies log base e."""

def __init__(self, sampling_rate, n_fft, hop_length, f_min, n_mels):
super(LogMelSpectrogram, self).__init__()
self.melspectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=sampling_rate,
n_fft=n_fft,
hop_length=hop_length,
f_min=f_min,
n_mels=n_mels,
)

def forward(self, x):
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
X = self.melspectrogram(x)
X = X.clamp(min=1e-6).log()

return X


class ConcatEmbeddingToMel(nn.Module):
"""Embedding Matrix for `composer` tokens."""

Expand Down Expand Up @@ -1086,20 +1063,12 @@ def __init__(self, config: Pop2PianoConfig):
self.config = config
self.model_dim = config.d_model

self.spectrogram = LogMelSpectrogram(
sampling_rate=config.dataset_sampling_rate,
n_fft=config.n_fft,
hop_length=config.hop_length,
f_min=config.f_min,
n_mels=config.n_mels,
n_dim = 512
composer_n_vocab = len(config.composer_to_feature_token)
embedding_offset = min(config.composer_to_feature_token.values())
self.mel_conditioner = ConcatEmbeddingToMel(
embedding_offset=embedding_offset, n_vocab=composer_n_vocab, n_dim=n_dim
)
if config.dataset_mel_is_conditioned:
n_dim = 512
composer_n_vocab = len(config.composer_to_feature_token)
embedding_offset = min(config.composer_to_feature_token.values())
self.mel_conditioner = ConcatEmbeddingToMel(
embedding_offset=embedding_offset, n_vocab=composer_n_vocab, n_dim=n_dim
)

self.shared = nn.Embedding(config.vocab_size, config.d_model)

Expand Down Expand Up @@ -1390,12 +1359,10 @@ def generate(
else max_length
)

inputs_embeds = self.spectrogram(input_features["input_features"]).transpose(-1, -2)
if self.config.dataset_mel_is_conditioned:
composer_value = composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value, device=self.device)
composer_value = composer_value.repeat(inputs_embeds.shape[0])
inputs_embeds = self.mel_conditioner(inputs_embeds, composer_value)
composer_value = composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value, device=self.device)
composer_value = composer_value.repeat(input_features["input_features"].shape[0])
inputs_embeds = self.mel_conditioner(input_features["input_features"], composer_value)

return super().generate(
inputs,
Expand Down
42 changes: 7 additions & 35 deletions tests/models/pop2piano/test_feature_extraction_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,14 @@ def test_call(self):
speech_input = np.zeros(
[
1000000,
]
],
dtype=np.float32,
)

input_features = feature_extractor(speech_input, audio_sr=16_000, return_tensors="np")
self.assertTrue(input_features.input_features.ndim == 2)
self.assertTrue(input_features.input_features.ndim == 3)
self.assertEqual(input_features.input_features.shape[-1], 512)

self.assertTrue(input_features.beatsteps.ndim == 1)
self.assertTrue(input_features.ext_beatstep.ndim == 1)

Expand All @@ -163,44 +166,13 @@ def _load_datasamples(self, num_samples):

def test_integration(self):
EXPECTED_INPUT_FEATURES = torch.tensor(
[
-4.5434e-05,
-1.8900e-04,
-2.2150e-04,
-2.1844e-04,
-2.7647e-04,
-2.1334e-04,
-1.5305e-04,
-2.6124e-04,
-2.6863e-04,
-1.5969e-04,
-1.6224e-04,
-1.2900e-04,
-9.9139e-06,
1.5336e-05,
4.7507e-05,
9.3454e-05,
-2.3652e-05,
-1.2942e-04,
-1.0804e-04,
-1.4267e-04,
-1.5102e-04,
-6.7488e-05,
-9.6527e-05,
-9.6909e-05,
8.0032e-05,
8.1948e-05,
-7.3148e-05,
3.4405e-05,
1.5065e-04,
-1.0989e-04,
]
[[-7.1493, -6.8701, -4.3214], [-5.9473, -5.7548, -3.8438], [-6.1324, -5.9018, -4.3778]]
)

input_speech, sampling_rate = self._load_datasamples(1)
feaure_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
input_features = feaure_extractor(input_speech, audio_sr=sampling_rate[0], return_tensors="pt").input_features
self.assertTrue(torch.allclose(input_features[0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))
self.assertTrue(torch.allclose(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, atol=1e-4))

@unittest.skip("Pop2PianoFeatureExtractor does not return attention_mask")
def test_attention_mask(self):
Expand Down
25 changes: 3 additions & 22 deletions tests/models/pop2piano/test_modeling_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,25 +637,6 @@ def test_generate_with_past_key_values(self):
@require_torch
@require_torchaudio
class Pop2PianoModelIntegrationTests(unittest.TestCase):
@slow
def test_log_mel_spectrogram_integration(self):
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
inputs = torch.ones([10, 100000])
output = model.spectrogram(inputs)

# check shape
self.assertEqual(output.size(), torch.Size([10, 512, 98]))

# check values
self.assertEqual(
output[0, :3, :3].cpu().numpy().tolist(),
[
[-13.815510749816895, -13.815510749816895, -13.815510749816895],
[-13.815510749816895, -13.815510749816895, -13.815510749816895],
[-13.815510749816895, -13.815510749816895, -13.815510749816895],
],
)

@slow
def test_mel_conditioner_integration(self):
composer = "composer1"
Expand All @@ -680,11 +661,11 @@ def test_mel_conditioner_integration(self):
def test_full_model_integration(self):
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
model.eval()
input_features = BatchFeature({"input_features": torch.ones([100, 100000])})
input_features = BatchFeature({"input_features": torch.ones([75, 66, 512])})
outputs = model.generate(input_features=input_features)

# check for shapes
self.assertEqual(outputs.size(0), 100)
self.assertEqual(outputs.size(0), 75)

# check for values
self.assertEqual(outputs[0, :3].detach().cpu().numpy().tolist(), [0, 134, 133])
self.assertEqual(outputs[0, :3].detach().cpu().numpy().tolist(), [0, 1])

0 comments on commit e5d2aec

Please sign in to comment.