Skip to content

Commit

Permalink
Replicate API
Browse files Browse the repository at this point in the history
  • Loading branch information
Luis committed Jan 17, 2024
1 parent 2a5c5e9 commit bfbd08f
Show file tree
Hide file tree
Showing 11 changed files with 219 additions and 0 deletions.
Binary file not shown.
21 changes: 21 additions & 0 deletions .cog/tmp/build4097781755/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
av
einops
flashy>=0.0.1
hydra-core>=1.1
hydra_colorlog
julius
num2words
numpy
sentencepiece
spacy>=3.6.1
torch==2.1.0
torchaudio>=2.0.0
huggingface_hub
tqdm
transformers>=4.31.0
demucs
librosa
torchmetrics
encodec
protobuf
xformers --index-url https://download.pytorch.org/whl/cu118
Binary file not shown.
11 changes: 11 additions & 0 deletions .cog/tmp/build4114755976/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
av
einops
flashy
gradio==3.50.2
julius
num2words
omegaconf
sentencepiece
torchmetrics
torch==2.1.0
torchaudio==2.1.0
Binary file not shown.
21 changes: 21 additions & 0 deletions .cog/tmp/build678619104/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
av
einops
flashy>=0.0.1
hydra-core>=1.1
hydra_colorlog
julius
num2words
numpy
sentencepiece
spacy>=3.6.1
torch==2.1.0
torchaudio>=2.0.0
huggingface_hub
tqdm
transformers>=4.31.0
demucs
librosa
gradio
torchmetrics
encodec
protobuf
Binary file not shown.
21 changes: 21 additions & 0 deletions .cog/tmp/build953554414/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
av
einops
flashy>=0.0.1
hydra-core>=1.1
hydra_colorlog
julius
num2words
numpy
sentencepiece
spacy>=3.6.1
torch==2.1.0
torchaudio>=2.0.0
huggingface_hub
tqdm
transformers>=4.31.0
demucs
librosa
torchmetrics
encodec
protobuf
xformers --index-url https://download.pytorch.org/whl/cu118
17 changes: 17 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# The .dockerignore file excludes files from the container build process.
#
# https://docs.docker.com/engine/reference/builder/#dockerignore-file

# Exclude Git files
.git
.github
.gitignore

# Exclude Python cache files
__pycache__
.mypy_cache
.pytest_cache
.ruff_cache

# Exclude Python virtual environment
/venv
35 changes: 35 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
gpu: true
cuda: "11.8"
system_packages:
- "ffmpeg"
- "aria2"
python_version: "3.10"
python_packages:
- "av"
- "einops"
- "flashy>=0.0.1"
- "hydra-core>=1.1"
- "hydra_colorlog"
- "julius"
- "num2words"
- "numpy"
- "sentencepiece"
- "spacy>=3.6.1"
- "torch==2.1.0"
- "torchaudio>=2.0.0"
- "huggingface_hub"
- "tqdm"
- "transformers>=4.31.0"
- "demucs"
- "librosa"
- "torchmetrics"
- "encodec"
- "protobuf"
- "xformers --index-url https://download.pytorch.org/whl/cu118"

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
93 changes: 93 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md
from cog import BasePredictor, Input, Path
import torch
import torchaudio
from typing import List
from audiocraft.models import MAGNeT
from audiocraft.data.audio import audio_write

class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
self.model = MAGNeT.get_pretrained("facebook/magnet-small-10secs")

@torch.inference_mode()
def predict(
self,
prompt: str = Input(
description="Input Text",
default="80s electronic track with melodic synthesizers, catchy beat and groovy bass"
),
model: str = Input(
description="Model to use",
default="facebook/magnet-small-10secs",
choices=[
'facebook/magnet-small-10secs',
'facebook/magnet-medium-10secs',
'facebook/magnet-small-30secs',
'facebook/magnet-medium-30secs',
'facebook/audio-magnet-small',
'facebook/audio-magnet-medium']
),
variations: int = Input(
description="Number of variations to generate",
default=3, ge=1, le=4,
),
span_score: str = Input(
default="prod-stride1",
choices=["max-nonoverlap", "prod-stride1"],
),
temperature: float = Input(
default=3.0,
description="Temperature for sampling",
),
top_p: float = Input(
default=0.9, ge=0.0, le=1.0,
description="Top p for sampling",
),
max_cfg: float = Input(
default=10.0,
description="Max CFG coefficient",
),
min_cfg: float = Input(
default=1.0,
description="Min CFG coefficient",
),
decoding_steps_stage_1: int = Input(
default=20,
description="Number of decoding steps for stage 1",
),
decoding_steps_stage_2: int = Input(
default=10,
description="Number of decoding steps for stage 2",
),
decoding_steps_stage_3: int = Input(
default=10,
description="Number of decoding steps for stage 3",
),
decoding_steps_stage_4: int = Input(
default=10,
description="Number of decoding steps for stage 4",
),
) -> List[Path]:
"""Run a single prediction on the model"""
descriptions = [prompt for _ in range(variations)]

self.model = MAGNeT.get_pretrained(model)
self.model.set_generation_params(
temperature=temperature,
top_p=top_p,
max_cfg_coef=max_cfg, min_cfg_coef=min_cfg,
decoding_steps=[decoding_steps_stage_1, decoding_steps_stage_2, decoding_steps_stage_3, decoding_steps_stage_4],
span_arrangement='stride1' if (span_score == 'prod-stride1') else 'nonoverlap',)
wav = self.model.generate(descriptions)

for idx, one_wav in enumerate(wav):
audio_write(f'/tmp/{idx}', one_wav.cpu(), self.model.sample_rate, strategy="loudness", loudness_compressor=True)

output_paths = []
for idx in range(variations):
output_paths.append(Path(f'/tmp/{idx}.wav'))

return output_paths

0 comments on commit bfbd08f

Please sign in to comment.