# "Poleval 2021 through wav2vec2"
> "Trying for pronunciation recovery"

- toc: false
- branch: master
- comments: true
- hidden: true
- categories: [wav2vec2, poleval]


In [1]:
%%capture
!pip install gdown



In [3]:
!gdown  https://drive.google.com/uc?id=1b6MyyqgA9D1U7DX3Vtgda7f9ppkxjCXJ

Downloading...
From: https://drive.google.com/uc?id=1b6MyyqgA9D1U7DX3Vtgda7f9ppkxjCXJ
To: /content/poleval_wav.train.tar.gz
2.14GB [00:38, 55.7MB/s]


In [None]:
%%capture
!tar zxvf poleval_wav.train.tar.gz && rm poleval_wav.train.tar.gz

In [6]:
%%capture
!pip install librosa webrtcvad

In [7]:
#collapse-hide
# VAD wrapper is taken from PyTorch Speaker Verification:
# https://github.com/HarryVolek/PyTorch_Speaker_Verification
# Copyright (c) 2019, HarryVolek
# License: BSD-3-Clause
# based on https://github.com/wiseman/py-webrtcvad/blob/master/example.py
# Copyright (c) 2016 John Wiseman
# License: MIT
import collections
import contextlib
import numpy as np
import sys
import librosa
import wave

import webrtcvad

#from hparam import hparam as hp
sr = 16000

def read_wave(path, sr):
    """Reads a .wav file.
    Takes the path, and returns (PCM audio data, sample rate).
    Assumes sample width == 2
    """
    with contextlib.closing(wave.open(path, 'rb')) as wf:
        num_channels = wf.getnchannels()
        assert num_channels == 1
        sample_width = wf.getsampwidth()
        assert sample_width == 2
        sample_rate = wf.getframerate()
        assert sample_rate in (8000, 16000, 32000, 48000)
        pcm_data = wf.readframes(wf.getnframes())
    data, _ = librosa.load(path, sr)
    assert len(data.shape) == 1
    assert sr in (8000, 16000, 32000, 48000)
    return data, pcm_data
    
class Frame(object):
    """Represents a "frame" of audio data."""
    def __init__(self, bytes, timestamp, duration):
        self.bytes = bytes
        self.timestamp = timestamp
        self.duration = duration


def frame_generator(frame_duration_ms, audio, sample_rate):
    """Generates audio frames from PCM audio data.
    Takes the desired frame duration in milliseconds, the PCM data, and
    the sample rate.
    Yields Frames of the requested duration.
    """
    n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
    offset = 0
    timestamp = 0.0
    duration = (float(n) / sample_rate) / 2.0
    while offset + n < len(audio):
        yield Frame(audio[offset:offset + n], timestamp, duration)
        timestamp += duration
        offset += n


def vad_collector(sample_rate, frame_duration_ms,
                  padding_duration_ms, vad, frames):
    """Filters out non-voiced audio frames.
    Given a webrtcvad.Vad and a source of audio frames, yields only
    the voiced audio.
    Uses a padded, sliding window algorithm over the audio frames.
    When more than 90% of the frames in the window are voiced (as
    reported by the VAD), the collector triggers and begins yielding
    audio frames. Then the collector waits until 90% of the frames in
    the window are unvoiced to detrigger.
    The window is padded at the front and back to provide a small
    amount of silence or the beginnings/endings of speech around the
    voiced frames.
    Arguments:
    sample_rate - The audio sample rate, in Hz.
    frame_duration_ms - The frame duration in milliseconds.
    padding_duration_ms - The amount to pad the window, in milliseconds.
    vad - An instance of webrtcvad.Vad.
    frames - a source of audio frames (sequence or generator).
    Returns: A generator that yields PCM audio data.
    """
    num_padding_frames = int(padding_duration_ms / frame_duration_ms)
    # We use a deque for our sliding window/ring buffer.
    ring_buffer = collections.deque(maxlen=num_padding_frames)
    # We have two states: TRIGGERED and NOTTRIGGERED. We start in the
    # NOTTRIGGERED state.
    triggered = False

    voiced_frames = []
    for frame in frames:
        is_speech = vad.is_speech(frame.bytes, sample_rate)

        if not triggered:
            ring_buffer.append((frame, is_speech))
            num_voiced = len([f for f, speech in ring_buffer if speech])
            # If we're NOTTRIGGERED and more than 90% of the frames in
            # the ring buffer are voiced frames, then enter the
            # TRIGGERED state.
            if num_voiced > 0.9 * ring_buffer.maxlen:
                triggered = True
                start = ring_buffer[0][0].timestamp
                # We want to yield all the audio we see from now until
                # we are NOTTRIGGERED, but we have to start with the
                # audio that's already in the ring buffer.
                for f, s in ring_buffer:
                    voiced_frames.append(f)
                ring_buffer.clear()
        else:
            # We're in the TRIGGERED state, so collect the audio data
            # and add it to the ring buffer.
            voiced_frames.append(frame)
            ring_buffer.append((frame, is_speech))
            num_unvoiced = len([f for f, speech in ring_buffer if not speech])
            # If more than 90% of the frames in the ring buffer are
            # unvoiced, then enter NOTTRIGGERED and yield whatever
            # audio we've collected.
            if num_unvoiced > 0.9 * ring_buffer.maxlen:
                triggered = False
                yield (start, frame.timestamp + frame.duration)
                ring_buffer.clear()
                voiced_frames = []
    # If we have any leftover voiced audio when we run out of input,
    # yield it.
    if voiced_frames:
        yield (start, frame.timestamp + frame.duration)


def VAD_chunk(aggressiveness, path):
    audio, byte_audio = read_wave(path, sr)
    vad = webrtcvad.Vad(int(aggressiveness))
    frames = frame_generator(20, byte_audio, sr)
    frames = list(frames)
    times = vad_collector(sr, 20, 200, vad, frames)
    speech_times = []
    speech_segs = []
    for i, time in enumerate(times):
        start = np.round(time[0],decimals=2)
        end = np.round(time[1],decimals=2)
        j = start
        while j + .4 < end:
            end_j = np.round(j+.4,decimals=2)
            speech_times.append((j, end_j))
            speech_segs.append(audio[int(j*sr):int(end_j*sr)])
            j = end_j
        else:
            speech_times.append((j, end))
            speech_segs.append(audio[int(j*sr):int(end*sr)])
    return speech_times, speech_segs

In [8]:
#collapse-hide
# Based on code from PyTorch Speaker Verification:
# https://github.com/HarryVolek/PyTorch_Speaker_Verification
# Copyright (c) 2019, HarryVolek
# Additions Copyright (c) 2021, Jim O'Regan
# License: MIT
import numpy as np

# wav2vec2's max duration is 40 seconds, using 39 by default
# to be a little safer
def vad_concat(times, segs, max_duration=39.0):
    """
    Concatenate continuous times and their segments, where the end time
    of a segment is the same as the start time of the next
        Parameters:
            times: list of tuple (start, end)
            segs: list of segments (audio frames)
            max_duration: maximum duration of the resulting concatenated
                segments; the kernel size of wav2vec2 is 40 seconds, so
                the default max_duration is 39, to ensure the resulting
                list of segments will fit
        Returns:
            concat_times: list of tuple (start, end)
            concat_segs: list of segments (audio frames)
    """
    absolute_maximum=40.0
    if max_duration > absolute_maximum:
        raise Exception('`max_duration` {:.2f} larger than kernel size (40 seconds)'.format(max_duration))
    # we take 0.0 to mean "don't concatenate"
    do_concat = (max_duration != 0.0)
    concat_seg = []
    concat_times = []
    seg_concat = segs[0]
    time_concat = times[0]
    for i in range(0, len(times)-1):
        can_concat = (times[i+1][1] - time_concat[0]) < max_duration
        if time_concat[1] == times[i+1][0] and do_concat and can_concat:
            seg_concat = np.concatenate((seg_concat, segs[i+1]))
            time_concat = (time_concat[0], times[i+1][1])
        else:
            concat_seg.append(seg_concat)
            seg_concat = segs[i+1]
            concat_times.append(time_concat)
            time_concat = times[i+1]
    else:
        concat_seg.append(seg_concat)
        concat_times.append(time_concat)
    return concat_times, concat_seg

In [9]:
def make_dataset(concat_times, concat_segs):
  starts = [s[0] for s in concat_times]
  ends = [s[1] for s in concat_times]
  return {'start': starts,
          'end': ends,
          'speech': concat_segs}

In [10]:
%%capture
!pip install datasets

In [11]:
from datasets import Dataset

def vad_to_dataset(path, max_duration):
    t,s = VAD_chunk(3, path)
    if max_duration > 0.0:
        ct, cs = vad_concat(t, s, max_duration)
        dset = make_dataset(ct, cs)
    else:
        dset = make_dataset(t, s)
    return Dataset.from_dict(dset)

In [12]:
%%capture
!pip install -q transformers

In [13]:
%%capture
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
# load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained("mbien/wav2vec2-large-xlsr-polish")
model = Wav2Vec2ForCTC.from_pretrained("mbien/wav2vec2-large-xlsr-polish")
model.to("cuda")

In [14]:
def speech_file_to_array_fn(batch):
    import torchaudio
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = speech_array[0].numpy()
    batch["sampling_rate"] = sampling_rate
    batch["target_text"] = batch["sentence"]
    return batch
def evaluate(batch):
  import torch
  inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

  with torch.no_grad():
    logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits

  pred_ids = torch.argmax(logits, dim=-1)
  batch["pred_strings"] = processor.batch_decode(pred_ids)
  return batch

In [15]:
import json
def process_wave(filename, duration):
    import json
    dataset = vad_to_dataset(filename, duration)
    result = dataset.map(evaluate, batched=True, batch_size=16)
    speechless = result.remove_columns(['speech'])
    d=speechless.to_dict()
    tlog = list()
    for i in range(0, len(d['end']) - 1):
        out = dict()
        out['start'] = d['start'][i]
        out['end'] = d['end'][i]
        out['transcript'] = d['pred_strings'][i]
        tlog.append(out)
    with open('{}.tlog'.format(filename), 'w') as outfile:
        json.dump(tlog, outfile)  

In [None]:
import glob
for f in glob.glob('/content/poleval_final_dataset_wav/train/*.wav'):
    print(f)
    process_wave(f, 10.0)

/content/poleval_final_dataset_wav/train/wikinews188083.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0014565.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews228271.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews231565.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews187207.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews184725.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0011735.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews227763.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews179902.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews231649.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews186002.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews227295.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews226119.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews182654.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews190354.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0012408.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews180447.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews197423.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews186135.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews186427.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews179671.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews188820.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews183567.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews218231.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews183507.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews226860.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks00514.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0010107.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0015533.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews229910.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews228478.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0017369.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews217449.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews190501.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews188862.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks009596.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks002129.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews199814.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews229327.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks005305.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0013853.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews190255.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews200233.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews230750.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews229709.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews179110.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews179245.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews229300.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews183797.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews188185.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews183290.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikitalks0012169.wav


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews184315.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews188912.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews209857.wav


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews188998.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews186061.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


/content/poleval_final_dataset_wav/train/wikinews228031.wav


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

In [None]:
!find . -name '*tlog'|zip poleval-train.zip -@