Skip to content

Commit

Permalink
add cuda support to hifi-synth
Browse files Browse the repository at this point in the history
  • Loading branch information
brentspell committed Jun 3, 2022
1 parent 584f286 commit 63579ac
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
10 changes: 8 additions & 2 deletions hifi_gan_bwe/scripts/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ def main() -> None:
type=Path,
help="output audio file path",
)
parser.add_argument(
"--device",
default="cpu",
help="torch device to use for synthesis (ex: cpu, cuda, cuda:1, etc.)",
)

args = parser.parse_args()

# load the model
torch.set_grad_enabled(False)
model = models.BandwidthExtender.from_pretrained(args.model)
model = models.BandwidthExtender.from_pretrained(args.model).to(args.device)

# load the source audio file
with audioread.audio_open(str(args.source_path)) as input_:
Expand All @@ -52,7 +57,8 @@ def main() -> None:
)

# run the bandwidth extender on each audio channel
audio = np.stack([model(torch.from_numpy(x), sample_rate) for x in audio.T]).T
inputs = torch.from_numpy(audio).to(args.device)
audio = torch.stack([model(x, sample_rate) for x in inputs.T]).T.cpu().numpy()

# save the output file
soundfile.write(args.target_path, audio, samplerate=int(model.sample_rate))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="hifi-gan-bwe",
version="0.1.12",
version="0.1.13",
description=(
"Unofficial implementation of the HiFi-GAN+ model "
"for audio bandwidth extension"
Expand Down

0 comments on commit 63579ac

Please sign in to comment.