-
Notifications
You must be signed in to change notification settings - Fork 15
/
snac.py
97 lines (88 loc) · 3.06 KB
/
snac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import math
import numpy as np
import torch
from torch import nn
from .layers import Encoder, Decoder
from .vq import ResidualVectorQuantize
class SNAC(nn.Module):
def __init__(
self,
sampling_rate=44100,
encoder_dim=64,
encoder_rates=[3, 3, 7, 7],
latent_dim=None,
decoder_dim=1536,
decoder_rates=[7, 7, 3, 3],
attn_window_size=32,
codebook_size=4096,
codebook_dim=8,
vq_strides=[8, 4, 2, 1],
noise=True,
depthwise=True,
):
super().__init__()
self.sampling_rate = sampling_rate
self.encoder_dim = encoder_dim
self.encoder_rates = encoder_rates
self.decoder_dim = decoder_dim
self.decoder_rates = decoder_rates
if latent_dim is None:
latent_dim = encoder_dim * (2 ** len(encoder_rates))
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = Encoder(
encoder_dim,
encoder_rates,
depthwise=depthwise,
attn_window_size=attn_window_size,
)
self.n_codebooks = len(vq_strides)
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.vq_strides = vq_strides
self.attn_window_size = attn_window_size
self.quantizer = ResidualVectorQuantize(
input_dim=latent_dim,
codebook_size=codebook_size,
codebook_dim=codebook_dim,
vq_strides=vq_strides,
)
self.decoder = Decoder(
latent_dim,
decoder_dim,
decoder_rates,
noise,
depthwise=depthwise,
attn_window_size=attn_window_size,
)
def preprocess(self, audio_data):
length = audio_data.shape[-1]
lcm = math.lcm(self.vq_strides[0], self.attn_window_size or 1)
pad_to = self.hop_length * lcm
right_pad = math.ceil(length / pad_to) * pad_to - length
audio_data = nn.functional.pad(audio_data, (0, right_pad))
return audio_data
def forward(self, audio_data):
length = audio_data.shape[-1]
audio_data = self.preprocess(audio_data)
z = self.encoder(audio_data)
z, codes, commitment_loss, codebook_loss = self.quantizer(z)
x = self.decoder(z)
return x[..., :length], z, codes, commitment_loss, codebook_loss
@classmethod
def from_config(cls, config_path):
with open(config_path, "r") as f:
config = json.load(f)
model = cls(**config)
return model
@classmethod
def from_pretrained(cls, repo_id, **kwargs):
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", **kwargs)
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", **kwargs)
model = cls.from_config(config_path)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model