In [1]:
from datasets import load_dataset, Audio, concatenate_datasets
import pandas as pd
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
from datasets import Dataset, load_dataset
import soundfile as sf
import torch
import re
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\*]'
# chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
    return batch


def prepare_dataset(batch):
    audio = batch["audio"]
    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    with processor.as_target_processor():
        batch["labels"] = processor(batch["text"]).input_ids
    return batch


def load_dataset_from_files(data_dir_list:list[str], csv_export_dir:str, split_ratio=0.1, csv_export=True):
    frames = []
    for path in data_dir_list:
        source = os.path.basename(os.path.dirname(path))
        wavfile_data = []
        textfile_data = []
        for (root, dirs, files) in os.walk(path, topdown=True):
            if source == "Rundkast":  # to modify depending on Rundkast cuts folder name
                for fn in files:
                    if fn.endswith(".wav"):
                        wav_id = source + "_" + os.path.splitext(fn)[0]
                        path = os.path.join(root, fn)
                        wavfile_data.append((wav_id, fn, path, source))
                    elif fn.endswith(".txt"):
                        text_id = source + "_" + os.path.splitext(fn)[0]
                        with open(os.path.join(root, fn), encoding="utf-8") as text_file:
                            text = text_file.read()
                        textfile_data.append((text_id, text))
            else:
                for fn in files:
                    if fn.endswith(".wav"):
                        wav_id = source + "_" + os.path.splitext(fn)[0]
                        path = os.path.join(root, fn)
                        wavfile_data.append((wav_id, fn, path, source))
                    elif fn.endswith(".txt-utf8"):
                        text_id = source + "_" + os.path.splitext(fn)[0]
                        with open(os.path.join(root, fn), encoding="utf-8-sig") as text_file:
                            text = text_file.read()
                        textfile_data.append((text_id, text))
        df_wav = pd.DataFrame(wavfile_data, columns=["segment_id", "wav_file", "path", "source"])
        df_wav = df_wav.set_index("segment_id")
        df_text = pd.DataFrame(textfile_data, columns=["segment_id", "text"])
        df_text = df_text.set_index("segment_id")
        dataset_df = df_wav.merge(df_text, left_index=True, right_index=True)
        frames.append(dataset_df)
    # concat to full dataframe and convert to Dataset with special characters removed
    full_dataset_df = pd.concat(frames)
    raw_dataset = Dataset.from_pandas(full_dataset_df)
    raw_dataset = raw_dataset.map(remove_special_characters)
    # split dataset
    raw_dataset = raw_dataset.train_test_split(test_size=split_ratio)
    # save copy of dataset
    if csv_export is True:
        df_train = pd.DataFrame(raw_dataset["train"])
        df_train.to_csv(os.path.join(csv_export_dir, "train_set.csv"))
        df_dev = pd.DataFrame(raw_dataset["test"])
        df_dev.to_csv(os.path.join(csv_export_dir, "dev_set.csv"))
    # loading audio
    dataset = raw_dataset.cast_column("path", Audio())
    dataset = dataset.rename_column("path", "audio")
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))
    # preprocess dataset
    # dataset = dataset.map(prepare_dataset,
    #                       remove_columns=dataset.column_names["train"],
    #                       num_proc=4)
    return raw_dataset, dataset

In [3]:
data_dir_list = ["../../datasets/NordTrans_TUL/train_small/Stortinget/",
                 "../../datasets/NordTrans_TUL/train_small/NRK/",
                 "../../datasets/NordTrans_TUL/train_small/Rundkast/"]

# data_dir_list = ["../../datasets/NordTrans_TUL/train_small/Stortinget/"]

In [4]:
csv_export_dir = "./code_trial/"

raw_dataset, dataset = load_dataset_from_files(data_dir_list, csv_export_dir, split_ratio=0.1, csv_export=True)

100%|██████████| 24300/24300 [00:01<00:00, 21744.88ex/s]


In [5]:
dataset.map(remove_special_characters)

100%|██████████| 21870/21870 [00:01<00:00, 15073.81ex/s]
100%|██████████| 2430/2430 [00:00<00:00, 14730.58ex/s]


DatasetDict({
    train: Dataset({
        features: ['wav_file', 'audio', 'source', 'text', 'segment_id'],
        num_rows: 21870
    })
    test: Dataset({
        features: ['wav_file', 'audio', 'source', 'text', 'segment_id'],
        num_rows: 2430
    })
})

In [35]:
sample = dataset["train"][1]
sample

{'wav_file': '12232017_001.wav',
 'audio': {'path': '../../datasets/NordTrans_TUL/train_small/NRK/12232017_001.wav',
  'array': array([0.02038574, 0.01919556, 0.01751709, ..., 0.00527954, 0.00469971,
         0.00375366], dtype=float32),
  'sampling_rate': 16000},
 'source': 'NRK',
 'text': 'gjøre mine vurderinger ',
 'segment_id': 'NRK_12232017_001'}

In [84]:
model_name = "NbAiLab/nb-wav2vec2-300m-bokmaal"
# model_name = "KBLab/wav2vec2-large-voxrex"
# processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
processor = Wav2Vec2Processor.from_pretrained(model_name)

# dataset = dataset.map(prepare_dataset,
#                         remove_columns=dataset.column_names["train"],
#                         num_proc=4)
inputs = processor(sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt")
label = sample["text"]

In [85]:
processor.tokenizer.pad_token_id

31

In [79]:
model = Wav2Vec2ForCTC.from_pretrained(
    "KBLab/wav2vec2-large-voxrex",
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
)

Some weights of the model checkpoint at KBLab/wav2vec2-large-voxrex were not used when initializing Wav2Vec2ForCTC: ['project_hid.weight', 'project_hid.bias', 'project_q.bias', 'quantizer.weight_proj.weight', 'project_q.weight', 'quantizer.weight_proj.bias', 'quantizer.codevectors']
- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at KBLab/wav2vec2-large-voxrex and are newly initialized: ['lm_head.weight', 'lm_head.bias']
You should probably TRAIN this model on a down-stream task to be able to use it fo

In [80]:
with torch.no_grad():
  logits = model(**inputs).logits

In [81]:
logits.shape

torch.Size([1, 60, 32])

In [82]:
logits

tensor([[[ 0.1931, -0.2910, -0.3430,  ...,  0.3020,  0.2800, -0.1646],
         [ 0.1936, -0.2956, -0.3336,  ...,  0.2900,  0.2811, -0.1608],
         [ 0.1902, -0.2981, -0.3339,  ...,  0.2827,  0.2760, -0.1666],
         ...,
         [ 0.1883, -0.3031, -0.3290,  ...,  0.2887,  0.2919, -0.1822],
         [ 0.1988, -0.3122, -0.3347,  ...,  0.2893,  0.2938, -0.1748],
         [ 0.2029, -0.3056, -0.3337,  ...,  0.2898,  0.2984, -0.1819]]])

In [83]:
transcription = processor.batch_decode(logits.numpy()).text
transcription[0].lower()

ValueError: Input logits of size 32, but vocabulary is size 34

In [62]:
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)

print(predicted_ids)
print(transcription)
print(label)

tensor([[24, 10, 10, 10, 10, 10, 10, 10, 24, 10, 10, 10, 10, 10, 24, 10, 10, 10,
         10, 10, 10, 24, 24, 24, 24, 24, 24, 24, 10, 10, 29, 10, 24, 10, 10, 10,
         10, 10, 10, 10, 10, 10, 24, 10, 29, 29, 24, 10, 29, 24, 24, 24, 24, 24,
         10, 10, 10, 10, 10, 10]])
['xjxjxjxjøjxjxjøxjøxj']
gjøre mine vurderinger 


In [6]:
def extract_all_chars(batch):
  all_text = " ".join(batch["text"])
  vocab = list(set(all_text))
  return {"vocab": [vocab], "all_text": [all_text]}

In [7]:
vocabs = dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=dataset.column_names["train"])

100%|██████████| 1/1 [00:00<00:00,  2.61ba/s]
100%|██████████| 1/1 [00:00<00:00, 26.59ba/s]


In [8]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'2': 0,
 'b': 1,
 '1': 2,
 'k': 3,
 'è': 4,
 "'": 5,
 '3': 6,
 'ó': 7,
 'r': 8,
 'v': 9,
 't': 10,
 'u': 11,
 'l': 12,
 'å': 13,
 'p': 14,
 'm': 15,
 'æ': 16,
 'f': 17,
 'w': 18,
 'c': 19,
 'o': 20,
 'í': 21,
 'n': 22,
 ' ': 23,
 '–': 24,
 '4': 25,
 'ü': 26,
 'd': 27,
 'g': 28,
 'z': 29,
 'i': 30,
 'e': 31,
 '`': 32,
 'q': 33,
 'ö': 34,
 'y': 35,
 'ä': 36,
 'a': 37,
 'x': 38,
 '6': 39,
 'á': 40,
 'h': 41,
 'ø': 42,
 'é': 43,
 'j': 44,
 '9': 45,
 'ò': 46,
 's': 47}

In [61]:
vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [62]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
print(len(vocab_dict))

52


In [14]:
processor.tokenizer.vocab_size

32

In [86]:
for i in range(32):
    print(i, processor.tokenizer.convert_ids_to_tokens(i))

0 |
1 a
2 b
3 c
4 d
5 e
6 f
7 g
8 h
9 i
10 j
11 k
12 l
13 m
14 n
15 o
16 p
17 q
18 r
19 s
20 t
21 u
22 v
23 w
24 x
25 y
26 z
27 å
28 æ
29 ø
30 [UNK]
31 [PAD]
