In [1]:
import os
import io
import sys
import matplotlib.pyplot as plt
import IPython.display as ipd
import pandas as pd
import re
import subprocess
import numpy as np
from tqdm import tqdm

%load_ext autoreload
%autoreload 2
%matplotlib inline

sys.path.append('../src')

In [None]:
import torch
from transformers import logging
logging.set_verbosity_error()
from soft.models.models import AudioFeatureExtractor

vad_model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                  model='silero_vad',
                                  force_reload=True,
                                  onnx=False)

(get_speech_timestamps, _, read_audio, _, _) = utils


def convert_video_to_audio(file_path: str, sr: int = 16000) -> str:
    path_save = file_path.split('.')[0] + ".wav"
    if not os.path.exists(path_save):
        ffmpeg_command = f"ffmpeg -y -i {file_path} -async 1 -vn -acodec pcm_s16le -ar {sr} {path_save}"
        subprocess.call(ffmpeg_command, shell=True)

    return path_save


def readetect_speech(file_path: str, 
                     sr: int = 16000) -> list[dict]:
    wav = read_audio(file_path, sampling_rate=sr)
    # get speech timestamps from full audio file
    speech_timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sr)
    
    return wav, speech_timestamps


def find_intersections(x: list[dict], y: list[dict], min_length: float = 0) -> list[dict]:
    """Find intersections of two lists of dicts with intervals, preserving structure of `x` and adding intersection info.

    Args:
        x (list[dict]): First list of intervals
        y (list[dict]): Second list of intervals
        min_length (float, optional): Minimum length of intersection. Defaults to 0.

    Returns:
        list[dict]: Windows with intersections, maintaining structure of `x`, and indicating intersection presence.
    """
    timings = []
    j = 0

    for interval_x in x:
        original_start = interval_x['start']
        original_end = interval_x['end']
        intersections_found = False

        while j < len(y) and y[j]['end'] < original_start:
            j += 1  # Skip any intervals in `y` that end before the current interval in `x` starts

        # Check for all overlapping intervals in `y`
        temp_j = j  # Temporary pointer to check intersections within `y` for current `x`
        while temp_j < len(y) and y[temp_j]['start'] <= original_end:
            # Calculate the intersection between `x[i]` and `y[j]`
            intersection_start = max(original_start, y[temp_j]['start'])
            intersection_end = min(original_end, y[temp_j]['end'])

            if intersection_start < intersection_end and (intersection_end - intersection_start) >= min_length:
                timings.append({
                    'original_start': original_start,
                    'original_end': original_end,
                    'start': intersection_start,
                    'end': intersection_end,
                    'speech': True
                })
                intersections_found = True

            temp_j += 1  # Move to the next interval in `y` for further intersections

        # If no intersections were found, add the interval with `intersected` set to False
        if not intersections_found:
            timings.append({
                'original_start': original_start,
                'original_end': original_end,
                'start': None,
                'end': None,
                'speech': False
            })

    return timings


def slice_audio(start_time: float, end_time: float, 
                win_max_length: float, win_shift: float, win_min_length: float) -> list[dict]:
    """Slices audio on windows

    Args:
        start_time (float): Start time of audio
        end_time (float): End time of audio
        win_max_length (float): Window max length
        win_shift (float): Window shift
        win_min_length (float): Window min length

    Returns:
        list[dict]: List of dict with timings, f.e.: {'start': 0, 'end': 12}
    """    

    if end_time < start_time:
        return []
    elif (end_time - start_time) > win_max_length:
        timings = []
        while start_time < end_time:
            end_time_chunk = start_time + win_max_length
            if end_time_chunk < end_time:
                timings.append({'start': start_time, 'end': end_time_chunk})
            elif end_time_chunk == end_time: # if tail exact `win_max_length` seconds
                timings.append({'start': start_time, 'end': end_time_chunk})
                break
            else: # if tail less then `win_max_length` seconds
                if end_time - start_time < win_min_length: # if tail less then `win_min_length` seconds
                    break
                
                timings.append({'start': start_time, 'end': end_time})
                break

            start_time += win_shift
        return timings
    else:
        return [{'start': start_time, 'end': end_time}]

  from .autonotebook import tqdm as notebook_tqdm
Downloading: "https://github.com/snakers4/silero-vad/zipball/master" to /home/maxim/.cache/torch/hub/master.zip


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
checkpoint_path = 'models/w-AudioModelWT_weights.pth'
afe = AudioFeatureExtractor(checkpoint_path, device)
file_path = 'example_file.mp4'
sr = 16000

new_file_path = convert_video_to_audio(file_path=file_path, sr=sr)
wav, vad_info = readetect_speech(file_path=new_file_path, sr=sr)
audio_windows = slice_audio(start_time=0, end_time=int(len(wav)),
                            win_max_length=int(4 * sr), 
                            win_shift=int(2 * sr), win_min_length=int(2 * sr))

intersections = find_intersections(x=audio_windows, y=vad_info, min_length=int(2 * sr))

res = []
for w_idx, window in enumerate(intersections):
    if not window['speech']:
        res.append({
            'emo': None,
            'sen': None,
            'fea': None,
        })
        continue
        
    wave = wav[window['start']: window['end']].clone()
    predicts, features = afe(wave)

    res.append({
        'emo': predicts['emo'], # ['neutral', 'happy', 'sad', 'anger', 'surprise', 'disgust', 'fear']
        'sen': predicts['sen'], # ['negative', 'neutral', 'positive']
        'fea': features,
    })



In [4]:
res

[{'emo': tensor([0.0063, 0.8051, 0.0654, 0.0577, 0.0274, 0.0291, 0.0091]),
  'sen': tensor([0.0211, 0.9728, 0.0061]),
  'fea': None},
 {'emo': tensor([1.2339e-07, 9.8426e-01, 1.6798e-04, 4.9909e-04, 1.4139e-02, 8.7509e-05,
          8.4590e-04]),
  'sen': tensor([1.5948e-08, 1.3211e-08, 1.0000e+00]),
  'fea': None},
 {'emo': tensor([1.3111e-08, 9.9434e-01, 4.1480e-05, 3.1078e-05, 5.3832e-03, 6.9627e-06,
          1.9564e-04]),
  'sen': tensor([4.3964e-09, 4.1938e-09, 1.0000e+00]),
  'fea': None},
 {'emo': tensor([7.0702e-08, 9.8921e-01, 3.9427e-05, 6.2164e-04, 9.7759e-03, 4.3741e-05,
          3.0579e-04]),
  'sen': tensor([3.6061e-09, 6.8011e-09, 1.0000e+00]),
  'fea': None},
 {'emo': tensor([9.3258e-08, 9.9803e-01, 8.9725e-05, 8.4152e-05, 1.5301e-03, 4.3461e-06,
          2.6527e-04]),
  'sen': tensor([3.5798e-11, 2.6737e-10, 1.0000e+00]),
  'fea': None},
 {'emo': tensor([2.3990e-06, 9.7891e-01, 1.1962e-02, 3.2124e-03, 2.4369e-03, 3.9619e-04,
          3.0830e-03]),
  'sen': tensor([