Skip to content

Commit

Permalink
add cross-fading to hifi-synth
Browse files Browse the repository at this point in the history
  • Loading branch information
brentspell committed Jun 7, 2022
1 parent 63579ac commit 6fd6323
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
58 changes: 57 additions & 1 deletion hifi_gan_bwe/scripts/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import soundfile
import torch
from tqdm import tqdm

from hifi_gan_bwe import models

Expand All @@ -39,6 +40,18 @@ def main() -> None:
default="cpu",
help="torch device to use for synthesis (ex: cpu, cuda, cuda:1, etc.)",
)
parser.add_argument(
"--fade_stride",
type=float,
default=30,
help="streaming chunk length, in seconds",
)
parser.add_argument(
"--fade_length",
type=float,
default=0.025,
help="cross-fading overlap, in seconds",
)

args = parser.parse_args()

Expand All @@ -58,11 +71,54 @@ def main() -> None:

# run the bandwidth extender on each audio channel
inputs = torch.from_numpy(audio).to(args.device)
audio = torch.stack([model(x, sample_rate) for x in inputs.T]).T.cpu().numpy()
audio = (
torch.stack([_stream(args, model, x, sample_rate) for x in inputs.T])
.T.cpu()
.numpy()
)

# save the output file
soundfile.write(args.target_path, audio, samplerate=int(model.sample_rate))


def _stream(
args: argparse.Namespace,
model: torch.nn.Module,
x: torch.Tensor,
sample_rate: int,
) -> torch.Tensor:
stride_samples = int(args.fade_stride) * sample_rate
fade_samples = int(args.fade_length * sample_rate)

# create a linear cross-fader
fade_in = torch.linspace(0, 1, fade_samples).to(x.device)
fade_ou = fade_in.flip(0)

# window the audio into overlapping frames
frames = x.unfold(
dimension=0,
size=stride_samples + fade_samples,
step=stride_samples,
)
prev = torch.zeros_like(fade_ou)
output = []
for frame in tqdm(frames):
# run the bandwidth extender on the current frame
y = model(frame, sample_rate)

# fade out the previous frame, fade in the current
y[:fade_samples] = prev * fade_ou + y[:fade_samples] * fade_in

# save off the previous frame for fading into the next
# and add the current frame to the output
prev = y[-fade_samples:]
output.append(y[:-fade_samples])

# tack on the fade out of the last frame
output.append(prev)

return torch.cat(output)


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="hifi-gan-bwe",
version="0.1.13",
version="0.1.14",
description=(
"Unofficial implementation of the HiFi-GAN+ model "
"for audio bandwidth extension"
Expand Down

0 comments on commit 6fd6323

Please sign in to comment.