In [3]:
!pip install transformers datasets jiwer soundfile

Collecting soundfile
  Downloading soundfile-0.11.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: soundfile
Successfully installed soundfile-0.11.0


In [15]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
import torch
import numpy as np
import soundfile as sf

# Load Acoustic Model

In [16]:
model_name = "elgeish/wav2vec2-base-timit-asr"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)
model.eval()

dataset = load_dataset("Siyong/speech_timit", split="test").select(range(10))
char_translations = str.maketrans({"-": " ", ",": "", ".": "", "?": ""})

def prepare_example(example):
    example["speech"] = example["audio"]["array"]
    example["text"] = example["sentence"].translate(char_translations)
    example["text"] = " ".join(example["text"].split())  # clean up whitespaces
    example["text"] = example["text"].lower()
    return example

dataset = dataset.map(prepare_example)
inputs = processor(dataset["speech"], sampling_rate=16000, return_tensors="pt", padding="longest")

with torch.no_grad():
    predicted_ids = torch.argmax(model(inputs.input_values).logits, dim=-1)
predicted_ids[predicted_ids == -100] = processor.tokenizer.pad_token_id
predicted_transcripts = processor.tokenizer.batch_decode(predicted_ids)

for reference, predicted in zip(dataset["text"], predicted_transcripts):
    print("reference:", reference)
    print("predicted:", predicted)
    print("--")

Using custom data configuration Siyong--speech_timit-e00a2cf2b9b45bd5
Found cached dataset parquet (/csehome/heatz123/.cache/huggingface/datasets/Siyong___parquet/Siyong--speech_timit-e00a2cf2b9b45bd5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached processed dataset at /csehome/heatz123/.cache/huggingface/datasets/Siyong___parquet/Siyong--speech_timit-e00a2cf2b9b45bd5/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-6fb23cde0307bff6.arrow


reference: she had your dark suit in greasy wash water all year
predicted: she had your dark suit in greasy wash water all year
--
reference: she had your dark suit in greasy wash water all year
predicted: pshe had your dark suit in greasy wash water all year
--
reference: there are more obvious nymphomaniacs on any privateeye series
predicted: ther are more obvious nimfom many acts on any privite eye series
--
reference: only the best players enjoy popularity
predicted: only the best players enjoy popularity
--
reference: december and january are nice months to spend in miami
predicted: tdisember and january ar nice months tospend in my amy
--
reference: keep the thermometer under your tongue
predicted: keep the thermometer under your tongue
--
reference: you're a taxpayer householder landlord
predicted: you'r a taxpayer householder landlord
--
reference: does creole cooking use curry
predicted: does creole coking use curry
--
reference: scholastic aptitude is judged by standardized t

In [17]:
_ = model.to("cuda")

In [20]:
def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=16_000).input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    transcription = processor.decode(logits[0].argmax(dim=-1).cpu().numpy())
    batch["transcription"] = transcription
    return batch

result = dataset.map(map_to_pred, batched=False, batch_size=1, remove_columns=["audio"])
for label, pred in zip(result["sentence"], result["transcription"]):
    print('label:', label)
    print('pred:', pred)
    print('-----------')

100%|█████████████████████████████████████████████████| 10/10 [00:00<00:00, 20.26ex/s]

label: she had your dark suit in greasy wash water all year 
pred: she had your dark suit in greasy wash water all year
-----------
label: she had your dark suit in greasy wash water all year 
pred: she had your dark suit in greasy wash water all year
-----------
label: there are more obvious nymphomaniacs on any privateeye series 
pred: ther are more obvious nimpfom many acts on any privite eye series
-----------
label: only the best players enjoy popularity 
pred: only the best players enjoy popularity
-----------
label: december and january are nice months to spend in miami 
pred: tdisember and january ar nice months toespend in my amy
-----------
label: keep the thermometer under your tongue 
pred: keep the thermometer under your tongue
-----------
label: you're a taxpayer householder landlord 
pred: you'r a taxpayer householder landlord
-----------
label: does creole cooking use curry 
pred: does creole coking use curry
-----------
label: scholastic aptitude is judged by standardi




# Load GPT as Language Model

In [7]:
processor.tokenizer.word_delimiter_token_id
processor.tokenizer.pad_token_id

0

In [51]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
gpt = GPT2LMHeadModel.from_pretrained('gpt2').to("cuda").eval()

In [9]:
from transformers import GPTNeoForCausalLM

In [52]:
text = "Replace me by any text you'd like."
encoded_input = gpt_tokenizer(text, return_tensors='pt').to("cuda")
output = gpt(**encoded_input)
gpt_tokenizer.batch_decode(output.logits.argmax(dim=-1))

['. the with a means, want like.\n']

In [53]:
if not gpt_tokenizer.pad_token:
  gpt_tokenizer.pad_token = gpt_tokenizer.eos_token

Using pad_token, but it is not set yet.


# Transformer Decoder

In [105]:
class TransformerDecoder:
    def __init__(self, processor, gpt, gpt_tokenizer, beam_width=50, num_topk=10, alpha=0.1, beta=0., score_interval=50):
        self.processor = processor
        self.gpt = gpt
        self.gpt_tokenizer = gpt_tokenizer
        self.B = beam_width
        self.VK = num_topk
        self.alpha = alpha
        self.beta = beta
        self.score_interval = score_interval

        assert gpt_tokenizer.pad_token is not None

        self.processor_vocab = processor.tokenizer.get_vocab()
        self.i2c = {v: k for k, v in self.processor_vocab.items()} # index to char

        self.pad_token = processor.tokenizer.pad_token_id
        self.word_delimiter_token = processor.tokenizer.word_delimiter_token_id
        self.i2c[self.word_delimiter_token] = ' '

    def decode(self, logits, **kwargs):
        return self._beam_search(torch.from_numpy(logits))

    def _beam_search(self, audio_logit):
        B, VK, alpha, beta = self.B, self.VK, self.alpha, self.beta

        beam_scores = np.ones(B) * (-np.log(B))
        
        beam_infos = [
            {
                'text': '', # char level vocab doesn't need starting char (convert this to <s> when using kogpt2)
                'last_token': self.pad_token,
                'last_lm_score': 0,
            } for i in range(B)
        ]
        
        logits = audio_logit - audio_logit.logsumexp(dim=-1, keepdim=True)

        logits_topk, tokens_topk = torch.topk(logits, k=VK, dim=-1)
        tokens_topk = tokens_topk.numpy()
        logits_topk = logits_topk.numpy()
        chars_topk = [[self.i2c[i] for i in tokens] for tokens in tokens_topk]
        N, V = audio_logit.shape
        for n in range(N):
            
            logits, tokens = logits_topk[n], tokens_topk[n]
            chars = chars_topk[n]
            
            # get next scores first
            beam_hyp_scores = beam_scores[:, None].repeat(VK, axis=1)
            beam_hyp_scores += logits[None, :] # (B, VK) + (1, VK)
            beam_hyp_scores = beam_hyp_scores.reshape(-1)
            
            # get next beams by score
            indices_topk = np.argsort(beam_hyp_scores, axis=-1)[::-1][:B+5]
            beam_scores = beam_hyp_scores[indices_topk] # B*2, 1
            # print(beam_scores)
            # get next beam_info by scores
            new_beam_infos = [] # B, VK
            cache = {}
            for i, idx in enumerate(indices_topk):
                b, vk = idx // VK, idx % VK
                beam_info = beam_infos[b]
                
                new_last_token = tokens[vk]
                new_text = beam_info['text'] + chars[vk] if new_last_token != self.pad_token and beam_info['last_token'] != new_last_token \
                    else beam_info['text']
                last_lm_score = beam_info['last_lm_score']
                
                if (new_text, new_last_token) in cache:
                    target_i = cache[(new_text, new_last_token)]
                    beam_scores[target_i] = np.log(np.exp(beam_scores[target_i]) + np.exp(beam_scores[i]))
                    
                    beam_scores[i] = -1e9
                else:
                    cache[(new_text, new_last_token)] = i # beam index
                
                beam_hyp = {
                    'text': new_text,
                    'last_token': new_last_token,
                    'last_lm_score': last_lm_score,
                }
                
                new_beam_infos.append(beam_hyp)

            # now that we got top B*x scores and indices,
            # let's get final beams after merging scores
            
            indices_topk = np.argsort(beam_scores, axis=0)[::-1][:B]
            beam_scores = beam_scores[indices_topk]
            beam_infos = [new_beam_infos[i] for i in indices_topk]
            
            if n % self.score_interval == self.score_interval - 1:
                beam_scores = self._add_lm_scores(beam_infos, beam_scores, alpha, beta)

        beam_scores = self._add_lm_scores(beam_infos, beam_scores, alpha, beta)
        selected_beam_idx = beam_scores.argmax(axis=-1)


        return beam_infos[selected_beam_idx]['text'].strip()

    def _query_lm(self, texts, beta=0.):
        inputs = self.gpt_tokenizer(texts, return_tensors='pt', padding=True)
       
        B, N = inputs.input_ids.shape
        if N < 2:
            return torch.zeros(B)
        inputs = inputs.to(self.gpt.device)
        text_input_ids = inputs.input_ids
        B, N = text_input_ids.shape
        with torch.no_grad():
            labels = text_input_ids[:, :]
            text_logits = self.gpt(**inputs).logits
            
            shift_logits = text_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
            res = -lm_loss.reshape(B, -1)
            res = (res * inputs.attention_mask[:, 1:])
    
            scores = res.sum(dim=1) # N, 1
        
        scores = scores.cpu()
        
        input_lengths = np.array(list(map(len, texts)), dtype=np.float32)
        scores += input_lengths * beta

        return scores

    def _add_lm_scores(self, beam_infos, beam_scores, alpha, beta):
        beam_last_lm_scores = np.array(list(map(lambda x: x['last_lm_score'], beam_infos)), dtype=np.float32)
        beam_scores -= beam_last_lm_scores        
        
        texts = list(map(lambda x: x['text'], beam_infos))
        lm_scores = self._query_lm(texts, beta).numpy() * alpha
        
        for i, x in enumerate(beam_infos):
             x['last_lm_score'] = lm_scores[i]
        
        final_beam_scores = beam_scores + lm_scores
        return final_beam_scores

In [109]:
transformer_decoder = TransformerDecoder(
    processor=processor,
    gpt=gpt,
    gpt_tokenizer=gpt_tokenizer,
    beam_width=200,
    num_topk=10,
    alpha=1.0,
    beta=2.0,
    score_interval=50
)

In [110]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [111]:
def map_to_pred(batch):
    input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest", sampling_rate=16_000).input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    batch["pred_raw"] = processor.decode(logits[0].argmax(dim=-1).cpu().numpy()).lower()
    batch["pred_transformer_decoder"] = transformer_decoder.decode(logits[0].cpu().numpy()).lower()
    return batch

result = dataset.map(map_to_pred, batched=False, batch_size=1, remove_columns=["audio"])
for label, p, p1 in zip(result["sentence"], result["pred_raw"], result["pred_transformer_decoder"]):
    print('label   :\t', label)
    print('pred_raw:\t', p)
    print('pred_gpt:\t', p1)
    print('-----------')

100%|█████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.25ex/s]

label   :	 she had your dark suit in greasy wash water all year 
pred_raw:	 she had your dark suit in greasy wash water all year
pred_gpt:	 she had your dark suit in greasy wash water all year
-----------
label   :	 she had your dark suit in greasy wash water all year 
pred_raw:	 she had your dark suit in greasy wash water all year
pred_gpt:	 she had your dark suit in greasy wash water all year
-----------
label   :	 there are more obvious nymphomaniacs on any privateeye series 
pred_raw:	 ther are more obvious nimpfom many acts on any privite eye series
pred_gpt:	 there are more obvious nimfom many acts on any private eye series
-----------
label   :	 only the best players enjoy popularity 
pred_raw:	 only the best players enjoy popularity
pred_gpt:	 only the best players enjoy popularity
-----------
label   :	 december and january are nice months to spend in miami 
pred_raw:	 tdisember and january ar nice months toespend in my amy
pred_gpt:	 disember and january are nice months to sp


