In [None]:
!pip install audiolm-pytorch x-clip

In [None]:
import os
import urllib

from audiolm_pytorch.data import SoundDataset
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from spear_tts_pytorch.spear_tts_pytorch import TextToSemantic
from spear_tts_pytorch.trainer import SpeechSpeechPretrainer

from torchaudio import datasets


In [None]:
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin'

if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin',
    target_sample_hz = 24_000,
)


In [None]:
# Download the dataset if needed.
if not os.path.isdir("data"):
    os.makedirs('data')

ds = datasets.LIBRITTS(
    root = 'data',
    url = 'dev-clean',
    download = True
)

dataset_folder = 'data/LibriTTS/dev-clean'
data_max_length_seconds = 3

dataset = SoundDataset(
    dataset_folder,
    max_length = int(data_max_length_seconds * wav2vec.target_sample_hz),
    target_sample_hz = wav2vec.target_sample_hz,
    seq_len_multiple_of = wav2vec.seq_len_multiple_of
)

In [None]:
# T5-small from Table 9 in the paper.

text_to_semantic_model = TextToSemantic(
    dim = 256,
    num_text_token_ids = 32100,
    source_depth = 6,
    target_depth = 6,
    heads = 8,
    dim_head = 64,
    wav2vec = wav2vec,
    num_semantic_token_ids = wav2vec.codebook_size,
    attn_dropout = 0.5,
    ff_mult = 2,
    ff_dropout = 0.5
)

trainer = SpeechSpeechPretrainer(
    model = text_to_semantic_model,
    wav2vec = wav2vec,
    dataset = dataset,
    batch_size = 16,
    grad_accum_every = 4,
    lr = 2e-4,
    num_train_steps = 100_000
)
trainer.train()
