In [1]:
import os

# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import numpy as np
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers import AddedToken

IGNORE_INDEX = -100

In [3]:
def lengths_to_padding_mask(lens):
    bsz, max_lens = lens.size(0), torch.max(lens).item()
    mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
    mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
    return mask


def _uniform_assignment(src_lens, tgt_lens):
    tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device)
    ratio = tgt_lens / src_lens
    index_t = (tgt_indices / ratio.view(-1, 1)).long()
    return index_t

class SpeechGeneratorCTC(torch.nn.Module):
    def __init__(self, config, ctc_upsample_factor = 26, unit_vocab_size = 1024):
        super().__init__()
        n_layers, n_dims, n_heads, n_inter_dims = 2,4096,32,11008
        _config = copy.deepcopy(config)
        _config.hidden_size = n_dims
        _config.num_hidden_layers = n_layers
        _config.num_attention_heads = n_heads
        _config.num_key_value_heads = n_heads
        _config.intermediate_size = n_inter_dims
        _config._attn_implementation = "flash_attention_2"
        self.upsample_factor = ctc_upsample_factor
        self.input_proj = nn.Linear(config.hidden_size, n_dims)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)]
        )
        self.unit_vocab_size = unit_vocab_size
        self.output_proj = nn.Linear(n_dims, self.unit_vocab_size + 1)
    
    def upsample(self, reps, tgt_units=None):
        src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device)
        up_lens = src_lens * self.upsample_factor
        if tgt_units is not None:
            tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
            up_lens = torch.max(up_lens, tgt_lens)
        reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True)
        padding_mask = lengths_to_padding_mask(up_lens)
        mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill(
            padding_mask, 0
        )
        copied_reps = torch.gather(
            reps,
            1,
            mapped_inputs.unsqueeze(-1).expand(
                *mapped_inputs.size(), reps.size(-1)
            ),
        )
        copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0)
        position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device)
        return copied_reps, ~padding_mask, position_ids

    def forward(self, tgt_reps, labels, tgt_units):
        tgt_label_reps = []
        for tgt_rep, label in zip(tgt_reps, labels):
            tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX])
        hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units)
        hidden_states = self.input_proj(hidden_states)
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )
            hidden_states = layer_outputs[0]
        ctc_logits = self.output_proj(hidden_states)
        ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
        ctc_lens = attention_mask.long().sum(dim=-1)
        ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1)
        ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens)
        ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask)
        ctc_loss = F.ctc_loss(
            ctc_lprobs.transpose(0, 1),
            ctc_tgt_flat,
            ctc_lens,
            ctc_tgt_lens,
            reduction="sum",
            zero_infinity=True,
            blank=self.unit_vocab_size
        )
        ctc_loss /= ctc_tgt_lens.sum().item()
        return ctc_loss
    
    def predict(self, tgt_reps):
        hidden_states, attention_mask, position_ids = self.upsample([tgt_reps])
        hidden_states = self.input_proj(hidden_states)
        for layer in self.layers:
            layer_outputs = layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )
            hidden_states = layer_outputs[0]
        ctc_logits = self.output_proj(hidden_states)
        ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32)
        ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size)
        return ctc_pred

In [4]:
class LlamaTTS(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.speech_generator = SpeechGeneratorCTC(self.config)
        
    def forward(self, tgt_units = None, **kwargs):
        return super().forward(**kwargs)

In [5]:
model = LlamaTTS.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct',
                                torch_dtype = torch.bfloat16)
_ = model.cuda()

Some weights of LlamaTTS were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-135M-Instruct and are newly initialized: ['speech_generator.input_proj.bias', 'speech_generator.input_proj.weight', 'speech_generator.layers.0.input_layernorm.weight', 'speech_generator.layers.0.mlp.down_proj.weight', 'speech_generator.layers.0.mlp.gate_proj.weight', 'speech_generator.layers.0.mlp.up_proj.weight', 'speech_generator.layers.0.post_attention_layernorm.weight', 'speech_generator.layers.0.self_attn.k_proj.weight', 'speech_generator.layers.0.self_attn.o_proj.weight', 'speech_generator.layers.0.self_attn.q_proj.weight', 'speech_generator.layers.0.self_attn.v_proj.weight', 'speech_generator.layers.1.input_layernorm.weight', 'speech_generator.layers.1.mlp.down_proj.weight', 'speech_generator.layers.1.mlp.gate_proj.weight', 'speech_generator.layers.1.mlp.up_proj.weight', 'speech_generator.layers.1.post_attention_layernorm.weight', 'speech_generator.layers.1.self_attn.k_proj.weight', 

In [6]:
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-135M-Instruct')

In [7]:
new = ['<|speaker|>']
new = [AddedToken(t) for t in new]
tokenizer.add_tokens(new)

1

In [8]:
model.resize_token_embeddings(len(tokenizer), mean_resizing=False)

Embedding(49153, 576, padding_idx=2)

In [9]:
import pandas as pd

df = pd.read_parquet('data/train-00000-of-00001.parquet').to_dict(orient = 'records')
len(df)

360298

In [10]:
row = df[0]
row

{'transcription': 'Sedangkan dalam bahasa Perancis , frira hanya bererti menggoreng di dalam minyak goreng yang banyak hingga terendam .',
 'speaker': 'Osman',
 'speaker_id': 1,
 'gender': 'male',
 'utterance_pitch_mean': 140.82264709472656,
 'utterance_pitch_std': 37.72042465209961,
 'snr': 69.54813385009766,
 'c50': 55.92512130737305,
 'speech_duration': 6.648750000000001,
 'stoi': 0.9943549633026123,
 'si-sdr': 16.59736442565918,
 'pesq': 3.5911829471588135,
 'pitch': 'slightly high pitch',
 'speaking_rate': 'very slowly',
 'noise': 'very clear',
 'reverberation': 'very confined sounding',
 'speech_monotony': 'very monotone',
 'prompt': 'Osman, a male speaker with a moderately high-pitched voice delivers an animated and expressive speech in a confined room with very clear recording. His voice is very monotone, and he speaks very slowly.',
 'audio_filename': 'combine-audio/0.mp3'}

In [16]:
speaker = f"<|speaker|>{row['speaker']}<|speaker|>"
len_speaker_token = len(tokenizer.tokenize(speaker))
prompt = f"{speaker}{row['transcription']}"
input_ids = tokenizer(prompt, add_special_tokens = False, return_tensors = 'pt').to('cuda')
input_ids

{'input_ids': tensor([[49152, 11062,  1483, 49152,    67,   277,   604, 24184,   287, 20462,
           278,  1287, 14852,  3017,  1148,   271,  3297,  1669,   317,   294,
         37430, 39802,  9573,    89,  1800,  1057,   390,   863,   801,   287,
         20462,  1079,   105,   494,   310,   390,   863, 33856,   278,  1111,
           494, 27427,  8662,   252,   518,   268,   332,  1673]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}

In [12]:
splitted = row['audio_filename'].split('/')
new_f = '/'.join([splitted[0] + '_vqgan'] + splitted[1:]).replace('.mp3', '.npy')
speech_token = np.load(new_f)
speech_token.shape

(1224,)

In [17]:
tgt_units = torch.tensor([speech_token]).to('cuda')

In [18]:
o = model(**input_ids, output_hidden_states = True, tgt_units = tgt_units)

In [19]:
o

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 15.6250,   6.7812,  10.1875,  ...,  11.5625,   3.0938,  -0.6523],
         [ 18.3750,  11.1875,  15.3750,  ...,  15.6250,   8.5000,  -0.2578],
         [  4.6250, -18.0000, -15.7500,  ...,  -3.1406, -14.8125,   0.6172],
         ...,
         [ 19.3750,  -0.9219,  -0.0552,  ...,  12.3125,   9.8125,  -0.8047],
         [  9.8125,  -5.8125,  -5.3750,  ...,   8.1875,  -2.8594,  -1.0000],
         [  8.1250,   5.7188,  10.0000,  ...,   7.8750,  -1.6172,   1.0859]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[ 0.4355, -0.5117, -0.0317,  ..., -0.2715, -0.0140,  0.4023],
          [-0.9375,  0.1650, -0.4316,  ..., -0.6445,  2.7656,  0.1089],
          [ 1.8750,  0.3926, -0.4883,  ..., -0.3770,  2.1250, -0.0923],
          ...,
          [-0.6406,  0.2910,  0.2988,  ..., -0.5352,  2.2969,  0.1279],
          [ 1.0312,  0.2490,  0.3945,  ...,  0.1216,  1.6641,  0.2891],
          

In [21]:
hidden_states, attention_mask, position_ids = model.speech_generator.upsample(o.hidden_states[-1], tgt_units = tgt_units)

In [26]:
input_ids['input_ids']

tensor([[49152, 11062,  1483, 49152,    67,   277,   604, 24184,   287, 20462,
           278,  1287, 14852,  3017,  1148,   271,  3297,  1669,   317,   294,
         37430, 39802,  9573,    89,  1800,  1057,   390,   863,   801,   287,
         20462,  1079,   105,   494,   310,   390,   863, 33856,   278,  1111,
           494, 27427,  8662,   252,   518,   268,   332,  1673]],
       device='cuda:0')

In [28]:
model.speech_generator(tgt_reps = o.hidden_states[-1], labels = input_ids['input_ids'], tgt_units = tgt_units)

tensor(252.9116, device='cuda:0', grad_fn=<DivBackward0>)

In [None]:
ctc_loss