forked from facebookresearch/audiocraft
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Luis
committed
Jan 17, 2024
1 parent
2a5c5e9
commit bfbd08f
Showing
11 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
av | ||
einops | ||
flashy>=0.0.1 | ||
hydra-core>=1.1 | ||
hydra_colorlog | ||
julius | ||
num2words | ||
numpy | ||
sentencepiece | ||
spacy>=3.6.1 | ||
torch==2.1.0 | ||
torchaudio>=2.0.0 | ||
huggingface_hub | ||
tqdm | ||
transformers>=4.31.0 | ||
demucs | ||
librosa | ||
torchmetrics | ||
encodec | ||
protobuf | ||
xformers --index-url https://download.pytorch.org/whl/cu118 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
av | ||
einops | ||
flashy | ||
gradio==3.50.2 | ||
julius | ||
num2words | ||
omegaconf | ||
sentencepiece | ||
torchmetrics | ||
torch==2.1.0 | ||
torchaudio==2.1.0 |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
av | ||
einops | ||
flashy>=0.0.1 | ||
hydra-core>=1.1 | ||
hydra_colorlog | ||
julius | ||
num2words | ||
numpy | ||
sentencepiece | ||
spacy>=3.6.1 | ||
torch==2.1.0 | ||
torchaudio>=2.0.0 | ||
huggingface_hub | ||
tqdm | ||
transformers>=4.31.0 | ||
demucs | ||
librosa | ||
gradio | ||
torchmetrics | ||
encodec | ||
protobuf |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
av | ||
einops | ||
flashy>=0.0.1 | ||
hydra-core>=1.1 | ||
hydra_colorlog | ||
julius | ||
num2words | ||
numpy | ||
sentencepiece | ||
spacy>=3.6.1 | ||
torch==2.1.0 | ||
torchaudio>=2.0.0 | ||
huggingface_hub | ||
tqdm | ||
transformers>=4.31.0 | ||
demucs | ||
librosa | ||
torchmetrics | ||
encodec | ||
protobuf | ||
xformers --index-url https://download.pytorch.org/whl/cu118 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# The .dockerignore file excludes files from the container build process. | ||
# | ||
# https://docs.docker.com/engine/reference/builder/#dockerignore-file | ||
|
||
# Exclude Git files | ||
.git | ||
.github | ||
.gitignore | ||
|
||
# Exclude Python cache files | ||
__pycache__ | ||
.mypy_cache | ||
.pytest_cache | ||
.ruff_cache | ||
|
||
# Exclude Python virtual environment | ||
/venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Configuration for Cog ⚙️ | ||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md | ||
|
||
build: | ||
gpu: true | ||
cuda: "11.8" | ||
system_packages: | ||
- "ffmpeg" | ||
- "aria2" | ||
python_version: "3.10" | ||
python_packages: | ||
- "av" | ||
- "einops" | ||
- "flashy>=0.0.1" | ||
- "hydra-core>=1.1" | ||
- "hydra_colorlog" | ||
- "julius" | ||
- "num2words" | ||
- "numpy" | ||
- "sentencepiece" | ||
- "spacy>=3.6.1" | ||
- "torch==2.1.0" | ||
- "torchaudio>=2.0.0" | ||
- "huggingface_hub" | ||
- "tqdm" | ||
- "transformers>=4.31.0" | ||
- "demucs" | ||
- "librosa" | ||
- "torchmetrics" | ||
- "encodec" | ||
- "protobuf" | ||
- "xformers --index-url https://download.pytorch.org/whl/cu118" | ||
|
||
# predict.py defines how predictions are run on your model | ||
predict: "predict.py:Predictor" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Prediction interface for Cog ⚙️ | ||
# https://github.com/replicate/cog/blob/main/docs/python.md | ||
from cog import BasePredictor, Input, Path | ||
import torch | ||
import torchaudio | ||
from typing import List | ||
from audiocraft.models import MAGNeT | ||
from audiocraft.data.audio import audio_write | ||
|
||
class Predictor(BasePredictor): | ||
def setup(self) -> None: | ||
"""Load the model into memory to make running multiple predictions efficient""" | ||
self.model = MAGNeT.get_pretrained("facebook/magnet-small-10secs") | ||
|
||
@torch.inference_mode() | ||
def predict( | ||
self, | ||
prompt: str = Input( | ||
description="Input Text", | ||
default="80s electronic track with melodic synthesizers, catchy beat and groovy bass" | ||
), | ||
model: str = Input( | ||
description="Model to use", | ||
default="facebook/magnet-small-10secs", | ||
choices=[ | ||
'facebook/magnet-small-10secs', | ||
'facebook/magnet-medium-10secs', | ||
'facebook/magnet-small-30secs', | ||
'facebook/magnet-medium-30secs', | ||
'facebook/audio-magnet-small', | ||
'facebook/audio-magnet-medium'] | ||
), | ||
variations: int = Input( | ||
description="Number of variations to generate", | ||
default=3, ge=1, le=4, | ||
), | ||
span_score: str = Input( | ||
default="prod-stride1", | ||
choices=["max-nonoverlap", "prod-stride1"], | ||
), | ||
temperature: float = Input( | ||
default=3.0, | ||
description="Temperature for sampling", | ||
), | ||
top_p: float = Input( | ||
default=0.9, ge=0.0, le=1.0, | ||
description="Top p for sampling", | ||
), | ||
max_cfg: float = Input( | ||
default=10.0, | ||
description="Max CFG coefficient", | ||
), | ||
min_cfg: float = Input( | ||
default=1.0, | ||
description="Min CFG coefficient", | ||
), | ||
decoding_steps_stage_1: int = Input( | ||
default=20, | ||
description="Number of decoding steps for stage 1", | ||
), | ||
decoding_steps_stage_2: int = Input( | ||
default=10, | ||
description="Number of decoding steps for stage 2", | ||
), | ||
decoding_steps_stage_3: int = Input( | ||
default=10, | ||
description="Number of decoding steps for stage 3", | ||
), | ||
decoding_steps_stage_4: int = Input( | ||
default=10, | ||
description="Number of decoding steps for stage 4", | ||
), | ||
) -> List[Path]: | ||
"""Run a single prediction on the model""" | ||
descriptions = [prompt for _ in range(variations)] | ||
|
||
self.model = MAGNeT.get_pretrained(model) | ||
self.model.set_generation_params( | ||
temperature=temperature, | ||
top_p=top_p, | ||
max_cfg_coef=max_cfg, min_cfg_coef=min_cfg, | ||
decoding_steps=[decoding_steps_stage_1, decoding_steps_stage_2, decoding_steps_stage_3, decoding_steps_stage_4], | ||
span_arrangement='stride1' if (span_score == 'prod-stride1') else 'nonoverlap',) | ||
wav = self.model.generate(descriptions) | ||
|
||
for idx, one_wav in enumerate(wav): | ||
audio_write(f'/tmp/{idx}', one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True) | ||
|
||
output_paths = [] | ||
for idx in range(variations): | ||
output_paths.append(Path(f'/tmp/{idx}.wav')) | ||
|
||
return output_paths |