Skip to content

Commit

Permalink
export to script module for version independence
Browse files Browse the repository at this point in the history
  • Loading branch information
brentspell committed Apr 26, 2022
1 parent cee26bf commit e8cf1a2
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 17 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ The example below uses a pretrained HiFi-GAN+ model to upsample a 1 second
import torch
from hifi_gan_bwe import BandwidthExtender

model = BandwidthExtender.from_pretrained("hifi-gan-bwe-05-d3abf04-vctk-48kHz")
model = BandwidthExtender.from_pretrained("hifi-gan-bwe-05-cd9f4ca-vctk-48kHz")

fs = 24000
x = torch.full([fs], 261.63 / fs).cumsum(-1) % 1.0 - 0.5
Expand All @@ -54,7 +54,7 @@ the output can be in any format supported by

```shell
pipx run --python=python3.9 hifi-gan-bwe \
hifi-gan-bwe-05-d3abf04-vctk-48kHz \
hifi-gan-bwe-05-cd9f4ca-vctk-48kHz \
input.mp3 \
output.wav
```
Expand All @@ -66,7 +66,7 @@ the HiFi-GAN+ library into it and run synthesis, training, etc. using it.
```shell
pip install hifi-gan-bwe

hifi-synth hifi-gan-bwe-05-d3abf04-vctk-48kHz input.mp3 output.wav
hifi-synth hifi-gan-bwe-05-cd9f4ca-vctk-48kHz input.mp3 output.wav
```

## Pretrained Models
Expand All @@ -76,7 +76,7 @@ the link and use it offline.

|Name|Sample Rate|Parameters|Wandb Metrics|Notes|
|-|-|-|-|-|
|[hifi-gan-bwe-05-d3abf04-vctk-48kHz](https://cdn.brentspell.com/models/hifi-gan-bwe/hifi-gan-bwe-05-d3abf04-vctk-48kHz.pt)|48kHz|1M|[bwe-05-d3abf04](https://wandb.ai/brentspell/hifi-gan-bwe/runs/bwe-05-d3abf04?workspace=user-brentspell)|Trained for 200K iterations on the VCTK speech dataset with noise agumentation from the DNS Challenge dataset.|
|[hifi-gan-bwe-05-cd9f4ca-vctk-48kHz](https://cdn.brentspell.com/models/hifi-gan-bwe/hifi-gan-bwe-05-cd9f4ca-vctk-48kHz.pt)|48kHz|1M|[bwe-05-cd9f4ca](https://wandb.ai/brentspell/hifi-gan-bwe/runs/bwe-05-cd9f4ca?workspace=user-brentspell)|Trained for 200K iterations on the VCTK speech dataset with noise agumentation from the DNS Challenge dataset.|

## Training
If you want to train your own model, you can use any of the methods above
Expand Down
16 changes: 8 additions & 8 deletions hifi_gan_bwe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import typing as T
from pathlib import Path

import numpy as np
import requests
import torch
import torchaudio
Expand All @@ -48,7 +47,7 @@ def __init__(self) -> None:
# store the training sample rate in the state dict, so that
# we can run inference on a model trained for a different rate
self.sample_rate: torch.Tensor
self.register_buffer("sample_rate", torch.as_tensor(SAMPLE_RATE))
self.register_buffer("sample_rate", torch.tensor(SAMPLE_RATE))

self._wavenet = WaveNet(
stacks=2,
Expand All @@ -60,16 +59,17 @@ def __init__(self) -> None:
dilation_base=3,
)

def save(self, path: str) -> None:
torch.jit.save(torch.jit.script(self), path)

@staticmethod
def from_pretrained(path: str) -> "BandwidthExtender":
# first see if this is a hosted pretrained model, download it if so
if not path.endswith(".pt"):
path = _download(path)

# load the pretrained model's weights from the path
state = torch.load(path)
model = BandwidthExtender()
model.load_state_dict(state)
# load the local model file as a script module
model = torch.jit.load(path)
return model

@staticmethod
Expand Down Expand Up @@ -189,7 +189,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
for n in self._layers:
x, h = n(x)
s += h
x = s * np.sqrt(1.0 / len(self._layers))
x = s * torch.tensor(1.0 / len(self._layers)).sqrt()

# apply the output projection
x = self._conv_out(x)
Expand Down Expand Up @@ -247,7 +247,7 @@ def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, torch.Tensor]:
x = self._conv_out(x)

# add residual and apply a normalizing gain
x = (x + r) * np.sqrt(0.5)
x = (x + r) * torch.tensor(0.5).sqrt()

return x, s

Expand Down
3 changes: 1 addition & 2 deletions hifi_gan_bwe/scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path

import git
import torch

from hifi_gan_bwe import datasets, models

Expand Down Expand Up @@ -65,7 +64,7 @@ def main() -> None:
target_path.parent.mkdir(parents=True, exist_ok=True)

# save the model
torch.save(model.state_dict(), target_path)
model.save(target_path)

print(f"exported {source_path.name} to {target_path}")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="hifi-gan-bwe",
version="0.1.9",
version="0.1.10",
description=(
"Unofficial implementation of the HiFi-GAN+ model "
"for audio bandwidth extension"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def test_save_load(tmpdir: Path) -> None:
model.remove_weightnorm()

model_path = tmpdir / "pretrained.pt"
torch.save(model.state_dict(), str(model_path))
model.save(str(model_path))
loaded = models.BandwidthExtender.from_pretrained(str(model_path))
assert_params(model, loaded)


@torch.no_grad()
def test_hosted() -> None:
model_name = "hifi-gan-bwe-05-d3abf04-vctk-48kHz"
model_name = "hifi-gan-bwe-05-cd9f4ca-vctk-48kHz"
model = models.BandwidthExtender.from_pretrained(model_name)
y = model(torch.zeros([80]), 8000)
assert list(y.shape) == [480]
Expand Down

0 comments on commit e8cf1a2

Please sign in to comment.