In [3]:
import torch
import torchaudio
import numpy as np
import sys

from artst.tasks.artst import ArTSTTask
from artst.models.artst import ArTSTTransformerModel
from sentencepiece import SentencePieceProcessor

from fairseq import checkpoint_utils, utils
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
from fairseq.data import data_utils


checkpoint = torch.load('.../ArTST-hf/CLARTTS_ArTSTstar_TTS.pt')  # path to change
checkpoint['cfg']['task'].t5_task = 't2s' # or "s2t" for asr
checkpoint['cfg']['task'].data = '.../ArTST-hf'  # path to change
task = ArTSTTask.setup_task(checkpoint['cfg']['task'])
task.args.bpe_tokenizer = '.../ArTST-hf/tts_spm.model'  # path to change

model = ArTSTTransformerModel.build_model(checkpoint['cfg']['model'], task)
model.load_state_dict(checkpoint['model'])



ArTSTTransformerModel(
  (encoder): TransformerEncoder(
    (dropout_module): FairseqDropout()
    (layers): ModuleList(
      (0-11): 12 x TransformerSentenceEncoderLayer(
        (self_attn): MultiheadAttention(
          (dropout_module): FairseqDropout()
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (dropout1): Dropout(p=0.15, inplace=False)
        (dropout2): Dropout(p=0.15, inplace=False)
        (dropout3): Dropout(p=0.15, inplace=False)
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (final_layer_norm): LayerNorm((768,), eps=1e-05, ele

In [53]:
bpe = task.build_bpe(task.args)
sp = SentencePieceProcessor(model_file='.../ArTST-hf/tts_spm.model')  # path to change
spkembs = get_features_or_waveform('.../CLARTTS_speaker_embedding.npy')  # path to change
spkembs = torch.from_numpy(spkembs).float()


def _collate_frames(
    frames, is_audio_input: bool = False
):
    max_len = max(frame.size(0) for frame in frames)
    if is_audio_input:
        out = frames[0].new_zeros((len(frames), max_len))
    else:
        out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
    for i, v in enumerate(frames):
        out[i, : v.size(0)] = v
    return out

source = torch.Tensor([sp.encode("أدب العلم شرف العلم وفضه", out_type=int)]).int()

sample = {
    "net_input": {
        "src_tokens": source, 
        "src_lengths": source.size(-1), 
        "spkembs": _collate_frames([spkembs], is_audio_input=True), 
        "padding_mask": None
    }
}
models = [model]
output, _, attn = task.generate_speech(
    models,
    sample['net_input']
)

In [66]:
import sys

from vocoders.vocoder_melgan_hf import MelGANGenerator
from vocoders.vocoder_mb_melgan_hf import MBMelGANGenerator
from vocoders.config_other_vocoders import MelGanConfig
from vocoders.vocoder_parallelwavegan_hf import ParallelWaveGANGenerator
from vocoders.config_other_vocoders import ParallelWaveGanConfig
from vocoders.vocoder_style_melgan_hf import StyleMelGANGenerator
from vocoders.config_other_vocoders import StyleMelGanConfig

from transformers import SpeechT5HifiGan

vocoder0 = SpeechT5HifiGan.from_pretrained('microsoft/speecht5_hifigan')
vocoder1 = SpeechT5HifiGan.from_pretrained('ArTST/vocoders/hifigan')  # trained on ClArTTS
vocoder2 = MelGANGenerator.from_pretrained('ArTST/vocoders/melgan')  # trained on ClArTTS
vocoder3 = ParallelWaveGANGenerator.from_pretrained('ArTST/vocoders/parallel_wavegan')  # trained on ClArTTS
vocoder4 = StyleMelGANGenerator.from_pretrained('ArTST/vocoders/style_melgan')  # trained on ClArTTS
vocoder5 = MBMelGANGenerator.from_pretrained('ArTST/vocoders/multiband_melgan')  # trained on ClArTTS

In [67]:
with torch.no_grad():
    gen_audio = vocoder0(output)
speech = (gen_audio.cpu().numpy() * 32767).astype(np.int16)

In [68]:
from IPython.display import Audio

Audio(speech, rate=16000)