Skip to content

Commit

Permalink
tokenizer changed to only return midi_object and other changes
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Apr 10, 2023
1 parent d86224d commit 941255a
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 350 deletions.
63 changes: 39 additions & 24 deletions docs/source/en/model_doc/pop2piano.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,47 @@ Tips:
This model was contributed by [Susnato Dhar](https://huggingface.co/susnato).
The original code can be found [here](https://github.com/sweetcocoa/pop2piano).

Example:
```
import librosa
from transformers import Pop2PianoFeatureExtractor, Pop2PianoForConditionalGeneration, Pop2PianoTokenizer

raw_audio, sr = librosa.load("audio.mp3", sr=44100)
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")

model.eval()

feature_extractor_outputs = feature_extractor(raw_audio=raw_audio, audio_sr=sr, return_tensors="pt")
model_outputs = model.generate(feature_extractor_outputs, composer="composer1")

opt_postprocess = tokenizer(relative_tokens=model_outputs,
beatsteps=feature_extractor_outputs["beatsteps"],
ext_beatstep=feature_extractor_outputs["ext_beatstep"],
raw_audio=raw_audio,
sampling_rate=sr,
save_path="./Music/Outputs/",
audio_file_name="filename",
save_midi=True
)
###Example using HuggingFace Dataset:###
```python
>>> from datasets import load_dataset
>>> from transformers import Pop2PianoForConditionalGeneration, Pop2PianoTokenizer, Pop2PianoFeatureExtractor
>>> model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev").to("cuda")
>>> model.eval()
>>> feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
>>> tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
>>> ds = load_dataset("sweetcocoa/pop2piano_ci", split="test")
>>> fe_output = feature_extractor(ds["audio"][0]["array"], audio_sr=ds["audio"][0]["sampling_rate"]).to("cuda")
>>> model_output = model.generate(fe_output, composer="composer1")
>>> tokenizer_output = tokenizer(
... relative_tokens=model_output.cpu(),
... beatsteps=fe_output["beatsteps"].cpu(),
... ext_beatstep=fe_output["ext_beatstep"].cpu(),
... )
>>> tokenizer_output.write("./Outputs/midi_output.mid")
```
###Example using Your own Audio:###
```python
>>> import librosa
>>> from transformers import Pop2PianoFeatureExtractor, Pop2PianoForConditionalGeneration, Pop2PianoTokenizer
>>> raw_audio, sr = librosa.load("<your_audio_file_here>", sr=44100)
>>> model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev").to("cuda")
>>> model.eval()
>>> feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
>>> tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
>>> fe_output = feature_extractor(raw_audio, audio_sr=sr).to("cuda")
>>> model_output = model.generate(fe_output, composer="composer1")
>>> tokenizer_output = tokenizer(
... relative_tokens=model_output.cpu(),
... beatsteps=fe_output["beatsteps"].cpu(),
... ext_beatstep=fe_output["ext_beatstep"].cpu(),
... )
>>> tokenizer_output.write("./Outputs/midi_output.mid")
```

## Pop2PianoConfig

Expand Down
42 changes: 22 additions & 20 deletions src/transformers/models/pop2piano/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,15 @@
is_librosa_available,
is_pretty_midi_available,
is_scipy_available,
is_soundfile_availble,
is_torch_available,
is_torchaudio_available,
)


_import_structure = {
"configuration_pop2piano": ["POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP", "Pop2PianoConfig"],
}

# Model
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand All @@ -42,27 +41,29 @@
"Pop2PianoPreTrainedModel",
]

# Feature Extractor
try:
if not (
is_librosa_available()
and is_essentia_available()
and is_scipy_available()
and is_pretty_midi_available()
and is_soundfile_availble()
and is_torch_available()
and is_torchaudio_available()
):
if not (is_librosa_available() and is_essentia_available() and is_scipy_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["feature_extraction_pop2piano"] = ["Pop2PianoFeatureExtractor"]

# Tokenizer
try:
if not (is_pretty_midi_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["tokenization_pop2piano"] = ["Pop2PianoTokenizer"]


if TYPE_CHECKING:
from .configuration_pop2piano import POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP, Pop2PianoConfig

# Model
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
Expand All @@ -75,21 +76,22 @@
Pop2PianoPreTrainedModel,
)

# Feature Extractor
try:
if not (
is_librosa_available()
and is_essentia_available()
and is_scipy_available()
and is_pretty_midi_available()
and is_soundfile_availble()
and is_torch_available()
and is_torchaudio_available()
):
if not (is_librosa_available() and is_essentia_available() and is_scipy_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .feature_extraction_pop2piano import Pop2PianoFeatureExtractor

# Tokenizer
try:
if not (is_pretty_midi_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .tokenization_pop2piano import Pop2PianoTokenizer

else:
Expand Down
41 changes: 29 additions & 12 deletions src/transformers/models/pop2piano/feature_extraction_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import numpy as np
import scipy
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence

from ...audio_utils import fram_wave, get_mel_filter_banks, stft
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...utils import TensorType, logging
Expand Down Expand Up @@ -93,19 +93,34 @@ def __init__(
def log_mel_spectogram(self, sequence):
"""Generates MelSpectrogram then applies log base e."""

melspectrogram = torchaudio.transforms.MelSpectrogram(
mel_fb = get_mel_filter_banks(
nb_frequency_bins=(self.n_fft // 2) + 1,
nb_mel_filters=self.n_mels,
frequency_min=self.f_min,
frequency_max=float(self.sampling_rate // 2),
sample_rate=self.sampling_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
f_min=self.f_min,
n_mels=self.n_mels,
norm=None,
mel_scale="htk",
).astype(np.float32)

spectrogram = []
for seq in sequence:
window = np.hanning(self.n_fft + 1)[:-1]
framed_audio = fram_wave(seq, self.hop_length, self.n_fft)
spec = stft(framed_audio, window, fft_window_size=self.n_fft)
spec = np.abs(spec) ** 2.0
spectrogram.append(spec)

spec_shape = spec.shape
spectrogram = torch.Tensor(spectrogram).view(-1, *spec_shape)
log_melspec = (
torch.matmul(spectrogram.transpose(-1, -2), torch.from_numpy(mel_fb))
.transpose(-1, -2)
.clamp(min=1e-6)
.log()
)
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
X = melspectrogram(sequence)
X = X.clamp(min=1e-6).log()

return X
return log_melspec

def extract_rhythm(self, raw_audio):
"""
Expand Down Expand Up @@ -238,6 +253,7 @@ def __call__(
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
"""
warnings.warn("Please make sure to have the audio sampling_rate as 44100, to get the optimal performence!")
warnings.warn(
"Pop2PianoFeatureExtractor only takes one raw_audio at a time, if you want to extract features from more than a single audio then you might need to call it multiple times."
)
Expand All @@ -250,10 +266,11 @@ def __call__(
beatsteps = self.interpolate_beat_times(beat_times, steps_per_beat, extend=True)

if self.sampling_rate != audio_sr and self.sampling_rate is not None:
# Change `raw_audio_sr` to `self.sampling_rate`
# Change audio_sr to self.sampling_rate
raw_audio = librosa.core.resample(
raw_audio, orig_sr=audio_sr, target_sr=self.sampling_rate, res_type="kaiser_best"
)

audio_sr = self.sampling_rate
start_sample = int(beatsteps[0] * audio_sr)
end_sample = int(beatsteps[-1] * audio_sr)
Expand Down
Loading

0 comments on commit 941255a

Please sign in to comment.