In [None]:
!pip install audiolm-pytorch x-clip tqdm grapheme

In [None]:
import os
from pathlib import Path
import re
from tqdm import tqdm
import urllib

from audiolm_pytorch.data import SoundDataset
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans

from spear_tts_pytorch.spear_tts_pytorch import TextToSemantic
from spear_tts_pytorch.data import SemanticDataset, SemanticPhonemeDataset
from spear_tts_pytorch.trainer import SpeechSpeechPretrainer, SemanticToTextTrainer

import grapheme

import torch
import torchaudio

target_sample_hz = 24_000

In [None]:
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin'

if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin',
    target_sample_hz = target_sample_hz,
)

In [None]:
# download the dataset if needed

if not os.path.isdir("data"):
    os.makedirs('data')

ds = torchaudio.datasets.LIBRITTS(
    root = 'data',
    url = 'dev-clean',
    download = True
)

dataset_folder = 'data/LibriTTS/dev-clean'
data_max_length_seconds = 3

dataset = SemanticDataset(
    dataset_folder,
    max_length = int(data_max_length_seconds * target_sample_hz),
    target_sample_hz = target_sample_hz,
    seq_len_multiple_of = 320
)

In [None]:
# preprocess with wav2vec

wav_files = [str(p) for p in Path(dataset_folder).glob('**/*.wav')]
print(f'Found {len(wav_files)} files to preprocess.')

with torch.inference_mode():
    for wav_file in tqdm(wav_files):
        audio_input, _ = torchaudio.load(wav_file)
        processed_input = wav2vec(audio_input)
        output_file = wav_file.replace('.wav', '.semantic.pt')
        torch.save(processed_input, output_file)

In [None]:
# speech-to-speech pretraining

# T5-small from Table 9 in the paper

text_to_semantic_model = TextToSemantic(
    dim = 256,
    num_text_token_ids = 70,
    text_pad_id = 0,
    num_semantic_token_ids = 500,
    semantic_pad_id = 0,
    source_depth = 6,
    target_depth = 6,
    heads = 6,
    dim_head = 64,
    attn_dropout = 0.5,
    ff_mult = 2,
    ff_dropout = 0.5
)

trainer = SpeechSpeechPretrainer(
    model = text_to_semantic_model,
    wav2vec = wav2vec,
    dataset = dataset,
    batch_size = 256,
    grad_accum_every = 1,
    initial_lr=1e-5,
    lr = 1e-3,
    num_train_steps = 100_000,
    num_warmup_steps = 1000
)

trainer.train()

In [None]:
# preprocess the graphemes

all_graphemes = set()
for wav_file in tqdm(wav_files):
    caption_file = wav_file.replace('.wav', '.normalized.txt')
    with open(caption_file, 'r') as f:
        caption = re.sub(r'[^\w\s]', '', f.read().rstrip())
        f.close()
        
    graphemes_list = list(grapheme.graphemes(caption))
    all_graphemes.update(graphemes_list)

grapheme_tokens = sorted(list(all_graphemes))
grapheme_token_to_id = {token: i + 1 for i, token in enumerate(grapheme_tokens)}
print(f"Found {len(grapheme_tokens)} grapheme tokens: {grapheme_token_to_id}")

for wav_file in tqdm(wav_files):
    caption_file = wav_file.replace('.wav', '.normalized.txt')
    with open(caption_file, 'r') as f:
        caption = re.sub(r'[^\w\s]', '', f.read().rstrip())
        f.close()
        
    graphemes_list = list(grapheme.graphemes(caption))
    grapheme_token_ids = torch.LongTensor([grapheme_token_to_id[token] for token in graphemes_list])
    grapheme_file = wav_file.replace('.wav', '.graphemes.pt')
    torch.save(grapheme_token_ids, grapheme_file)

In [None]:
# semantic-to-text finetuning for backtranslation

text_to_semantic_model = TextToSemantic(
    dim = 256,
    num_text_token_ids = 70,
    text_pad_id = 0,
    num_semantic_token_ids = 500,
    semantic_pad_id = 0,
    source_depth = 6,
    target_depth = 6,
    heads = 6,
    dim_head = 64,
    attn_dropout = 0.1,
    ff_mult = 2,
    ff_dropout = 0.1,
    freeze_encoder = True
)

dataset = SemanticPhonemeDataset(
    dataset_folder,
    max_length = int(data_max_length_seconds * target_sample_hz),
    target_sample_hz = target_sample_hz,
    seq_len_multiple_of = 320
)

trainer = SemanticToTextTrainer(
    model = text_to_semantic_model,
    dataset = dataset,
    batch_size = 48,
    grad_accum_every = 1,
    initial_lr=1e-5,
    lr = 1e-3,
    num_train_steps = 10_000,
    num_warmup_steps = 10,
    force_clear_prev_results = False
)

# load the pretrained checkpoint

trainer.load("results/speech.speech.XXXX.pt", restore_optimzier = False)
trainer.train()