Skip to content

Commit

Permalink
replicate
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxwh committed Jun 24, 2023
1 parent f586b78 commit bfa9383
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# TANGO: Text to Audio using iNstruction-Guided diffusiOn
<!-- ![cover](img/tango-neurips.png) -->

[Paper](https://arxiv.org/pdf/2304.13731.pdf) | [Model](https://huggingface.co/declare-lab/tango) | [Website and Examples](https://tango-web.github.io/) | [More Examples](https://github.com/declare-lab/tango/blob/master/samples/README.md) | [Demo](https://huggingface.co/spaces/declare-lab/tango)
[Paper](https://arxiv.org/pdf/2304.13731.pdf) | [Model](https://huggingface.co/declare-lab/tango) | [Website and Examples](https://tango-web.github.io/) | [More Examples](https://github.com/declare-lab/tango/blob/master/samples/README.md) | [Demo](https://huggingface.co/spaces/declare-lab/tango) | [Replicate demo and API](https://replicate.com/cjwbw/tango)


:fire: The demo of **TANGO** is live on [Huggingface Space](https://huggingface.co/spaces/declare-lab/tango)

Expand Down
38 changes: 38 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
gpu: true
cuda: "11.7"
python_version: "3.10"
python_packages:
- "torch==1.13.1"
- "torchaudio==0.13.1"
- "torchvision==0.14.1"
- "transformers==4.27.0"
- "accelerate==0.18.0"
- "datasets==2.1.0"
- "einops==0.6.1"
- "h5py==3.8.0"
- "huggingface_hub==0.13.3"
- "importlib_metadata==6.3.0"
- "librosa==0.9.2"
- "matplotlib==3.5.2"
- "numpy==1.23.0"
- "omegaconf==2.3.0"
- "packaging==23.1"
- "pandas==1.4.1"
- "progressbar33==2.4"
- "protobuf==3.20.*"
- "resampy==0.4.2"
- "scikit_image==0.19.3"
- "scikit_learn==1.2.2"
- "scipy==1.8.0"
- "soundfile==0.12.1"
- "ssr_eval==0.0.6"
- "torchlibrosa==0.1.0"
- "tqdm==4.63.1"
- "wandb==0.12.14"
- "ipython==8.12.0"
- "diffusers==0.17.1"
predict: "predict.py:Predictor"
116 changes: 116 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import json
import torch
from tqdm import tqdm
import soundfile as sf
from models import AudioDiffusion, DDPMScheduler
from audioldm.audio.stft import TacotronSTFT
from audioldm.variational_autoencoder import AutoencoderKL
from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.tango = Tango()

def predict(
self,
prompt: str = Input(
description="Input prompt", default="An audience cheering and clapping"
),
steps: int = Input(description="inferene steps", default=100),
guidance: float = Input(description="guidance scale", default=3),
) -> Path:
"""Run a single prediction on the model"""

audio = self.tango.generate(prompt, steps, guidance)
out = "/tmp/output.wav"
sf.write(out, audio, samplerate=16000)
return Path(out)


class Tango:
def __init__(self, path="tango_weights", device="cuda:0"):
# weights are dowloaded from https://huggingface.co/declare-lab/tango-full/tree/main and saved to ./tango_weights
vae_config = json.load(open("{}/vae_config.json".format(path)))
stft_config = json.load(open("{}/stft_config.json".format(path)))
main_config = json.load(open("{}/main_config.json".format(path)))

self.vae = AutoencoderKL(**vae_config).to(device)
self.stft = TacotronSTFT(**stft_config).to(device)
self.model = AudioDiffusion(**main_config).to(device)

vae_weights = torch.load(
"{}/pytorch_model_vae.bin".format(path), map_location=device
)
stft_weights = torch.load(
"{}/pytorch_model_stft.bin".format(path), map_location=device
)
main_weights = torch.load(
"{}/pytorch_model_main.bin".format(path), map_location=device
)

self.vae.load_state_dict(vae_weights)
self.stft.load_state_dict(stft_weights)
self.model.load_state_dict(main_weights)

self.vae.eval()
self.stft.eval()
self.model.eval()

self.scheduler = DDPMScheduler.from_pretrained(
main_config["scheduler_name"], subfolder="scheduler"
)

def chunks(self, lst, n):
"""Yield successive n-sized chunks from a list."""
for i in range(0, len(lst), n):
yield lst[i : i + n]

def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
"""Genrate audio for a single prompt string."""
with torch.no_grad():
latents = self.model.inference(
[prompt],
self.scheduler,
steps,
guidance,
samples,
disable_progress=disable_progress,
)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
return wave[0]

def generate_for_batch(
self,
prompts,
steps=100,
guidance=3,
samples=1,
batch_size=8,
disable_progress=True,
):
"""Genrate audio for a list of prompt strings."""
outputs = []
for k in tqdm(range(0, len(prompts), batch_size)):
batch = prompts[k : k + batch_size]
with torch.no_grad():
latents = self.model.inference(
batch,
self.scheduler,
steps,
guidance,
samples,
disable_progress=disable_progress,
)
mel = self.vae.decode_first_stage(latents)
wave = self.vae.decode_to_waveform(mel)
outputs += [item for item in wave]
if samples == 1:
return outputs
else:
return list(self.chunks(outputs, samples))

0 comments on commit bfa9383

Please sign in to comment.