In [1]:
from pyannote.audio import Pipeline
from transformers import AutoProcessor, WhisperForConditionalGeneration
import pandas as pd
import torchaudio
import os
import torch
import openai
import os
import pandas as pd
import cv2
from moviepy.editor import VideoFileClip
from pydub import AudioSegment

import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
openai_api_key = "<YOUR_OPENAI_API_KEY>"
huggingface_token = "<YOUR_HUGGINGFACE_TOKEN>"

os.environ["OPENAI_API_KEY"] = openai_api_key
openai.api_key = openai_api_key

In [3]:
class DialogueExtractor:
    def __init__(self, audio_clip_path, transcription_type='api'):
        self.audio_path = audio_clip_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.speaker_diarization_model = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token=huggingface_token
        ).to(self.device)
        
        if transcription_type == 'api':
            self.processor = None
            self.local_transcriber = None
        elif transcription_type == 'local':
            self.processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
            self.local_transcriber = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3").to(self.device)
        else:
            raise ValueError("Invalid transcription type. Choose either 'api' or 'local'")
        
        self.waveform, self.sample_rate = torchaudio.load(self.audio_path)
    
    def _diarize(self):
        output = self.speaker_diarization_model({'waveform': self.waveform, 'sample_rate': self.sample_rate}, 
                                                min_speakers=2)
        segments = list(output.itersegments())
        
        diarization = pd.DataFrame(columns=['start', 'end', 'speaker'])
        for segment in segments:
            speakers = output.get_labels(segment)
            for speaker in speakers:
                start = segment.start
                end = segment.end
                
                diarization.loc[len(diarization)] = [start, end, speaker]

        return diarization

    def _transcribe_api(self, clip_path):
        with open(clip_path, "rb") as audio_file:
            transcription = openai.audio.transcriptions.create(
                model="whisper-1",
                file=audio_file,
                language="en"
            )

        return transcription.text

    def _transcribe_local(self, audio_path):
        wav, sr = torchaudio.load(audio_path)
        wav = wav.mean(dim=0).numpy()
        inputs = self.processor(wav, return_tensors="pt", sampling_rate=self.sample_rate).to(self.device)
        input_features = inputs.input_features
        seq = self.local_transcriber.generate(inputs=input_features)

        transcription = self.processor.batch_decode(seq, skip_special_tokens=True)[0]

        return transcription
    
    def _transcribe(self, audio_path):
        if self.local_transcriber is not None:
            return self._transcribe_local(audio_path)
        else:
            return self._transcribe_api(audio_path)

    def extract_dialogue(self):
        diarization_df = self._diarize()
        
        transcription_df = diarization_df.copy()
        transcription_df['transcription'] = None
        for i in range(len(diarization_df)):
            start = diarization_df.loc[i, 'start']
            end = diarization_df.loc[i, 'end']
            
            wav = self.waveform[:, int(start*self.sample_rate) : int(end*self.sample_rate)].clone()
            
            tmp_path = f"tmp/temp_{i}.wav"
            os.makedirs(os.path.dirname(tmp_path), exist_ok=True)
            torchaudio.save(tmp_path, wav, self.sample_rate)
            transcription = self._transcribe(tmp_path)
            # os.remove(tmp_path)

            transcription_df.loc[i, 'transcription'] = transcription
        
        return transcription_df

In [4]:
audio_clip_path = '../sample_audio_clips/clip1.wav'
extractor = DialogueExtractor(audio_clip_path, transcription_type='api')
extractor.extract_dialogue()

Unnamed: 0,start,end,speaker,transcription
0,0.030969,3.692844,SPEAKER_00,get in touch with them and ask them how they d...
1,3.152844,5.684094,SPEAKER_01,"Yeah, it would be great. They might get a rude..."
2,4.114719,4.435344,SPEAKER_00,great thing.
3,6.797844,9.970344,SPEAKER_01,They might have to prepare themselves or I'm s...


In [5]:
audio_clip_path = '../sample_audio_clips/clip2.wav'
extractor = DialogueExtractor(audio_clip_path, transcription_type='api')
extractor.extract_dialogue()

Unnamed: 0,start,end,speaker,transcription
0,0.115344,1.414719,SPEAKER_00,We don't like to jinx it.
1,1.870344,3.473469,SPEAKER_01,"What, the interview?"
2,2.832219,4.148469,SPEAKER_00,"Years ago? No, but well..."
3,5.025969,6.814719,SPEAKER_01,"Oh, that's what he said. I hadn't thought of t..."
4,6.021594,6.477219,SPEAKER_00,"Hey, got it."
5,7.995969,9.970344,SPEAKER_00,"No, years ago, when couples were"
