In [19]:
import math
from typing import List

import torch
import torchaudio
import numpy as np

print(torch.__version__)
print(torchaudio.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

import ctc_segmentation
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, Wav2Vec2CTCTokenizer

1.12.1+cu116
0.12.1+cu113
cuda


In [20]:
# load model, processor and tokenizer
model_name = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
processor = Wav2Vec2Processor.from_pretrained(model_name)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)

In [28]:
# load dummy dataset and read soundfiles
SAMPLERATE = 44100
waveform, srate = torchaudio.load("./cli-test/Nicki Minaj - Anaconda/vocals.wav")
audio = waveform[0, :int(math.ceil(srate*3.5))]
transcripts = ["MY ANACONDA DONT", "MY ANACONDA DONT"]

In [24]:
torch.cuda.empty_cache()
import gc
gc.collect()

2687

In [29]:
CHUNKS = 10

def align_with_transcript(
    audio : np.ndarray,
    transcripts : List[str],
    samplerate : int = SAMPLERATE,
    model : Wav2Vec2ForCTC = model,
    processor : Wav2Vec2Processor = processor,
    tokenizer : Wav2Vec2CTCTokenizer = tokenizer
):
    assert audio.ndim == 1

    w_len = audio.shape[0]
    chunk_len = int(math.ceil(w_len/CHUNKS))
    audio_chunks = [audio[i:min(w_len,i+chunk_len)] for i in range(CHUNKS)]
    logits_chunks = []
    probs_chunks = []
    for i in range(CHUNKS):
        # Run prediction, get logits and probabilities
        inputs = processor(audio_chunks[i], return_tensors="pt", padding="longest")
        with torch.no_grad():
            logits = model(inputs.input_values.to(device)).logits.cpu()[0]
            probs = torch.nn.functional.softmax(logits,dim=-1)
            logits_chunks.append(logits)
            probs_chunks.append(probs)
    logits = torch.cat(logits_chunks)
    probs = torch.cat(probs_chunks)

    # Tokenize transcripts
    vocab = tokenizer.get_vocab()
    inv_vocab = {v:k for k,v in vocab.items()}
    unk_id = vocab["<unk>"]
    
    tokens = []
    for transcript in transcripts:
        assert len(transcript) > 0
        tok_ids = tokenizer(transcript.replace("\n"," ").lower())['input_ids']
        tok_ids = np.array(tok_ids,dtype=np.int)
        tokens.append(tok_ids[tok_ids != unk_id])
    
    # Align
    char_list = [inv_vocab[i] for i in range(len(inv_vocab))]
    config = ctc_segmentation.CtcSegmentationParameters(char_list=char_list)
    config.index_duration = audio.shape[0] / probs.size()[0] / samplerate
    
    ground_truth_mat, utt_begin_indices = ctc_segmentation.prepare_token_list(config, tokens)
    timings, char_probs, state_list = ctc_segmentation.ctc_segmentation(config, probs.numpy(), ground_truth_mat)
    segments = ctc_segmentation.determine_utterance_segments(config, utt_begin_indices, char_probs, timings, transcripts)
    return [{"text" : t, "start" : p[0], "end" : p[1], "conf" : p[2]} for t,p in zip(transcripts, segments)]

print(align_with_transcript(audio, transcripts, srate))

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra

[{'text': 'MY ANACONDA DONT', 'start': 0.003723404255319149, 'end': 2.1781914893617023, 'conf': 0.0}, {'text': 'MY ANACONDA DONT', 'start': 2.1781914893617023, 'end': 3.146276595744681, 'conf': 0.0}]


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  tok_ids = np.array(tok_ids,dtype=np.int)


In [35]:
# load dummy dataset and read soundfiles
TARGET_SONG = "Jimi Hendrix - All Along the Watchtower"
waveform, srate = torchaudio.load("./cli-test/"+TARGET_SONG+"/vocals.wav")
audio = waveform[0]

LYRICS_PATH = "./cli-test/"+TARGET_SONG+".txt"
lyrics_file = open(LYRICS_PATH, 'r')
transcript  = lyrics_file.read()
transcript_lines = transcript.split("\n")
transcripts = ["".join(filter(lambda chr: chr.isalpha() or chr==' ', line)).upper() for line in transcript_lines][:-1] # last one is blank

In [36]:
ctc_output = align_with_transcript(audio, transcripts, srate)

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
It is strongly recommended to pass the ``sampling_ra

In [37]:
print(ctc_output)

[{'text': 'THERE MUST BE SOME KIND OF WAY OUT OF HERE', 'start': 10.02404758259145, 'end': 19.5057778470445, 'conf': 0.0}, {'text': 'SAID THE JOKER TO THE THIEF', 'start': 19.5057778470445, 'end': 21.320268809560268, 'conf': 0.0}, {'text': 'THERES TOO MUCH CONFUSION', 'start': 21.320268809560268, 'end': 36.03216153363811, 'conf': 0.0}, {'text': 'I CANT GET NO RELIEF', 'start': 36.03216153363811, 'end': 44.54575312976209, 'conf': 0.0}, {'text': 'BUSINESSMEN THEY DRINK MY WINE', 'start': 44.54575312976209, 'end': 48.63924474119766, 'conf': 0.0}, {'text': 'PLOWMEN DIG MY EARTH', 'start': 48.63924474119766, 'end': 60.745528443102856, 'conf': 0.0}, {'text': 'NONE WILL LEVEL ON THE LINE', 'start': 60.745528443102856, 'end': 69.76717750873125, 'conf': 0.0}, {'text': 'NOBODY OFFERED HIS WORD', 'start': 69.76717750873125, 'end': 71.16796453179343, 'conf': 0.0}, {'text': 'HEY', 'start': 71.16796453179343, 'end': 71.50183086889633, 'conf': 0.0}, {'text': 'NO REASON TO GET EXCITED', 'start': 71.50

In [34]:
import json
def export_transcript(merged_lines, outfile):
    script = {}
    script['fragments'] = []
    l_id = 0
    for line in merged_lines:
        fragment = {}
        fragment['lines'] = [line['text']]
        fragment['begin'] = line['start']
        fragment['end'] = line['end']
        fragment['language'] = 'eng'
        fragment['children'] = []
        fragment['id'] = l_id
        l_id+=1
        script['fragments'].append(fragment)
    
    json.dump(script, outfile)

with open("ctc_seg_alignment_aatw.json", 'w') as f:
    export_transcript(ctc_output, f)