In [1]:
import torch
import kenlm
import pickle
import numpy as np

from datasets import load_from_disk, load_dataset, load_metric, Dataset, concatenate_datasets

In [2]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor

tokenizer = Wav2Vec2CTCTokenizer("./vocab_new.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
#feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)

#processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

#model = Wav2Vec2ForCTC.from_pretrained("./wav2vec2-large-xlsr-ger-chris-piece/checkpoint-54000")

In [3]:
vocab_dict = tokenizer.get_vocab()

sort_vocab = sorted((value, key) for (key,value) in vocab_dict.items())
vocab = [x[1].replace("|", " ") if x[1] not in tokenizer.all_special_tokens else "_" for x in sort_vocab]
print(vocab)

['o', 'm', 'ö', 'v', 'p', 'y', 'z', 'f', 'd', 'j', 'i', 't', 'r', 'ä', 'n', 'w', 'h', 'l', 'u', 'a', 'x', 's', 'b', 'c', 'ß', 'ü', 'e', 'g', 'q', ' ', 'k', '_', '_', '-']


In [4]:
from ctcdecode import CTCBeamDecoder

labels = ['o', 'm', 'ö', 'v', 'p', 'y', 'z', 'f', 'd', 'j', 'i', 't', 'r', 'ä', 'n', 'w', 'h', 'l', 'u', 'a', 'x', 's', 'b', 'c', 'ß', 'ü', 'e', 'g', 'q', ' ', 'k', '_', '_', '-']

lm_decoder = CTCBeamDecoder(
    labels,
    model_path='lm_data/train_piece.arpa',
    alpha=2.5,
    beta=0,
    cutoff_top_n=300,
    cutoff_prob=np.log(0.00000001),
    beam_width=256,
    num_processes=4,
    blank_id=32,
    log_probs_input=True
)

In [5]:
res_54_log = load_from_disk("lm_res51/results54_piece_log")

In [6]:
vocab_dict = {'o': 0, 'm': 1, 'ö': 2, 'v': 3, 'p': 4, 'y': 5, 'z': 6, 'f': 7, 'd': 8, 'j': 9, 'i': 10, 't': 11, 'r': 12, 'ä': 13, 'n': 14, 'w': 15, 'h': 16, 'l': 17, 'u': 18, 'a': 19, 'x': 20, 's': 21, 'b': 22, 'c': 23, 'ß': 24, 'ü': 25, 'e': 26, 'g': 27, 'q': 28, 'k': 30, ' ': 29, '[UNK]': 31, '[PAD]': 32, "-": 33}
inv_map = {v: k for k, v in vocab_dict.items()}

In [7]:
for i in range(10):
    beam_results, beam_scores, timesteps, out_len = lm_decoder.decode(torch.tensor(res_54_log[i]["lm_raw"]))
    res = ""
    for n in beam_results[0][0][:out_len[0][0]]:
        res = res + inv_map[int(n)]

    print(res)
    print(res_54_log[i]["pred_str"])

-sie -durch -bitte -draußen -die -schuhe -aus 
-zie durch -bittet -draußen -die -schuhe -aus
-dies -grund -zu -schon -bedeuten -ster 
-dies -grund -zun -schon bedauten -ster
-ihre -foto strecken -erschienen -im -mode magazin -wieder -volk -ab -das -basar -war -re cla ir 
-ihre -fote strecken -erschienen -im -mode magazin -wieder -bok -ab -das -basar -wa -ri clar
-ver -lipp ert -eine -auch -für -monarchen -ungewöhnlich -lange -titel liste 
-verlippert -eine -auch -für -monachen -ungewöhnlich -lange -titel liste
-er -wurde -zu -ehren -des -reichskanzler s -otto -von -bismarck -errichtet 
-er -wurde -zu -ehren -des -reichs kanzler s -otto -von -bismark -errichtet
-was -soll s -ich -bin -bereit 
-was -soll s -ich -bin -bereit
-das -internet -besteht -aus -vielen -computer n -die -miteinander -verbunden -sind 
-das -internet -besteht -aus -vielen -computer n -die -miteinander -verbunden -sind
-der -uran us -ist -der -sieben te -planet -in -unsere m -sonnensystem 
-der -uran us -ist -der -si

In [9]:
def infer_lm(batch):
    beam_results, beam_scores, timesteps, out_len = lm_decoder.decode(torch.tensor(batch["lm_raw"]))
    res = ""
    for n in beam_results[0][0][:out_len[0][0]]:
        res = res + inv_map[int(n)]
        
    batch["lm_str"] = res
    return batch

In [10]:
res_54_log = res_54_log.map(infer_lm, num_proc=8)



HBox(children=(FloatProgress(value=0.0, description=' #0', max=1949.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #1', max=1949.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #3', max=1949.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #4', max=1948.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #2', max=1949.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #6', max=1948.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #7', max=1948.0, style=ProgressStyle(description_width='…

HBox(children=(FloatProgress(value=0.0, description=' #5', max=1948.0, style=ProgressStyle(description_width='…











In [11]:
import sentencepiece as spm
s = spm.SentencePieceProcessor(model_file='one.model')

def decode_pieces(batch):
    
    res_back = batch["lm_str"]
    res_back = res_back.replace("-","▁")
    res_back = res_back.split(" ")
    detokenized = s.decode(res_back)
    detokenized = detokenized.replace("▁", " ")
    batch["detokenized"] = detokenized
    
    return batch 

In [12]:
res_54_log = res_54_log.map(decode_pieces)

HBox(children=(FloatProgress(value=0.0, max=15588.0), HTML(value='')))




In [13]:
res_54_log.save_to_disk("lm_kenlm_pred/lm_piece_05")

In [15]:
from datasets import load_metric

wer_metric = load_metric("wer")

print("Test WER: {:.3f}".format(wer_metric.compute(predictions=res_54_log["detokenized"], references=res_54_log["target_text"])))

Test WER: 0.107


In [9]:
wer_metric = load_metric("wer")
print("Full Test WER: {:.3f}".format(wer_metric.compute(predictions=res_51_full_log_2["pred_str"], references=res_51_full_log_2["target_text"])))

Full Test WER: 0.147


In [10]:
print("Full Test WER: {:.3f}".format(wer_metric.compute(predictions=res_51_full_log_2["lm_2_str"], references=res_51_full_log_2["target_text"])))

Full Test WER: 0.126


In [18]:
print("Full Test WER: {:.3f}".format(wer_metric.compute(predictions=res_51_full_full["lm_str"], references=res_51_full_full["target_text"])))

Full Test WER: 0.132


In [11]:
res_51_full_log_2.save_to_disk("lm_full_full_3")

In [19]:
import spacy

nlp = spacy.load("de_dep_news_trf")
nlp.max_length = 17000000

In [24]:
from tqdm import tqdm

results51_full_full_2_trf = []
for i in tqdm(range(len(res_51_full_log_2))):
    results51_full_full_2_trf.append(nlp(res_51_full_log_2[i]['lm_2_str'], disable = ['ner', 'parser']))

  1%|          | 168/15588 [00:18<28:49,  8.92it/s]


KeyboardInterrupt: 