Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement melody encoder and support glide input #143

Merged
merged 16 commits into from
Oct 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ dur_prediction_args:
lambda_wdur_loss: 1.0
lambda_sdur_loss: 3.0

use_melody_encoder: false
melody_encoder_args:
hidden_size: 128
enc_layers: 4
use_glide_embed: false
glide_types: [up, down]
glide_embed_scale: 11.313708498984760 # sqrt(128)

pitch_prediction_args:
pitd_norm_min: -8.0
pitd_norm_max: 8.0
Expand Down
8 changes: 8 additions & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ dur_prediction_args:
lambda_wdur_loss: 1.0
lambda_sdur_loss: 3.0

use_melody_encoder: false
melody_encoder_args:
hidden_size: 128
enc_layers: 4
use_glide_embed: false
glide_types: [up, down]
glide_embed_scale: 11.313708498984760 # sqrt(128)

pitch_prediction_args:
pitd_norm_min: -8.0
pitd_norm_max: 8.0
Expand Down
20 changes: 14 additions & 6 deletions deployment/exporters/variance_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def _torch_export_model(self):
)

if self.model.predict_pitch:
use_melody_encoder = hparams.get('use_melody_encoder', False)
use_glide_embed = use_melody_encoder and hparams['use_glide_embed']
# Prepare inputs for preprocessor of PitchDiffusion
note_midi = torch.FloatTensor([[60.] * 4]).to(self.device)
note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device)
Expand All @@ -261,10 +263,12 @@ def _torch_export_model(self):
pitch_input_args = (
encoder_out,
ph_dur,
note_midi,
note_dur,
pitch,
{
'note_midi': note_midi,
**({'note_rest': note_midi >= 0} if use_melody_encoder else {}),
'note_dur': note_dur,
**({'note_glide': torch.zeros_like(note_midi, dtype=torch.long)} if use_glide_embed else {}),
'pitch': pitch,
**({'expr': torch.ones_like(pitch)} if self.expose_expr else {}),
'retake': retake,
**({'spk_embed': torch.rand(
Expand All @@ -277,8 +281,10 @@ def _torch_export_model(self):
pitch_input_args,
self.pitch_preprocess_cache_path,
input_names=[
'encoder_out', 'ph_dur',
'note_midi', 'note_dur',
'encoder_out', 'ph_dur', 'note_midi',
*(['note_rest'] if use_melody_encoder else []),
'note_dur',
*(['note_glide'] if use_glide_embed else []),
'pitch',
*(['expr'] if self.expose_expr else []),
'retake',
Expand All @@ -297,13 +303,15 @@ def _torch_export_model(self):
'note_midi': {
1: 'n_notes'
},
**({'note_rest': {1: 'n_notes'}} if use_melody_encoder else {}),
'note_dur': {
1: 'n_notes'
},
**({'expr': {1: 'n_frames'}} if self.expose_expr else {}),
**({'note_glide': {1: 'n_notes'}} if use_glide_embed else {}),
'pitch': {
1: 'n_frames'
},
**({'expr': {1: 'n_frames'}} if self.expose_expr else {}),
'retake': {
1: 'n_frames'
},
Expand Down
38 changes: 32 additions & 6 deletions deployment/modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,18 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):
return x_cond

def forward_pitch_preprocess(
self, encoder_out, ph_dur, note_midi, note_dur,
self, encoder_out, ph_dur,
note_midi=None, note_rest=None, note_dur=None, note_glide=None,
pitch=None, expr=None, retake=None, spk_embed=None
):
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
if self.use_melody_encoder:
melody_encoder_out = self.melody_encoder(
note_midi, note_rest, note_dur,
glide=note_glide
)
melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size)
condition += melody_encoder_out
if expr is None:
retake_embed = self.pitch_retake_embed(retake.long())
else:
Expand All @@ -178,8 +186,12 @@ def forward_pitch_preprocess(
pitch_cond = condition + retake_embed
frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None)
base_pitch = self.smooth(frame_midi_pitch)
base_pitch = base_pitch * retake + pitch * ~retake
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if self.use_melody_encoder:
delta_pitch = (pitch - base_pitch) * ~retake
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
else:
base_pitch = base_pitch * retake + pitch * ~retake
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if hparams['use_spk_id'] and spk_embed is not None:
pitch_cond += spk_embed
return pitch_cond, base_pitch
Expand Down Expand Up @@ -229,6 +241,8 @@ def view_as_linguistic_encoder(self):
model = copy.deepcopy(self)
if self.predict_pitch:
del model.pitch_predictor
if self.use_melody_encoder:
del model.melody_encoder
if self.predict_variances:
del model.variance_predictor
model.fs2 = model.fs2.view_as_encoder()
Expand All @@ -239,12 +253,14 @@ def view_as_linguistic_encoder(self):
return model

def view_as_dur_predictor(self):
assert self.predict_dur
model = copy.deepcopy(self)
if self.predict_pitch:
del model.pitch_predictor
if self.use_melody_encoder:
del model.melody_encoder
if self.predict_variances:
del model.variance_predictor
assert self.predict_dur
model.fs2 = model.fs2.view_as_dur_predictor()
model.forward = model.forward_dur_predictor
return model
Expand All @@ -260,18 +276,22 @@ def view_as_pitch_preprocess(self):
return model

def view_as_pitch_diffusion(self):
assert self.predict_pitch
model = copy.deepcopy(self)
del model.fs2
del model.lr
if self.use_melody_encoder:
del model.melody_encoder
if self.predict_variances:
del model.variance_predictor
assert self.predict_pitch
model.forward = model.forward_pitch_diffusion
return model

def view_as_pitch_postprocess(self):
model = copy.deepcopy(self)
del model.fs2
if self.use_melody_encoder:
del model.melody_encoder
if self.predict_variances:
del model.variance_predictor
model.forward = model.forward_pitch_postprocess
Expand All @@ -282,18 +302,22 @@ def view_as_variance_preprocess(self):
del model.fs2
if self.predict_pitch:
del model.pitch_predictor
if self.use_melody_encoder:
del model.melody_encoder
if self.predict_variances:
del model.variance_predictor
model.forward = model.forward_variance_preprocess
return model

def view_as_variance_diffusion(self):
assert self.predict_variances
model = copy.deepcopy(self)
del model.fs2
del model.lr
if self.predict_pitch:
del model.pitch_predictor
assert self.predict_variances
if self.use_melody_encoder:
del model.melody_encoder
model.forward = model.forward_variance_diffusion
return model

Expand All @@ -302,5 +326,7 @@ def view_as_variance_postprocess(self):
del model.fs2
if self.predict_pitch:
del model.pitch_predictor
if self.use_melody_encoder:
del model.melody_encoder
model.forward = model.forward_variance_postprocess
return model
34 changes: 31 additions & 3 deletions inference/ds_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def __init__(
smooth_kernel /= smooth_kernel.sum()
self.smooth.weight.data = smooth_kernel[None, None]

glide_types = hparams.get('glide_types', [])
assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.'
self.glide_map = {
'none': 0,
**{
typename: idx + 1
for idx, typename in enumerate(glide_types)
}
}

self.auto_completion_mode = len(predictions) == 0
self.global_predict_dur = 'dur' in predictions and hparams['predict_dur']
self.global_predict_pitch = 'pitch' in predictions and hparams['predict_pitch']
Expand Down Expand Up @@ -98,14 +108,15 @@ def preprocess_input(
note_seq = torch.FloatTensor(
[(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in param['note_seq'].split()]
).to(self.device)[None] # [B=1, T_n]
T_n = note_seq.shape[1]
note_dur_sec = torch.from_numpy(np.array([param['note_dur'].split()], np.float32)).to(self.device) # [B=1, T_n]
note_acc = torch.round(torch.cumsum(note_dur_sec, dim=1) / self.timestep + 0.5).long()
note_dur = torch.diff(note_acc, dim=1, prepend=note_acc.new_zeros(1, 1))
mel2note = self.lr(note_dur) # [B=1, T_s]
T_s = mel2note.shape[1]

summary['words'] = T_w
summary['notes'] = note_seq.shape[1]
summary['notes'] = T_n
summary['tokens'] = T_ph
summary['frames'] = T_s
summary['seconds'] = '%.2f' % (T_s * self.timestep)
Expand Down Expand Up @@ -156,6 +167,17 @@ def preprocess_input(
word_dur = mel2ph_to_dur(mel2word, T_w)
batch['word_dur'] = word_dur

batch['note_midi'] = note_seq
batch['note_dur'] = note_dur
batch['note_rest'] = note_seq < 0
if hparams.get('use_glide_embed', False) and param.get('note_glide') is not None:
batch['note_glide'] = torch.LongTensor(
[[self.glide_map.get(x, 0) for x in param['note_glide'].split()]]
).to(self.device)
else:
batch['note_glide'] = torch.zeros(1, T_n, dtype=torch.long, device=self.device)
batch['mel2note'] = mel2note

# Calculate frame-level MIDI pitch, which is a step function curve
frame_midi_pitch = torch.gather(
F.pad(note_seq, [1, 0]), 1, mel2note
Expand Down Expand Up @@ -250,6 +272,11 @@ def forward_model(self, sample):
word_dur = sample['word_dur']
ph_dur = sample['ph_dur']
mel2ph = sample['mel2ph']
note_midi = sample['note_midi']
note_rest = sample['note_rest']
note_dur = sample['note_dur']
note_glide = sample['note_glide']
mel2note = sample['mel2note']
base_pitch = sample['base_pitch']
expr = sample.get('expr')
pitch = sample.get('pitch')
Expand All @@ -271,8 +298,9 @@ def forward_model(self, sample):
ph_spk_mix_embed = spk_mix_embed = None

dur_pred, pitch_pred, variance_pred = self.model(
txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur,
mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr,
txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph,
note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note,
base_pitch=base_pitch, pitch=pitch, pitch_expr=expr,
ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed,
infer=True
)
Expand Down
49 changes: 49 additions & 0 deletions modules/fastspeech/variance_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,52 @@ def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_emb
return encoder_out, ph_dur_pred
else:
return encoder_out, None


class MelodyEncoder(nn.Module):
def __init__(self, enc_hparams: dict):
super().__init__()

def get_hparam(key):
return enc_hparams.get(key, hparams.get(key))

# MIDI inputs
hidden_size = get_hparam('hidden_size')
self.note_midi_embed = Linear(1, hidden_size)
self.note_dur_embed = Linear(1, hidden_size)

# ornament inputs
self.use_glide_embed = hparams['use_glide_embed']
self.glide_embed_scale = hparams['glide_embed_scale']
if self.use_glide_embed:
# 0: none, 1: up, 2: down
self.note_glide_embed = Embedding(len(hparams['glide_types']) + 1, hidden_size, padding_idx=0)

self.encoder = FastSpeech2Encoder(
None, hidden_size, num_layers=get_hparam('enc_layers'),
ffn_kernel_size=get_hparam('enc_ffn_kernel_size'),
ffn_padding=get_hparam('ffn_padding'), ffn_act=get_hparam('ffn_act'),
dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'),
use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos')
)
self.out_proj = Linear(hidden_size, hparams['hidden_size'])

def forward(self, note_midi, note_rest, note_dur, glide=None):
"""
:param note_midi: float32 [B, T_n], -1: padding
:param note_rest: bool [B, T_n]
:param note_dur: int64 [B, T_n]
:param glide: int64 [B, T_n]
:return: [B, T_n, H]
"""
midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None]
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None])
ornament_embed = 0
if self.use_glide_embed:
ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale
encoder_out = self.encoder(
midi_embed, dur_embed + ornament_embed,
padding_mask=note_midi < 0
)
encoder_out = self.out_proj(encoder_out)
return encoder_out
Loading