Skip to content

Commit

Permalink
changing tmpdir when running in slurm + not depending anymore on torc…
Browse files Browse the repository at this point in the history
…haudio for writing audio files. (#306)

* changing tmpdir when runnign in slurm

* fixing typing in dadam

* limiting dependency on torchaudio for writing files

* not using torchaudio for reading anymore

* trying desperatly to get those unit tests to pass

* plop

* fixing tests once more

* linter

* plop
  • Loading branch information
adefossez committed Oct 12, 2023
1 parent a2b9675 commit 5d8752d
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 63 deletions.
2 changes: 2 additions & 0 deletions .github/actions/audiocraft_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ runs:
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install --pre xformers
pip install -e '.[dev]'
- name: System Dependencies
shell: bash
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).


## [1.0.1] - TBD

Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.

## [1.0.0] - 2023-09-07

Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Expand Down
45 changes: 29 additions & 16 deletions audiocraft/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import soundfile
import torch
from torch.nn import functional as F
import torchaudio as ta

import av
import subprocess as sp

from .audio_utils import f32_pcm, i16_pcm, normalize_audio
from .audio_utils import f32_pcm, normalize_audio


_av_initialized = False
Expand Down Expand Up @@ -136,12 +136,6 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
elif (
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
and duration <= 0 and seek_time == 0
):
# Torchaudio is faster if we load an entire file at once.
wav, sr = ta.load(fp)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
Expand All @@ -150,10 +144,22 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
return wav, sr


def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
assert wav.dim() == 2, wav.shape
command = [
'ffmpeg',
'-loglevel', 'error',
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
'-i', '-'] + flags + [str(out_path)]
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
sp.run(command, input=input_, check=True)


def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
Expand All @@ -164,8 +170,9 @@ def audio_write(stem_name: tp.Union[str, Path],
stem_name (str or Path): Filename without extension which will be added automatically.
wav (torch.Tensor): Audio data to save.
sample_rate (int): Sample rate of audio data.
format (str): Either "wav" or "mp3".
format (str): Either "wav", "mp3", "ogg", or "flac".
mp3_rate (int): kbps when using mp3s.
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
Expand Down Expand Up @@ -193,14 +200,20 @@ def audio_write(stem_name: tp.Union[str, Path],
rms_headroom_db, loudness_headroom_db, loudness_compressor,
log_clipping=log_clipping, sample_rate=sample_rate,
stem_name=str(stem_name))
kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
kwargs.update({"compression": mp3_rate})
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
elif format == 'wav':
wav = i16_pcm(wav)
suffix = '.wav'
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
elif format == 'ogg':
suffix = '.ogg'
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
if ogg_rate is not None:
flags += ['-b:a', f'{ogg_rate}k']
elif format == 'flac':
suffix = '.flac'
flags = ['-f', 'flac']
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
Expand All @@ -209,7 +222,7 @@ def audio_write(stem_name: tp.Union[str, Path],
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
ta.save(path, wav, sample_rate, **kwargs)
_piping_to_ffmpeg(path, wav, sample_rate, flags)
except Exception:
if path.exists():
# we do not want to leave half written files around.
Expand Down
24 changes: 14 additions & 10 deletions audiocraft/modules/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,16 @@ def get_rotation(self, start: int, end: int):
self.rotation = torch.polar(torch.ones_like(angles), angles)
return self.rotation[start:end]

def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
"""Apply rope rotation to query or key tensor."""
T = x.shape[1]
rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
T = x.shape[time_dim]
target_shape = [1] * x.dim()
target_shape[time_dim] = T
target_shape[-1] = -1
rotation = self.get_rotation(start, start + T).view(target_shape)

if self.xpos:
decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
decay = self.xpos.get_decay(start, start + T).view(target_shape)
else:
decay = 1.0

Expand All @@ -96,11 +99,11 @@ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):

x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)

return x_out.type_as(x)

def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
""" Apply rope rotation to both query and key tensors.
Supports streaming mode, in which query and key are not expected to have the same shape.
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
Expand All @@ -110,12 +113,13 @@ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
query (torch.Tensor): Query to rotate.
key (torch.Tensor): Key to rotate.
start (int): Start index of the sequence for time offset.
time_dim (int): which dimension represent the time steps.
"""
query_timesteps = query.shape[1]
key_timesteps = key.shape[1]
query_timesteps = query.shape[time_dim]
key_timesteps = key.shape[time_dim]
streaming_offset = key_timesteps - query_timesteps

query_out = self.rotate(query, start + streaming_offset)
key_out = self.rotate(key, start, invert_decay=True)
query_out = self.rotate(query, start + streaming_offset, time_dim)
key_out = self.rotate(key, start, time_dim, invert_decay=True)

return query_out, key_out
23 changes: 11 additions & 12 deletions audiocraft/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'):
_efficient_attention_backend = backend


def _get_attention_time_dimension() -> int:
if _efficient_attention_backend == 'torch':
def _get_attention_time_dimension(memory_efficient: bool) -> int:
if _efficient_attention_backend == 'torch' and memory_efficient:
return 2
else:
return 1
Expand Down Expand Up @@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)


def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
if n_rep == 1:
return x
if _efficient_attention_backend == 'torch':
if _efficient_attention_backend == 'torch' and memory_efficient:
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
Expand Down Expand Up @@ -234,7 +234,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
# Return a causal mask, accounting for potentially stored past keys/values
# We actually return a bias for the attention score, as this has the same
# convention both in the builtin MHA in Pytorch, and Xformers functions.
time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.memory_efficient:
from xformers.ops import LowerTriangularMask
if current_steps == 1:
Expand Down Expand Up @@ -264,7 +264,7 @@ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype
torch.full([], float('-inf'), device=device, dtype=dtype))

def _complete_kv(self, k, v):
time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if self.cross_attention:
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
Expand Down Expand Up @@ -298,8 +298,7 @@ def _complete_kv(self, k, v):
return nk, nv

def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
# TODO: fix and verify layout.
assert _efficient_attention_backend == 'xformers', "Rope not supported with torch attn."
time_dim = _get_attention_time_dimension(self.memory_efficient)
# Apply rope embeddings to query and key tensors.
assert self.rope is not None
if 'past_keys' in self._streaming_state:
Expand All @@ -311,7 +310,7 @@ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
else:
past_context_offset = 0
streaming_offset = past_context_offset + past_keys_offset
return self.rope.rotate_qk(query, key, start=streaming_offset)
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask=None, need_weights=False, attn_mask=None,
Expand All @@ -320,7 +319,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
"use the causal args in the constructor.")

time_dim = _get_attention_time_dimension()
time_dim = _get_attention_time_dimension(self.memory_efficient)
if time_dim == 2:
layout = "b h t d"
else:
Expand Down Expand Up @@ -394,8 +393,8 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
q, k = self._apply_rope(q, k)
k, v = self._complete_kv(k, v)
if self.kv_repeat > 1:
k = expand_repeated_kv(k, self.kv_repeat)
v = expand_repeated_kv(v, self.kv_repeat)
k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
if self.attention_as_float32:
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
Expand Down
8 changes: 2 additions & 6 deletions audiocraft/optim/dadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import TYPE_CHECKING, Any
from typing import Any

import torch
import torch.optim
import torch.distributed as dist

if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any


logger = logging.getLogger(__name__)
_params_t = Any


def to_real(x):
Expand Down
6 changes: 6 additions & 0 deletions audiocraft/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import multiprocessing
import os
from pathlib import Path
import sys
import typing as tp

Expand Down Expand Up @@ -119,6 +120,11 @@ def init_seed_and_system(cfg):
logger.debug('Setting num threads to %d', cfg.num_threads)
set_efficient_attention_backend(cfg.efficient_attention_backend)
logger.debug('Setting efficient attention backend to %s', cfg.efficient_attention_backend)
if 'SLURM_JOB_ID' in os.environ:
tmpdir = Path('/scratch/slurm_tmpdir/' + os.environ['SLURM_JOB_ID'])
if tmpdir.exists():
logger.info("Changing tmpdir to %s", tmpdir)
os.environ['TMPDIR'] = str(tmpdir)


@hydra_main(config_path='../config', config_name='config', version_base='1.1')
Expand Down
15 changes: 6 additions & 9 deletions tests/common_utils/wav_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
# LICENSE file in the root directory of this source tree.

from pathlib import Path
import typing as tp

import torch
import torchaudio

from audiocraft.data.audio import audio_write


def get_white_noise(chs: int = 1, num_frames: int = 1):
Expand All @@ -22,11 +22,8 @@ def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):


def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
assert wav.dim() == 2, wav.shape
fp = Path(path)
kwargs: tp.Dict[str, tp.Any] = {}
if fp.suffix == '.wav':
kwargs['encoding'] = 'PCM_S'
kwargs['bits_per_sample'] = 16
elif fp.suffix == '.mp3':
kwargs['compression'] = 320
torchaudio.save(str(fp), wav, sample_rate, **kwargs)
assert fp.suffix in ['.mp3', '.ogg', '.wav', '.flac'], fp
audio_write(fp.parent / fp.stem, wav, sample_rate, fp.suffix[1:],
normalize=False, strategy='clip', peak_clip_headroom_db=0)
16 changes: 8 additions & 8 deletions tests/modules/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


def test_rope():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope = RotaryEmbedding(dim=C)
Expand All @@ -24,7 +24,7 @@ def test_rope():


def test_rope_io_dtypes():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
Expand All @@ -48,7 +48,7 @@ def test_rope_io_dtypes():


def test_transformer_with_rope():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)
for pos in ['rope', 'sin_rope']:
tr = StreamingTransformer(
Expand All @@ -64,7 +64,7 @@ def test_transformer_with_rope():

@torch.no_grad()
def test_rope_streaming():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, causal=True, dropout=0.,
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_rope_streaming():

@torch.no_grad()
def test_rope_streaming_past_context():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)

for context in [None, 10]:
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_rope_streaming_past_context():


def test_rope_memory_efficient():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
torch.manual_seed(1234)
tr = StreamingTransformer(
16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
Expand All @@ -143,7 +143,7 @@ def test_rope_memory_efficient():


def test_rope_with_xpos():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope = RotaryEmbedding(dim=C, xpos=True)
Expand All @@ -156,7 +156,7 @@ def test_rope_with_xpos():


def test_positional_scale():
set_efficient_attention_backend('xformers')
set_efficient_attention_backend('torch')
B, T, H, C = 8, 75, 16, 128

rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
Expand Down
Loading

0 comments on commit 5d8752d

Please sign in to comment.