Skip to content

Commit

Permalink
Speaker embeddings computation for metavoice. (#1800)
Browse files Browse the repository at this point in the history
* Speaker embeddings computation for metavoice.

* Compute the speaker embeddings.
  • Loading branch information
LaurentMazare committed Mar 4, 2024
1 parent 6530932 commit 8cc0a18
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 23 deletions.
130 changes: 108 additions & 22 deletions candle-transformers/src/models/metavoice.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle::{DType, Error as E, IndexOp, Module, Result, Tensor, D};
use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};

// Equivalent to torch.repeat_interleave
Expand All @@ -13,22 +13,41 @@ pub mod speaker_encoder {

#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub mel_window_step: usize,
pub mel_n_channels: usize,
pub sampling_rate: usize,
pub partial_n_frames: usize,
pub model_hidden_size: usize,
pub model_embedding_size: usize,
pub model_num_layers: usize,
pub mel_window_length: usize,
pub mel_window_step: usize,
pub mel_n_channels: usize,
}

impl Config {
pub fn cfg() -> Self {
Self {
sampling_rate: 16_000,
partial_n_frames: 160,
model_hidden_size: 256,
model_embedding_size: 256,
model_num_layers: 3,
mel_window_length: 25,
mel_window_step: 10,
mel_n_channels: 40,
}
}
}

pub struct Model {
lstms: Vec<candle_nn::LSTM>,
linear: Linear,
cfg: Config,
}

type Slice = (usize, usize);

impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
let mut lstms = Vec::with_capacity(cfg.model_num_layers);
let vb_l = vb.pp("lstm");
for layer_idx in 0..cfg.model_num_layers {
Expand All @@ -50,36 +69,103 @@ pub mod speaker_encoder {
true,
vb.pp("linear"),
)?;
Ok(Self { lstms, linear })
Ok(Self { lstms, linear, cfg })
}

fn compute_partial_slices(
_n_samples: usize,
_rate: f64,
_min_coverage: f64,
) -> Result<(Tensor, Tensor)> {
todo!()
}

pub fn embed_utterance(&self, wav: &[f32], rate: f64, min_coverage: f64) -> Result<Tensor> {
let (_wav_slices, _mel_slices) =
Self::compute_partial_slices(wav.len(), rate, min_coverage)?;
todo!()
&self,
n_samples: usize,
rate: f64,
min_coverage: f64,
) -> (Vec<Slice>, Vec<Slice>) {
let c = &self.cfg;
// Compute how many frames separate two partial utterances
let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
let n_frames = n_samples / samples_per_frame + 1;
let frame_step =
(c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
// Compute the slices.
let mut wav_slices = vec![];
let mut mel_slices = vec![];
for i in (0..steps).step_by(frame_step) {
let mel_range = (i, i + c.partial_n_frames);
let wav_range = (
i * samples_per_frame,
(i + c.partial_n_frames) * samples_per_frame,
);
mel_slices.push(mel_range);
wav_slices.push(wav_range);
}
// Evaluate whether extra padding is warranted or not.
let last_wav_range = match wav_slices.last() {
None => return (wav_slices, mel_slices),
Some(l) => *l,
};
let coverage = (n_samples - last_wav_range.0) as f64
/ (last_wav_range.1 - last_wav_range.0) as f64;
if coverage > min_coverage && mel_slices.len() > 1 {
mel_slices.pop();
wav_slices.pop();
}
(wav_slices, mel_slices)
}

pub fn embed_utterance(
&self,
wav: &[f32],
mel_filters: &[f32],
rate: f64,
min_c: f64,
device: &Device,
) -> Result<Tensor> {
let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
let max_wave_length = match wav_slices.last() {
Some(v) => v.1,
None => candle::bail!("empty wav slices"),
};
let wav = if max_wave_length > wav.len() {
let mut wav = wav.to_vec();
wav.resize(max_wave_length - wav.len(), 0.0);
std::borrow::Cow::Owned(wav)
} else {
std::borrow::Cow::Borrowed(wav)
};
let mel = crate::models::whisper::audio::log_mel_spectrogram_(
wav.as_ref(),
mel_filters,
/* fft_size */ self.cfg.mel_window_length,
/* fft_step */ self.cfg.mel_window_step,
self.cfg.mel_n_channels,
false,
);
let mels = mel_slices
.iter()
.flat_map(|s| [mel[s.0], mel[s.1]])
.collect::<Vec<_>>();
let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
let partial_embeds = self.forward(&mels)?;
let raw_embed = partial_embeds.mean(0)?;
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
raw_embed.broadcast_div(&norm)
}
}

impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
use candle_nn::RNN;

// This is different from the Python transformers version as candle LSTM is batch first.
let xs = xs.t()?;
let mut xs = xs.clone();
for lstm in self.lstms.iter() {
let res = lstm.seq(&xs)?;
let res: Vec<_> = res.into_iter().map(|s| s.h().clone()).collect();
xs = Tensor::stack(&res, 1)?;
for layer in self.lstms.iter() {
let states = layer.seq(&xs)?;
xs = layer.states_to_tensor(&states)?;
}
let xs = xs.t()?;
let embeds_raw = xs.apply(&self.linear)?.relu()?;
// TODO: normalize.
Ok(embeds_raw)
let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
embeds_raw.broadcast_div(&norm)
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/whisper/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ fn log_mel_spectrogram_w<T: Float>(
mel
}

fn log_mel_spectrogram_<T: Float>(
pub fn log_mel_spectrogram_<T: Float>(
samples: &[T],
filters: &[T],
fft_size: usize,
Expand Down

0 comments on commit 8cc0a18

Please sign in to comment.