-
Notifications
You must be signed in to change notification settings - Fork 3
/
fad_gen.py
124 lines (104 loc) · 4.96 KB
/
fad_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_fad_gen.ipynb.
# %% auto 0
__all__ = ['gen', 'main']
# %% ../nbs/01_fad_gen.ipynb 6
import os
import argparse
from accelerate import Accelerator
import warnings
import torch
from aeiou.core import get_device, load_audio, get_audio_filenames, makedir
from aeiou.datasets import get_wds_loader, AudioDataset
from aeiou.hpc import HostPrinter
from pathlib import Path
#from audio_algebra.given_models import StackedDiffAEWrapper
import ast
import torchaudio
from tqdm.auto import tqdm
import math
# %% ../nbs/01_fad_gen.ipynb 7
def gen(args):
# HPC / parallel setup
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddps = f"[{local_rank}/{world_size}]" # string for distributed computing info, e.g. "[1/8]"
accelerator = Accelerator()
hprint = HostPrinter(accelerator) # hprint only prints on head node
device = accelerator.device # get_device()
hprint(f"gen: args = {args}")
hprint(f'{ddps} Using device: {device}')
model_ckpt, data_sources, profiles, n = args.model_ckpt, args.data_sources, args.profiles, args.n
names = data_sources.split(' ')
#hprint(f"names = {names}")
local_data = False
if 's3://' in data_sources:
hprint("Data sources are on S3")
profiles = ast.literal_eval(profiles)
hprint(f"profiles = {profiles}")
dl = get_wds_loader(
batch_size=args.batch_size,
s3_url_prefix=None,
sample_size=args.sample_size,
names=names,
sample_rate=args.sample_rate,
num_workers=args.num_workers,
recursive=True,
random_crop=True,
epoch_steps=10000,
profiles=profiles,
)
else:
hprint("Data sources are local")
dataset = AudioDataset(names, sample_rate=args.sample_rate, sample_size=args.sample_size)
dl = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)
local_data = True
print(f"loading {model_ckpt}....")
if model_ckpt.endswith('.ts'):
model = torch.jit.load(model_ckpt)
#else: # default is stacked diffae
# model = StackedDiffAEWrapper(ckpt_info={'ckpt_path':model_ckpt})
try:
model.setup() # if it needs setup call
except:
pass
model.eval()
model = model.to(device)
model, dl = accelerator.prepare( model, dl ) # prepare handles distributing things among GPUs
reals_path, fakes_path = f"{args.name}_reals", f"{args.name}_fakes"
makedir(reals_path)
makedir(fakes_path)
progress_bar = tqdm(range(math.ceil(args.n/args.batch_size)), disable=not accelerator.is_local_main_process)
for i, data in enumerate(dl):
reals = data if local_data else data[0][0]
if args.debug: hprint(f"{ddps} i = {i}, reals.shape = {reals.shape}")
with torch.no_grad():
fakes = accelerator.unwrap_model(model).forward(reals.to(device)).cpu()
#hprint(f"fakes.shape = {fakes.shape}")
for b in range(reals.shape[0]):
waveform = reals[b]
torchaudio.save(f"{reals_path}/{i}_{b}.wav", waveform.cpu(), args.sample_rate)
waveform = fakes[b]
torchaudio.save(f"{fakes_path}/{i}_{b}.wav", waveform.cpu(), args.sample_rate)
progress_bar.update(1)
if (i+1)*args.batch_size > args.n:
hprint(f"\nGot all the data we needed: {i*args.batch_size}. Stopping")
break
hprint("Success!")
# %% ../nbs/01_fad_gen.ipynb 8
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('name', help='Name prefix for output directories: <name>_reals/ & <name>_fakes/')
parser.add_argument('model_ckpt', help='TorchScript (.ts) (Generative) Model checkpoint file')
parser.add_argument('data_sources', help='Space-separated string listing either S3 resources or local directories (but not a mix of both!) for real data')
parser.add_argument('-d','--debug', action="store_true", help='Enable extra debugging messages')
parser.add_argument('-b',"--batch_size", default=2, help='batch size')
parser.add_argument('--n', type=int, default=256, help='Number of real/fake samples to grab/generate, respectively')
parser.add_argument('--num_workers', type=int, default=12, help='Number of pytorch workers to use in DataLoader')
parser.add_argument('-p',"--profiles", default='', help='String representation of dict {resource:profile} of AWS credentials')
parser.add_argument('--sample_rate', type=int, default=48000, help='sample rate (will resample inputs at this rate)')
parser.add_argument('-s','--sample_size', type=int, default=2**18, help='Number of samples per clip')
args = parser.parse_args()
gen( args )
# %% ../nbs/01_fad_gen.ipynb 9
if __name__ == '__main__' and "get_ipython" not in dir():
main()