In [None]:
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install matplotlib Cython packaging

!python -m pip install git+https://github.com/NVIDIA/NeMo.git@main#egg=nemo_toolkit[asr]
!python -m pip install pyannote.audio==3.2.0
!pip install pydub librosa soundfile

In [30]:
import torch
import torchaudio
from nemo.collections.asr.models import EncDecRNNTBPEModel
from nemo.collections.asr.modules.audio_preprocessing import (
    AudioToMelSpectrogramPreprocessor as NeMoAudioToMelSpectrogramPreprocessor,
)
from nemo.collections.asr.parts.preprocessing.features import (
    FilterbankFeaturesTA as NeMoFilterbankFeaturesTA,
)
import numpy as np
from pyannote.audio import Pipeline
from pydub import AudioSegment
import os
import librosa
import soundfile as sf
from typing import List, Tuple
from io import BytesIO
import time

In [31]:
# Define constants
BATCH_SIZE = 10

# Custom class for filterbank features
class FilterbankFeaturesTA(NeMoFilterbankFeaturesTA):
    def __init__(self, mel_scale: str = "htk", wkwargs=None, **kwargs):
        if "window_size" in kwargs:
            del kwargs["window_size"]
        if "window_stride" in kwargs:
            del kwargs["window_stride"]

        super().__init__(**kwargs)

        self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = (
            torchaudio.transforms.MelSpectrogram(
                sample_rate=self._sample_rate,
                win_length=self.win_length,
                hop_length=self.hop_length,
                n_mels=kwargs["nfilt"],
                window_fn=self.torch_windows[kwargs["window"]],
                mel_scale=mel_scale,
                norm=kwargs["mel_norm"],
                n_fft=kwargs["n_fft"],
                f_max=kwargs.get("highfreq", None),
                f_min=kwargs.get("lowfreq", 0),
                wkwargs=wkwargs,
            )
        )

# Custom class for audio preprocessing
class AudioToMelSpectrogramPreprocessor(NeMoAudioToMelSpectrogramPreprocessor):
    def __init__(self, mel_scale: str = "htk", **kwargs):
        super().__init__(**kwargs)
        kwargs["nfilt"] = kwargs["features"]
        del kwargs["features"]
        self.featurizer = (
            FilterbankFeaturesTA(  # Deprecated arguments; kept for config compatibility
                mel_scale=mel_scale,
                **kwargs,
            )
        )

class Diarization:
    def __init__(self, start, stop, speaker) -> None:
        self.start = start
        self.stop = stop
        self.speaker = speaker


In [54]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load ASR model
model = EncDecRNNTBPEModel.from_config_file("./rnnt_model_config.yaml")
ckpt = torch.load("./rnnt_model_weights.ckpt", map_location="cpu")
model.load_state_dict(ckpt, strict=False)
model.eval()
model = model.to(device)

# Load voice activity detection pipeline
pipeline_vad = Pipeline.from_pretrained(
        "pyannote/voice-activity-detection", use_auth_token=os.getenv('HF_TOKEN')
    )
pipeline_vad = pipeline_vad.to(torch.device(device))

# Load diarization model
pipeline_diarization = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    use_auth_token=os.getenv('HF_TOKEN'))
pipeline_diarization.to(torch.device(device))

[NeMo I 2024-09-08 07:14:46 nemo_logging:381] Tokenizer SentencePieceTokenizer initialized with 512 tokens


[NeMo W 2024-09-08 07:14:46 nemo_logging:393] Could not load dataset as `manifest_filepath` was None. Provided config : {'shuffle': False, 'manifest_filepath': None}


[NeMo I 2024-09-08 07:14:46 nemo_logging:381] PADDING: 0
[NeMo I 2024-09-08 07:14:47 nemo_logging:381] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2024-09-08 07:14:47 nemo_logging:381] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}
[NeMo I 2024-09-08 07:14:47 nemo_logging:381] Using RNNT Loss : warprnnt_numba
    Loss warprnnt_numba_kwargs: {'fastemit_lambda': 0.0, 'clamp': -1.0}


Lightning automatically upgraded your loaded checkpoint from v1.1.3 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../.cache/torch/pyannote/models--pyannote--segmentation/snapshots/059e96f964841d40f1a5e755bb7223f76666bba4/pytorch_model.bin`


Model was trained with pyannote.audio 0.0.1, yours is 3.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.
Model was trained with torch 1.7.1, yours is 2.4.1. Bad things might happen unless you revert torch to 1.x.


<pyannote.audio.pipelines.speaker_diarization.SpeakerDiarization at 0x7018047b7650>

In [56]:
# Function to convert AudioSegment to numpy array
def audiosegment_to_numpy(audiosegment: AudioSegment) -> np.ndarray:
    """Convert AudioSegment to numpy array."""
    samples = np.array(audiosegment.get_array_of_samples())
    if audiosegment.channels == 2:
        samples = samples.reshape((-1, 2))

    samples = samples.astype(np.float32, order="C") / 32768.0
    return samples

# Function to convert audio to WAV format with 16kHz sampling rate
def convert_to_wav_16k(input_file: str, output_file: str):
  """
  Converts any audio file to WAV format with a sampling rate of 16000 Hz.

  Args:
    input_file: Path to the input audio file (any format supported by librosa).
    output_file: Path to save the output WAV file.
  """
  try:
    # Load the audio file using librosa
    y, sr = librosa.load(input_file, sr=None)  # Load with original sampling rate

    # Resample to 16kHz if necessary
    if sr != 16000:
      y = librosa.resample(y, orig_sr=sr, target_sr=16000)

    # Save as WAV file with 16kHz sampling rate
    sf.write(output_file, y, 16000, subtype='PCM_16') 

  except Exception as e:
    print(f"Error converting file: {e}")

def segment_audio(
    audio_path: str,
    pipeline_: Pipeline,
    max_duration: float = 22.0,
    min_duration: float = 10.0,
    new_chunk_threshold: float = 0.2,
) -> Tuple[List[np.ndarray], List[List[float]]]:
    audio = AudioSegment.from_wav(audio_path)
    audio_bytes = BytesIO()
    audio.export(audio_bytes, format='wav')
    audio_bytes.seek(0)

    # Process audio with pipeline to obtain segments with speech activity
    sad_segments = pipeline_({"uri": "filename", "audio": audio_bytes})

    segments = []
    curr_duration = 0
    curr_start = 0
    curr_end = 0
    boundaries = []

    # Concat segments from pipeline into chunks for asr according to max/min duration
    for segment in sad_segments.get_timeline().support():
        start = max(0, segment.start)
        end = min(len(audio) / 1000, segment.end)
        if (
            curr_duration > min_duration and start - curr_end > new_chunk_threshold
        ) or (curr_duration + (end - curr_end) > max_duration):
            audio_segment = audiosegment_to_numpy(
                audio[curr_start * 1000 : curr_end * 1000]
            )
            segments.append(audio_segment)
            boundaries.append([curr_start, curr_end])
            curr_start = start

        curr_end = end
        curr_duration = curr_end - curr_start

    if curr_duration != 0:
        audio_segment = audiosegment_to_numpy(
            audio[curr_start * 1000 : curr_end * 1000]
        )
        segments.append(audio_segment)
        boundaries.append([curr_start, curr_end])

    return segments, boundaries

# Function to format time in HH:MM:SS format
def format_time(seconds):
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    seconds = seconds % 60
    full_seconds = int(seconds)

    if hours > 0:
        return f"{hours:02}:{minutes:02}:{full_seconds:02}"
    else:
        return f"{minutes:02}:{full_seconds:02}"


In [57]:
from transformers import pipeline
from transformers import AutoTokenizer
import torch
import re

pt = "RUPunct/RUPunct_big"
device = "cuda" if torch.cuda.is_available() else "cpu"
tk = AutoTokenizer.from_pretrained(pt, strip_accents=False, add_prefix_space=True)
classifier = pipeline("ner", model=pt, tokenizer=tk, aggregation_strategy="first", device=device)


def process_token(token, label):
    if label == "LOWER_O":
        return token
    if label == "LOWER_PERIOD":
        return token + "."
    if label == "LOWER_COMMA":
        return token + ","
    if label == "LOWER_QUESTION":
        return token + "?"
    if label == "LOWER_TIRE":
        return token + "—"
    if label == "LOWER_DVOETOCHIE":
        return token + ":"
    if label == "LOWER_VOSKL":
        return token + "!"
    if label == "LOWER_PERIODCOMMA":
        return token + ";"
    if label == "LOWER_DEFIS":
        return token + "-"
    if label == "LOWER_MNOGOTOCHIE":
        return token + "..."
    if label == "LOWER_QUESTIONVOSKL":
        return token + "?!"
    if label == "UPPER_O":
        return token.capitalize()
    if label == "UPPER_PERIOD":
        return token.capitalize() + "."
    if label == "UPPER_COMMA":
        return token.capitalize() + ","
    if label == "UPPER_QUESTION":
        return token.capitalize() + "?"
    if label == "UPPER_TIRE":
        return token.capitalize() + " —"
    if label == "UPPER_DVOETOCHIE":
        return token.capitalize() + ":"
    if label == "UPPER_VOSKL":
        return token.capitalize() + "!"
    if label == "UPPER_PERIODCOMMA":
        return token.capitalize() + ";"
    if label == "UPPER_DEFIS":
        return token.capitalize() + "-"
    if label == "UPPER_MNOGOTOCHIE":
        return token.capitalize() + "..."
    if label == "UPPER_QUESTIONVOSKL":
        return token.capitalize() + "?!"
    if label == "UPPER_TOTAL_O":
        return token.upper()
    if label == "UPPER_TOTAL_PERIOD":
        return token.upper() + "."
    if label == "UPPER_TOTAL_COMMA":
        return token.upper() + ","
    if label == "UPPER_TOTAL_QUESTION":
        return token.upper() + "?"
    if label == "UPPER_TOTAL_TIRE":
        return token.upper() + " —"
    if label == "UPPER_TOTAL_DVOETOCHIE":
        return token.upper() + ":"
    if label == "UPPER_TOTAL_VOSKL":
        return token.upper() + "!"
    if label == "UPPER_TOTAL_PERIODCOMMA":
        return token.upper() + ";"
    if label == "UPPER_TOTAL_DEFIS":
        return token.upper() + "-"
    if label == "UPPER_TOTAL_MNOGOTOCHIE":
        return token.upper() + "..."
    if label == "UPPER_TOTAL_QUESTIONVOSKL":
        return token.upper() + "?!"

def update_punctuation(hyp: str):
    full_text = []
    tmp = re.split(" |\n", hyp)
    part = ''
    for el in tmp:
        if len(part.split()) < 250:
            part += el + ' '
        else:
            full_text.append(part)
            part = ''
    if part != '':
        full_text.append(part)
    full_answer = []
    for el in full_text:
        output = ""
        preds = classifier(el)
        for item in preds:
            output += " " + process_token(item['word'].strip(), item['entity_group'])
        full_answer.append(output.lstrip())
    return ' '.join(full_answer)


In [58]:
# Function to perform ASR inference on the audio file
def audio_inference(file_path: str) -> str:
    preprocessed_audio_path = 'preprocessed_audio.wav'
    convert_to_wav_16k(file_path, preprocessed_audio_path)
    waveform, sample_rate = torchaudio.load(preprocessed_audio_path)
    diarization = pipeline_diarization({"waveform": waveform, "sample_rate": sample_rate})

    # run the diarization pipeline on an audio file
    diarization = pipeline_diarization(preprocessed_audio_path)
    check_list = []
    for turn, _, speaker in diarization.itertracks(yield_label=True):
        if len(check_list) == 0:
            check_list.append(Diarization(turn.start, turn.end, speaker))
        elif speaker != check_list[-1].speaker:
            check_list.append(Diarization(turn.start, turn.end, speaker))
        else:
            check_list[-1].stop = turn.end
    postprocessing_list = []
    for el in check_list:
        if el.stop - el.start > 0.5:
            if len(postprocessing_list) == 0:
                postprocessing_list.append(el)
            elif el.speaker == postprocessing_list[-1].speaker:
                postprocessing_list[-1].stop = el.stop
            else:
                postprocessing_list.append(el)
    speakers = set()
    for el in postprocessing_list:
        speakers.add(el.speaker)
     # run the ASR pipeline on an audio file
    segments, boundaries = segment_audio(preprocessed_audio_path, pipeline_vad)
    transcriptions = model.transcribe(segments, batch_size=BATCH_SIZE)[0]
    full_text = f'{postprocessing_list[0].speaker}:({format_time(postprocessing_list[0].start)}-{format_time(postprocessing_list[0].stop)})\n'
    asr_text = ''
    for transcription, boundary in zip(transcriptions, boundaries):
        try:
            if (boundary[0] >= postprocessing_list[0].start - 0.5) and (boundary[1] <= postprocessing_list[0].stop + 0.5):
                full_text += transcription + '\n'
            else:
                transcription_list = transcription.split(' ')
                if (postprocessing_list[0].stop - boundary[0]) < 0:
                    postprocessing_list.pop(0)
                    full_text += f'\n{postprocessing_list[0].speaker}:({format_time(postprocessing_list[0].start)}-{format_time(postprocessing_list[0].stop)})\n'
                    full_text += transcription
                else: 
                    formula = round( len(transcription_list) * ((postprocessing_list[0].stop - boundary[0]) / (boundary[1] - boundary[0])))
                    part1 = ' '.join(transcription_list[:formula])
                    part2 = ' '.join(transcription_list[formula:])
                    full_text += part1
                    postprocessing_list.pop(0)
                    full_text += f'\n{postprocessing_list[0].speaker}:({format_time(postprocessing_list[0].start)}-{format_time(postprocessing_list[0].stop)})\n'
                    full_text += part2
                asr_text +=f'{transcription}\n'
        except:
            asr_text +=f'{transcription}\n'
        
    return asr_text, full_text, len(speakers)


In [77]:
start_full_time = time.time()
all_texts = []
speakers = []
audio_names = []
for filename in os.listdir('test_dataset_pravitelstvo_test'):
    start = time.time()
    asr, _, count_speakers = audio_inference(os.path.join('test_dataset_pravitelstvo_test', filename))
    asr = update_punctuation(asr)
    all_texts.append(asr)
    speakers.append(count_speakers)
    audio_names.append(filename.split('.')[0])
    print(count_speakers)
    end = time.time()
    print(f'Время на транскрипцию и подсчет спикеров: {end - start}')
    os.remove('preprocessed_audio.wav')
end_full_time = time.time()
print(f'Время: {end_full_time - start_full_time}')

Transcribing: 100%|██████████| 9/9 [00:01<00:00,  5.17it/s]


2
Время на транскрипцию и подсчет спикеров: 93.18853306770325


Transcribing: 100%|██████████| 8/8 [00:02<00:00,  2.93it/s]


3
Время на транскрипцию и подсчет спикеров: 88.49120712280273


Transcribing: 100%|██████████| 6/6 [00:01<00:00,  4.63it/s]


2
Время на транскрипцию и подсчет спикеров: 55.46796894073486


Transcribing: 100%|██████████| 4/4 [00:00<00:00,  6.27it/s]

2
Время на транскрипцию и подсчет спикеров: 20.241917848587036
Время: 257.40371346473694





## Общее время обработки транскрипций: 4 минуты 17 секунды

In [83]:
all_texts

['Уважаемый алексей Геннадьевич, уважаемые коллеги! Мы сегодня продолжаем встречи с фракциями Государственной Думы, которые традиционно проводим перед отчетом правительства в Государственной Думе, и мы обсудили уже с парламентариями целый комплекс вопросов, предложений. Абсолютно уверен, что идеи и наработки, которые будут представлены фракцией НОВЫЕ ЛЮДИ, также будут способствовать конструктивному взаимодействию, нахождению новых эффективных решений. Приоритеты, озвученные депутатами, во многом и определяют наши решения, ведь высказанные здесь точки зрения основаны на запросах людей и, как отмечал президент, важно использовать эту обратную связь для совместной работы и для достижения поставленных целей. Один из весомых примеров подобного сотрудничества— это активное участие в подготовке федерального бюджета на текущие, на два последующих года. Ваша фракция вела активный диалог с Министерством финансов, с Министерством экономического развития, вы продвигали различные предложения по под

In [84]:
speakers


[2, 3, 2, 2]

In [85]:
audio_names

['Встреча 6', 'Встреча 1', 'Встреча 3', 'Встреча 7']

In [86]:
import pandas as pd

df = pd.DataFrame({'Наименование аудиозаписи': audio_names, 
                   'Транскрибированный текст': all_texts,
                   'Число спикеров': speakers,})

df = pd.concat([pd.DataFrame([], columns=df.columns), df], ignore_index=True)

In [87]:
df.to_csv('data_v4.csv', sep=';', index=False)