Skip to content

Commit

Permalink
import solved
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Mar 23, 2023
1 parent a82ada7 commit 4d9fcc3
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 299 deletions.
25 changes: 23 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
is_bitsandbytes_available,
is_flax_available,
is_keras_nlp_available,
is_music_available,
is_sentencepiece_available,
is_speech_available,
is_tensorflow_text_available,
Expand Down Expand Up @@ -415,7 +416,6 @@
"models.pop2piano": [
"POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Pop2PianoConfig",
"Pop2PianoFeatureExtractor",
],
"models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"],
"models.qdqbert": ["QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "QDQBertConfig"],
Expand Down Expand Up @@ -3431,6 +3431,20 @@
_import_structure["trainer_tf"] = ["TFTrainer"]


# music-backed objects
try:
if not is_music_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_music_objects

_import_structure["utils.dummy_music_objects"] = [
name for name in dir(dummy_music_objects) if not name.startswith("_")
]
else:
_import_structure["models.pop2piano"].append("Pop2PianoFeatureExtractor")


# FLAX-backed objects
try:
if not is_flax_available():
Expand Down Expand Up @@ -4055,7 +4069,6 @@
from .models.pop2piano import (
POP2PIANO_PRETRAINED_CONFIG_ARCHIVE_MAP,
Pop2PianoConfig,
Pop2PianoFeatureExtractor,
)
from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer
from .models.qdqbert import QDQBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, QDQBertConfig
Expand Down Expand Up @@ -6535,6 +6548,14 @@
# Trainer
from .trainer_tf import TFTrainer

try:
if not is_music_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_music_objects import *
else:
from .models.pop2piano import Pop2PianoFeatureExtractor

try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
Expand Down
298 changes: 1 addition & 297 deletions src/transformers/models/pop2piano/feature_extraction_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@
# limitations under the License.
""" Feature extractor class for Pop2Piano"""

import os
import warnings
from typing import List, Optional, Union

import essentia
import essentia.standard
import librosa
import numpy as np
import pretty_midi
import scipy
import soundfile as sf
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
Expand All @@ -36,17 +33,6 @@

logger = logging.get_logger(__name__)

TOKEN_SPECIAL: int = 0
TOKEN_NOTE: int = 1
TOKEN_VELOCITY: int = 2
TOKEN_TIME: int = 3

DEFAULT_VELOCITY: int = 77

TIE: int = 2
EOS: int = 1
PAD: int = 0


class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
r"""
Expand All @@ -56,8 +42,7 @@ class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
most of the main methods. Users should refer to this superclass for more information regarding those methods.
Args:
This class loads audio, extracts rhythm and does preprocesses before being passed through `LogMelSpectrogram`. This:
class also contains postprocessing methods to convert model outputs to midi audio and stereo-mix.
This class extracts rhythm and does preprocesses before being passed through the transformer model.
n_bars (`int`, *optional*, defaults to 2):
Determines `n_steps` in method `preprocess_mel`.
sampling_rate (`int`, *optional*, defaults to 22050):
Expand Down Expand Up @@ -316,284 +301,3 @@ def __call__(
output = output.convert_to_tensors(return_tensors)

return output

def decode(self, token, time_idx_offset):
"""Decodes the tokens generated by the transformer"""

if token >= (self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity):
type, value = TOKEN_TIME, (
(token - (self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity)) + time_idx_offset
)
elif token >= (self.vocab_size_special + self.vocab_size_note):
type, value = TOKEN_VELOCITY, (token - (self.vocab_size_special + self.vocab_size_note))
value = int(value)
elif token >= self.vocab_size_special:
type, value = TOKEN_NOTE, (token - self.vocab_size_special)
value = int(value)
else:
type, value = TOKEN_SPECIAL, token
value = int(value)

return [type, value]

def relative_batch_tokens_to_midi(
self,
tokens,
beatstep,
beat_offset_idx=None,
bars_per_batch=None,
cutoff_time_idx=None,
):
"""Converts tokens to midi"""

beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx
notes = None
bars_per_batch = 2 if bars_per_batch is None else bars_per_batch

N = len(tokens)
for n in range(N):
_tokens = tokens[n]
_start_idx = beat_offset_idx + n * bars_per_batch * 4
_cutoff_time_idx = cutoff_time_idx + _start_idx
_notes = self.relative_tokens_to_notes(
_tokens,
start_idx=_start_idx,
cutoff_time_idx=_cutoff_time_idx,
)

if len(_notes) == 0:
pass
elif notes is None:
notes = _notes
else:
notes = np.concatenate((notes, _notes), axis=0)

if notes is None:
notes = []
midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx])
return midi, notes

def relative_tokens_to_notes(self, tokens, start_idx, cutoff_time_idx=None):
# decoding If the first token is an arranger
if tokens[0] >= (
self.vocab_size_special + self.vocab_size_note + self.vocab_size_velocity + self.vocab_size_time
):
tokens = tokens[1:]

words = [self.decode(token, time_idx_offset=0) for token in tokens]

if hasattr(start_idx, "item"):
"""if numpy or torch tensor"""
start_idx = start_idx.item()

current_idx = start_idx
current_velocity = 0
note_onsets_ready = [None for i in range(self.vocab_size_note + 1)]
notes = []
for type, number in words:
if type == TOKEN_SPECIAL:
if number == EOS:
break
elif type == TOKEN_TIME:
current_idx += number
if cutoff_time_idx is not None:
current_idx = min(current_idx, cutoff_time_idx)

elif type == TOKEN_VELOCITY:
current_velocity = number
elif type == TOKEN_NOTE:
pitch = number
if current_velocity == 0:
# note_offset
if note_onsets_ready[pitch] is None:
# offset without onset
pass
else:
onset_idx = note_onsets_ready[pitch]
if onset_idx >= current_idx:
# No time shift after previous note_on
pass
else:
offset_idx = current_idx
notes.append([onset_idx, offset_idx, pitch, DEFAULT_VELOCITY])
note_onsets_ready[pitch] = None
else:
# note_on
if note_onsets_ready[pitch] is None:
note_onsets_ready[pitch] = current_idx
else:
# note-on already exists
onset_idx = note_onsets_ready[pitch]
if onset_idx >= current_idx:
# No time shift after previous note_on
pass
else:
offset_idx = current_idx
notes.append([onset_idx, offset_idx, pitch, DEFAULT_VELOCITY])
note_onsets_ready[pitch] = current_idx
else:
raise ValueError

for pitch, note_on in enumerate(note_onsets_ready):
# force offset if no offset for each pitch
if note_on is not None:
if cutoff_time_idx is None:
cutoff = note_on + 1
else:
cutoff = max(cutoff_time_idx, note_on + 1)

offset_idx = max(current_idx, cutoff)
notes.append([note_on, offset_idx, pitch, DEFAULT_VELOCITY])

if len(notes) == 0:
return []
else:
notes = np.array(notes)
note_order = notes[:, 0] * 128 + notes[:, 1]
notes = notes[note_order.argsort()]
return notes

def notes_to_midi(self, notes, beatstep, offset_sec=None):
"""Converts notes to midi"""

new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0)
new_inst = pretty_midi.Instrument(program=0)
new_notes = []
if offset_sec is None:
offset_sec = 0.0

for onset_idx, offset_idx, pitch, velocity in notes:
new_note = pretty_midi.Note(
velocity=velocity,
pitch=pitch,
start=beatstep[onset_idx] - offset_sec,
end=beatstep[offset_idx] - offset_sec,
)
new_notes.append(new_note)
new_inst.notes = new_notes
new_pm.instruments.append(new_inst)
new_pm.remove_invalid_notes()
return new_pm

def get_stereo(self, pop_y, midi_y, pop_scale=0.99):
"""Generates stereo audio using `pop audio(`pop_y`)` and `generated midi audio(`midi_y`)`"""

if len(pop_y) > len(midi_y):
midi_y = np.pad(midi_y, (0, len(pop_y) - len(midi_y)))
elif len(pop_y) < len(midi_y):
pop_y = np.pad(pop_y, (0, -len(pop_y) + len(midi_y)))
stereo = np.stack((midi_y, pop_y * pop_scale))
return stereo

def _to_np(self, tensor):
"""Converts pytorch tensor to np.ndarray."""
if isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, torch.Tensor):
return tensor.cpu().numpy()
else:
raise ValueError("dtype not understood! Please use wither torch.Tensor or np.ndarray")

def postprocess(
self,
relative_tokens: Union[np.ndarray, torch.Tensor],
beatsteps: Union[np.ndarray, torch.Tensor],
ext_beatstep: Union[np.ndarray, torch.Tensor],
raw_audio: Union[np.ndarray, List[float], List[np.ndarray]],
sampling_rate: int,
mix_sampling_rate=None,
save_path: str = None,
audio_file_name: str = None,
save_midi: bool = False,
save_mix: bool = False,
click_amp: float = 0.2,
stereo_amp: float = 0.5,
add_click: bool = False,
):
r"""
Args:
Postprocess step. It also saves the `"generated midi audio"`, `"stereo-mix"`
relative_tokens ([`~utils.TensorType`]):
Output of `Pop2PianoConditionalGeneration` model.
beatsteps ([`~utils.TensorType`]):
beatsteps returned by `Pop2PianoFeatureExtractor.__call__`
ext_beatstep ([`~utils.TensorType`]):
ext_beatstep returned by `Pop2PianoFeatureExtractor.__call__`
raw_audio (`np.ndarray`, `List`):
Denotes the raw_audio.
sampling_rate (`int`):
Denotes the Sampling Rate of `raw_audio`.
mix_sampling_rate (`int`, *optional*):
Denotes the Sampling Rate for `stereo-mix`.
audio_file_name (`str`, *optional*):
Name of the file to be saved.
save_path (`str`, *optional*):
Path where the `stereo-mix` and `midi-audio` is to be saved.
save_midi (`bool`, *optional*):
Whether to save `midi-audio` or not.
save_mix (`bool`, *optional*):
Whether to save `stereo-mix` or not.
add_click (`bool`, *optional*, defaults to `False`):
Constructs a `"click track"`.
click_amp (`float`, *optional*, defaults to 0.2):
Amplitude for `"click track"`.
Returns:
`pretty_midi.pretty_midi.PrettyMIDI` : returns pretty_midi object.
"""

relative_tokens = self._to_np(relative_tokens)
beatsteps = self._to_np(beatsteps)
ext_beatstep = self._to_np(ext_beatstep)

if (save_midi or save_mix) and save_path is None:
raise ValueError("If you want to save any mix or midi file then you must define save_path.")

if save_path and (not save_midi and not save_mix):
raise ValueError(
"You are setting save_path but not saving anything, use save_midi=True to "
"save the midi file and use save_mix to save the mix file or do both!"
)

mix_sampling_rate = sampling_rate if mix_sampling_rate is None else mix_sampling_rate

if save_path is not None:
if os.path.isdir(save_path):
midi_path = os.path.join(save_path, f"midi_output_{audio_file_name}.mid")
mix_path = os.path.join(save_path, f"mix_output_{audio_file_name}.wav")
else:
raise ValueError(f"Is {save_path} a directory?")

pm, notes = self.relative_batch_tokens_to_midi(
tokens=relative_tokens,
beatstep=ext_beatstep,
bars_per_batch=self.n_bars,
cutoff_time_idx=(self.n_bars + 1) * 4,
)
for n in pm.instruments[0].notes:
n.start += beatsteps[0]
n.end += beatsteps[0]

if save_midi:
pm.write(midi_path)
print(f"midi file saved at {midi_path}!")

if save_mix:
if mix_sampling_rate != sampling_rate:
raw_audio = librosa.core.resample(raw_audio, orig_sr=sampling_rate, target_sr=mix_sampling_rate)
sampling_rate = mix_sampling_rate
if add_click:
clicks = librosa.clicks(times=beatsteps, sr=sampling_rate, length=len(raw_audio)) * click_amp
raw_audio = raw_audio + clicks
pm_raw_audio = pm.fluidsynth(sampling_rate)
stereo = self.get_stereo(raw_audio, pm_raw_audio, pop_scale=stereo_amp)

sf.write(
file=mix_path,
data=stereo.T,
samplerate=sampling_rate,
format="wav",
)
print(f"stereo-mix file saved at {mix_path}!")

return pm
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
is_kenlm_available,
is_keras_nlp_available,
is_librosa_available,
is_music_available,
is_natten_available,
is_ninja_available,
is_onnx_available,
Expand Down
Loading

0 comments on commit 4d9fcc3

Please sign in to comment.