-
Notifications
You must be signed in to change notification settings - Fork 25.4k
/
configuration_jukebox.py
614 lines (552 loc) 路 26.4 KB
/
configuration_jukebox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
# coding=utf-8
# Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Jukebox configuration"""
import os
from typing import List, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
JUKEBOX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"openai/jukebox-5b-lyrics": "https://huggingface.co/openai/jukebox-5b-lyrics/blob/main/config.json",
"openai/jukebox-1b-lyrics": "https://huggingface.co/openai/jukebox-1b-lyrics/blob/main/config.json",
}
_LARGE_ATTENTION = [
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"block_attn",
"transpose_block_attn",
"prev_block_attn",
"cross_attention",
]
_RawColumnPreviousRowAttention = ["block_attn", "transpose_block_attn", "prev_block_attn"]
_FullDenseAttention = ["dense_attention"]
_PrimePrimeDenseAttention = ["prime_attn", "prime_attn", "dense_attn"]
def full_dense_attention(layer):
return _FullDenseAttention[0]
def raw_column_previous_row_attention(layer):
return _RawColumnPreviousRowAttention[layer % 3]
def large_separated_enc_dec_w_lyrics(layer):
return _LARGE_ATTENTION[layer % 79]
def enc_dec_with_lyrics(layer):
if layer % 16 == 15:
return _PrimePrimeDenseAttention[layer % 3]
return _RawColumnPreviousRowAttention[layer % 3]
ATTENTION_PATTERNS = {
"full_dense_attention": full_dense_attention,
"raw_column_previous_row_attention": raw_column_previous_row_attention, # Alternate row, column and previous row attn
"large_separated_enc_dec_w_lyrics": large_separated_enc_dec_w_lyrics, # Used by large separated_enc_dec model with lyrics
"enc_dec_with_lyrics": enc_dec_with_lyrics, # Used by encoder_decoder model with lyrics
}
class JukeboxPriorConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`JukeboxPrior`]. It is used to instantiate a
`JukeboxPrior` according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the top level prior from the
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox
-1b-lyrics) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
act_fn (`str`, *optional*, defaults to `"quick_gelu"`):
Activation function.
alignment_head (`int`, *optional*, defaults to 2):
Head that is responsible of the alignment between lyrics and music. Only used to compute the lyric to audio
alignment
alignment_layer (`int`, *optional*, defaults to 68):
Index of the layer that is responsible of the alignment between lyrics and music. Only used to compute the
lyric to audio alignment
attention_multiplier (`float`, *optional*, defaults to 0.25):
Multiplier coefficient used to define the hidden dimension of the attention layers. 0.25 means that
0.25*width of the model will be used.
attention_pattern (`str`, *optional*, defaults to `"enc_dec_with_lyrics"`):
Which attention pattern to use for the decoder/
attn_dropout (`int`, *optional*, defaults to 0):
Dropout probability for the post-attention layer dropout in the decoder.
attn_res_scale (`bool`, *optional*, defaults to `False`):
Whether or not to scale the residuals in the attention conditioner block.
blocks (`int`, *optional*, defaults to 64):
Number of blocks used in the `block_attn`. A sequence of length seq_len is factored as `[blocks, seq_len //
blocks]` in the `JukeboxAttention` layer.
conv_res_scale (`int`, *optional*):
Whether or not to scale the residuals in the conditioner block. Since the top level prior does not have a
conditioner, the default value is to None and should not be modified.
num_layers (`int`, *optional*, defaults to 72):
Number of layers of the transformer architecture.
emb_dropout (`int`, *optional*, defaults to 0):
Embedding dropout used in the lyric decoder.
encoder_config (`JukeboxPriorConfig`, *optional*) :
Configuration of the encoder which models the prior on the lyrics.
encoder_loss_fraction (`float`, *optional*, defaults to 0.4):
Multiplication factor used in front of the lyric encoder loss.
hidden_size (`int`, *optional*, defaults to 2048):
Hidden dimension of the attention layers.
init_scale (`float`, *optional*, defaults to 0.2):
Initialization scales for the prior modules.
is_encoder_decoder (`bool`, *optional*, defaults to `True`):
Whether or not the prior is an encoder-decoder model. In case it is not, and `nb_relevant_lyric_tokens` is
greater than 0, the `encoder` args should be specified for the lyric encoding.
mask (`bool`, *optional*, defaults to `False`):
Whether or not to mask the previous positions in the attention.
max_duration (`int`, *optional*, defaults to 600):
Maximum supported duration of the generated song in seconds.
max_nb_genres (`int`, *optional*, defaults to 1):
Maximum number of genres that can be used to condition the model.
merged_decoder (`bool`, *optional*, defaults to `True`):
Whether or not the decoder and the encoder inputs are merged. This is used for the separated
encoder-decoder architecture
metadata_conditioning (`bool`, *optional*, defaults to `True)`:
Whether or not to condition on the artist and genre metadata.
metadata_dims (`List[int]`, *optional*, defaults to `[604, 7898]`):
Number of genres and the number of artists that were used to train the embedding layers of the prior
models.
min_duration (`int`, *optional*, defaults to 0):
Minimum duration of the generated audio on which the model was trained.
mlp_multiplier (`float`, *optional*, defaults to 1.0):
Multiplier coefficient used to define the hidden dimension of the MLP layers. 0.25 means that 0.25*width of
the model will be used.
music_vocab_size (`int`, *optional*, defaults to 2048):
Number of different music tokens. Should be similar to the `JukeboxVQVAEConfig.nb_discrete_codes`.
n_ctx (`int`, *optional*, defaults to 6144):
Number of context tokens for each prior. The context tokens are the music tokens that are attended to when
generating music tokens.
n_heads (`int`, *optional*, defaults to 2):
Number of attention heads.
nb_relevant_lyric_tokens (`int`, *optional*, defaults to 384):
Number of lyric tokens that are used when sampling a single window of length `n_ctx`
res_conv_depth (`int`, *optional*, defaults to 3):
Depth of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
`JukeboxMusicTokenConditioner`.
res_conv_width (`int`, *optional*, defaults to 128):
Width of the `JukeboxDecoderConvBock` used to upsample the previously sampled audio in the
`JukeboxMusicTokenConditioner`.
res_convolution_multiplier (`int`, *optional*, defaults to 1):
Multiplier used to scale the `hidden_dim` of the `JukeboxResConv1DBlock`.
res_dilation_cycle (`int`, *optional*):
Dilation cycle used to define the `JukeboxMusicTokenConditioner`. Usually similar to the ones used in the
corresponding level of the VQVAE. The first prior does not use it as it is not conditioned on upper level
tokens.
res_dilation_growth_rate (`int`, *optional*, defaults to 1):
Dilation grow rate used between each convolutionnal block of the `JukeboxMusicTokenConditioner`
res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
Downsampling rates used in the audio conditioning network
res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
Striding used in the audio conditioning network
resid_dropout (`int`, *optional*, defaults to 0):
Residual dropout used in the attention pattern.
sampling_rate (`int`, *optional*, defaults to 44100):
Sampling rate used for training.
spread (`int`, *optional*):
Spread used in the `summary_spread_attention` pattern
timing_dims (`int`, *optional*, defaults to 64):
Dimension of the timing embedding.
zero_out (`bool`, *optional*, defaults to `False`):
Whether or not to zero out convolution weights when initializing.
"""
model_type = "jukebox_prior"
attribute_map = {
"max_position_embeddings": "n_positions",
"num_attention_heads": "n_head",
}
def __init__(
self,
act_fn="quick_gelu",
level=0,
alignment_head=2,
alignment_layer=68,
attention_multiplier=0.25,
attention_pattern="enc_dec_with_lyrics",
attn_dropout=0,
attn_res_scale=False,
blocks=64,
conv_res_scale=None,
num_layers=72,
emb_dropout=0,
encoder_config=None,
encoder_loss_fraction=0.4,
hidden_size=2048,
init_scale=0.2,
is_encoder_decoder=True,
lyric_vocab_size=80,
mask=False,
max_duration=600,
max_nb_genres=1,
merged_decoder=True,
metadata_conditioning=True,
metadata_dims=[604, 7898],
min_duration=0,
mlp_multiplier=1.0,
music_vocab_size=2048,
n_ctx=6144,
n_heads=2,
nb_relevant_lyric_tokens=384,
res_conv_depth=3,
res_conv_width=128,
res_convolution_multiplier=1,
res_dilation_cycle=None,
res_dilation_growth_rate=1,
res_downs_t=[3, 2, 2],
res_strides_t=[2, 2, 2],
resid_dropout=0,
sampling_rate=44100,
spread=None,
timing_dims=64,
zero_out=False,
**kwargs,
):
self.act_fn = act_fn
self.alignment_head = alignment_head
self.alignment_layer = alignment_layer
self.attention_multiplier = attention_multiplier
self.attention_pattern = attention_pattern
self.attn_dropout = attn_dropout
self.attn_res_scale = attn_res_scale
self.blocks = blocks
self.conv_res_scale = conv_res_scale
self.num_layers = num_layers
self.emb_dropout = emb_dropout
self.music_vocab_size = music_vocab_size
if encoder_config is not None:
self.encoder_config = JukeboxPriorConfig(**encoder_config)
else:
self.encoder_config = None
self.encoder_loss_fraction = encoder_loss_fraction
self.init_scale = init_scale
self.is_encoder_decoder = is_encoder_decoder
self.lyric_vocab_size = lyric_vocab_size
self.level = level
self.mask = mask
self.max_duration = max_duration
self.max_nb_genres = max_nb_genres
self.merged_decoder = merged_decoder
self.metadata_conditioning = metadata_conditioning
self.metadata_dims = metadata_dims
self.min_duration = min_duration
self.mlp_multiplier = mlp_multiplier
self.n_ctx = n_ctx
self.n_heads = n_heads
self.nb_relevant_lyric_tokens = nb_relevant_lyric_tokens
self.res_conv_depth = res_conv_depth
self.res_conv_width = res_conv_width
self.res_convolution_multiplier = res_convolution_multiplier
self.res_dilation_cycle = res_dilation_cycle
self.res_dilation_growth_rate = res_dilation_growth_rate
self.res_downs_t = res_downs_t
self.res_strides_t = res_strides_t
self.resid_dropout = resid_dropout
self.sampling_rate = sampling_rate
self.spread = spread
self.timing_dims = timing_dims
self.hidden_size = hidden_size
self.zero_out = zero_out
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], level=0, **kwargs
) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the prior config dict if we are loading from JukeboxConfig
if config_dict.get("model_type") == "jukebox":
config_dict = config_dict[f"prior_{level}"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class JukeboxVQVAEConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`JukeboxVQVAE`]. It is used to instantiate a
`JukeboxVQVAE` according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VQVAE from
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
act_fn (`str`, *optional*, defaults to `"relu"`):
Activation function of the model.
nb_discrete_codes (`int`, *optional*, defaults to 2048):
Number of codes of the VQVAE.
commit (`float`, *optional*, defaults to 0.02):
Commit loss multiplier.
conv_input_shape (`int`, *optional*, defaults to 1):
Number of audio channels.
conv_res_scale (`bool`, *optional*, defaults to `False`):
Whether or not to scale the residuals of the `JukeboxResConv1DBlock`.
embed_dim (`int`, *optional*, defaults to 64):
Embedding dimension of the codebook vectors.
hop_fraction (`List[int]`, *optional*, defaults to `[0.125, 0.5, 0.5]`):
Fraction of non-intersecting window used when continuing the sampling process.
levels (`int`, *optional*, defaults to 3):
Number of hierarchical levels that used in the VQVAE.
lmu (`float`, *optional*, defaults to 0.99):
Used in the codebook update, exponential moving average coefficient. For more detail refer to Appendix A.1
of the original [VQVAE paper](https://arxiv.org/pdf/1711.00937v2.pdf)
multipliers (`List[int]`, *optional*, defaults to `[2, 1, 1]`):
Depth and width multipliers used for each level. Used on the `res_conv_width` and `res_conv_depth`
res_conv_depth (`int`, *optional*, defaults to 4):
Depth of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
res_conv_width (`int`, *optional*, defaults to 32):
Width of the encoder and decoder block. If no `multipliers` are used, this is the same for each level.
res_convolution_multiplier (`int`, *optional*, defaults to 1):
Scaling factor of the hidden dimension used in the `JukeboxResConv1DBlock`.
res_dilation_cycle (`int`, *optional*):
Dilation cycle value used in the `JukeboxResnet`. If an int is used, each new Conv1 block will have a depth
reduced by a power of `res_dilation_cycle`.
res_dilation_growth_rate (`int`, *optional*, defaults to 3):
Resnet dilation growth rate used in the VQVAE (dilation_growth_rate ** depth)
res_downs_t (`List[int]`, *optional*, defaults to `[3, 2, 2]`):
Downsampling rate for each level of the hierarchical VQ-VAE.
res_strides_t (`List[int]`, *optional*, defaults to `[2, 2, 2]`):
Stride used for each level of the hierarchical VQ-VAE.
sample_length (`int`, *optional*, defaults to 1058304):
Provides the max input shape of the VQVAE. Is used to compute the input shape of each level.
init_scale (`float`, *optional*, defaults to 0.2):
Initialization scale.
zero_out (`bool`, *optional*, defaults to `False`):
Whether or not to zero out convolution weights when initializing.
"""
model_type = "jukebox_vqvae"
def __init__(
self,
act_fn="relu",
nb_discrete_codes=2048,
commit=0.02,
conv_input_shape=1,
conv_res_scale=False,
embed_dim=64,
hop_fraction=[0.125, 0.5, 0.5],
levels=3,
lmu=0.99,
multipliers=[2, 1, 1],
res_conv_depth=4,
res_conv_width=32,
res_convolution_multiplier=1,
res_dilation_cycle=None,
res_dilation_growth_rate=3,
res_downs_t=[3, 2, 2],
res_strides_t=[2, 2, 2],
sample_length=1058304,
init_scale=0.2,
zero_out=False,
**kwargs,
):
self.hop_fraction = hop_fraction
self.conv_input_shape = conv_input_shape
self.sample_length = sample_length
# VQVAE parameters (all used)
self.levels = levels
self.embed_dim = embed_dim
self.nb_discrete_codes = nb_discrete_codes
self.res_conv_width = res_conv_width
self.res_conv_depth = res_conv_depth
self.res_convolution_multiplier = res_convolution_multiplier
self.res_dilation_growth_rate = res_dilation_growth_rate
self.res_dilation_cycle = res_dilation_cycle
self.multipliers = multipliers
self.res_downs_t = res_downs_t
self.res_strides_t = res_strides_t
self.lmu = lmu
self.commit = commit
self.conv_res_scale = conv_res_scale
self.act_fn = act_fn
self.init_scale = init_scale
self.zero_out = zero_out
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the text config dict if we are loading from CLIPConfig
if config_dict.get("model_type") == "jukebox":
config_dict = config_dict["vqvae_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class JukeboxConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`JukeboxModel`].
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. Instantiating a configuration with the defaults will
yield a similar configuration to that of
[openai/jukebox-1b-lyrics](https://huggingface.co/openai/jukebox-1b-lyrics) architecture.
The downsampling and stride are used to determine downsampling of the input sequence. For example, downsampling =
(5,3), and strides = (2, 2) will downsample the audio by 2^5 = 32 to get the first level of codes, and 2**8 = 256
to get the second level codes. This is mostly true for training the top level prior and the upsamplers.
Args:
vqvae_config (`JukeboxVQVAEConfig`, *optional*):
Configuration for the `JukeboxVQVAE` model.
prior_config_list (`List[JukeboxPriorConfig]`, *optional*):
List of the configs for each of the `JukeboxPrior` of the model. The original architecture uses 3 priors.
nb_priors (`int`, *optional*, defaults to 3):
Number of prior models that will sequentially sample tokens. Each prior is conditional auto regressive
(decoder) model, apart from the top prior, which can include a lyric encoder. The available models were
trained using a top prior and 2 upsampler priors.
sampling_rate (`int`, *optional*, defaults to 44100):
Sampling rate of the raw audio.
timing_dims (`int`, *optional*, defaults to 64):
Dimensions of the JukeboxRangeEmbedding layer which is equivalent to traditional positional embedding
layer. The timing embedding layer converts the absolute and relative position in the currently sampled
audio to a tensor of length `timing_dims` that will be added to the music tokens.
min_duration (`int`, *optional*, defaults to 0):
Minimum duration of the audios to generate
max_duration (`float`, *optional*, defaults to 600.0):
Maximum duration of the audios to generate
max_nb_genres (`int`, *optional*, defaults to 5):
Maximum number of genres that can be used to condition a single sample.
metadata_conditioning (`bool`, *optional*, defaults to `True`):
Whether or not to use metadata conditioning, corresponding to the artist, the genre and the min/maximum
duration.
Example:
```python
>>> from transformers import JukeboxModel, JukeboxConfig
>>> # Initializing a Jukebox configuration
>>> configuration = JukeboxConfig()
>>> # Initializing a model from the configuration
>>> model = JukeboxModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "jukebox"
def __init__(
self,
vqvae_config=None,
prior_config_list=None,
nb_priors=3,
sampling_rate=44100,
timing_dims=64,
min_duration=0,
max_duration=600.0,
max_nb_genres=5,
metadata_conditioning=True,
**kwargs,
):
if vqvae_config is None:
vqvae_config = {}
logger.info("vqvae_config is None. initializing the JukeboxVQVAE with default values.")
self.vqvae_config = JukeboxVQVAEConfig(**vqvae_config)
if prior_config_list is not None:
self.prior_configs = [JukeboxPriorConfig(**prior_config) for prior_config in prior_config_list]
else:
self.prior_configs = []
for prior_idx in range(nb_priors):
prior_config = kwargs.pop(f"prior_{prior_idx}", None)
if prior_config is None:
prior_config = {}
logger.info(
f"prior_{prior_idx}'s config is None. Initializing the JukeboxPriorConfig list with default"
" values."
)
self.prior_configs.append(JukeboxPriorConfig(**prior_config))
self.hop_fraction = self.vqvae_config.hop_fraction
self.nb_priors = nb_priors
# Metadata conditioning
self.max_nb_genres = max_nb_genres
self.sampling_rate = sampling_rate
self.timing_dims = timing_dims
self.min_duration = min_duration
self.max_duration = max_duration
self.metadata_conditioning = metadata_conditioning
super().__init__(**kwargs)
@classmethod
def from_configs(cls, prior_configs: List[JukeboxPriorConfig], vqvae_config: JukeboxVQVAEConfig, **kwargs):
r"""
Instantiate a [`JukeboxConfig`] (or a derived class) from clip text model configuration and clip vision model
configuration.
Returns:
[`JukeboxConfig`]: An instance of a configuration object
"""
prior_config_list = [config.to_dict() for config in prior_configs]
return cls(prior_config_list=prior_config_list, vqvae_config_dict=vqvae_config.to_dict(), **kwargs)
def to_dict(self):
# Override the default to_dict to apply to_dict to the list of prior configs.
result = super().to_dict()
result["prior_config_list"] = [config.to_dict() for config in result.pop("prior_configs")]
return result