Speech encoder -> [BS, T, D]
Transcripcion -> Tokenizer -> [BS, lentrans] -> Embedding -> [BS, lentrans, D]

concat(axis=-1) -> [BS, T+lentrans, D]
loss(T:)

Encoder -> Decoder
Audio Texto

Solo Decoder -> Audio + Texto

x[:-1] -> LLM -> x[1:]

Agregar a la transcripcion un <eos> al final

Dataset -> {'wav': numpy, 'transcription': numpy (ids)}
Dataloader -> [{}, {}] -> collate_fn -> {'wav': tensor paddeado, 'wav_lens': tensor lens, 'transcription': tensor paddeado, 'transcription_lens': tensor lens}


In [8]:
from transformers import AutoModelForCausalLM, AutoTokenizer, WavLMModel
import pytorch_lightning as pl
import torch
from abc import abstractmethod

In [9]:
class HFLLMModel(torch.nn.Module):
    def __init__(self, hf_path):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(hf_path)
        self.tokenizer = AutoTokenizer.from_pretrained(hf_path)

    def forward(self, x, attention_mask, **kwargs):
        return self.model(inputs_embeds=x, attention_mask=attention_mask, **kwargs)

    @abstractmethod
    def get_lut(self):
        pass

class WavLM(torch.nn.Module):
    def __init__(self, hf_path, layer=12):
        super().__init__()
        self.model = WavLMModel.from_pretrained(hf_path)
        self.downsampling = 320
        self.layer = layer

    def forward(self, x):
        return torch.stack(self.model(x,output_hidden_states=True)['hidden_states'])[self.layer]

class GPTModel(HFLLMModel):
    def get_lut(self):
        return self.model.transformer.wte

class LLMASR(pl.LightningModule):
    def __init__(self, llm_model, wav_model):
        super().__init__()
        self.llm_model = llm_model
        self.llm_model_lut = self.llm_model.get_lut()
        self.wav_model = wav_model

    def prepare_input(self, speech, transcription, speech_lens, transcription_lens):
        x = []
        speech = self.wav_model(speech)
        speech_lens = speech_lens//self.wav_model.downsampling
        transcription = self.llm_model_lut(transcription)
        for s,sl,t,tl in zip(speech, speech_lens, transcription, transcription_lens):
            si = s[:sl]
            ti = t[:tl]
            xi = torch.cat([si,ti],axis=0)
            x.append(xi)
        xlens = [len(xi) for xi in x]
        maxlen = max(xlens)
        x = [torch.nn.functional.pad(xi,(0,0,0,maxlen - xi.shape[0])) for xi in x]
        xlens = torch.tensor(xlens)
        speech_lens = torch.tensor(speech_lens)
        padding_mask = torch.arange(0,maxlen)[None,:] < xlens[:,None]
        response_mask = torch.logical_and(torch.arange(0,maxlen)[None,:] >= speech_lens[:,None],torch.arange(0,maxlen)[None,:] < xlens[:,None])
        return torch.stack(x), padding_mask, response_mask
        
    def forward(self, speech, transcription, speech_lens, transcription_lens):
        xin, padding_mask, response_mask = self.prepare_input(speech, transcription, speech_lens, transcription_lens)
        return self.llm_model(xin[:,:-1], attention_mask = padding_mask), response_mask[:,1:], xin[:,1:]

    def training_step(self, batch, batch_idx):
        response_mask, logits, ytrue = self(batch)
        

In [10]:
llm_model = GPTModel('DeepESP/gpt2-spanish')
wav_model = WavLM('microsoft/wavlm-base-plus')
asr_model = LLMASR(llm_model, wav_model)

Some weights of the model checkpoint at microsoft/wavlm-base-plus were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at microsoft/wavlm-base-plus and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictio

In [22]:
llm_model.tokenizer(['hola, como andas?.','hola'], add_special_tokens=True, padding=True, truncation=False)

{'input_ids': [[1468, 334, 21, 420, 50127, 40, 23], [1468, 334, 50256, 50256, 50256, 50256, 50256]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 0, 0, 0, 0, 0]]}

In [16]:
llm_model.tokenizer.

50256

In [162]:
speech = torch.randn((4,32000))
transcription = torch.randint(low=0,high=10000,size=(4,30))
speech_lens = torch.tensor([32000,16000,24000,12000])
transcription_lens = torch.tensor([30,10,15,20])

out = asr_model(speech, transcription,  speech_lens, transcription_lens)