In [None]:
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import torchaudio
import librosa
import torch
import numpy as np

from IPython import display as display_


def viz(wav, sr):
    plt.figure(figsize=(20, 5))
    plt.plot(wav)
    plt.show()

    display_.display(display_.Audio(wav, rate=sr, normalize=False))


class GriffinLimVocoder:
    def __init__(
        self,
        sr: int,
        n_fft: int,
        fmax: int,
        n_mels: int = 80,
        power: float = 1.2,
        n_iters: int = 50,
    ):
        self.sr = sr
        self.n_fft = n_fft
        self.fmax = fmax
        self.n_mels = n_mels
        self.power = power
        self.n_iters = n_iters
        self.mag_scale = n_fft
        
        self.filterbank = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmax=fmax)
        
    def _griffin_lim(self, magnitudes):
        """Griffin-Lim algorithm to convert magnitude spectrograms to audio signals
        """
        phase = np.exp(2j * np.pi * np.random.rand(*magnitudes.shape))
        complex_spec = magnitudes * phase
        signal = librosa.istft(complex_spec)
        if not np.isfinite(signal).all():
            logging.warning("audio was not finite")
            return np.array([0])

        for _ in range(self.n_iters):
            _, phase = librosa.magphase(librosa.stft(signal, n_fft=self.n_fft))
            complex_spec = magnitudes * phase
            signal = librosa.istft(complex_spec)
        return signal
        
        
    def apply_vocoder(self, log_mel_spec):
        """Applies griffin-lim vocoder
        Args:
            log_mel_spec (np.ndarray): (n_frames, n_mels)
        """
        mel = np.exp(log_mel_spec)
        magnitude = np.dot(mel, self.filterbank) * self.mag_scale
        audio = self._griffin_lim(magnitude.T ** self.power)
        
        return audio


In [None]:
ROOT_PATH = Path("...")


with open(ROOT_PATH / "texts.txt", "r") as file:
    text_lines = list(map(lambda x: x.strip(), file.readlines()))

spectrograms = {}
for file_path in ROOT_PATH.iterdir():
    if file_path.name == "texts.txt":
        continue
        
    spec = torch.load(file_path)
    spectrograms[file_path.stem] = spec

In [None]:
vocoder = GriffinLimVocoder(sr=22050, n_fft=1024, fmax=8000)

In [None]:
spectrogram = "spec_5_speaker_1"

spec = spectrograms[spectrogram].squeeze().transpose(0, 1).numpy()
audio = vocoder.apply_vocoder(spec)
viz(audio, 22050)