## Imports

In [None]:
import logging
import torch
from unspoken import exceptions
from unspoken.services import db
from unspoken.core.loader import prepare_models
from unspoken.enitites.transcription import TranscriptionResult
from unspoken.enitites.speach_to_text import SpeachToTextResult
from unspoken.services.ml.transcriber import Transcriber
from unspoken.services.audio.converter import convert_to_wav, preprocess_audio
import matplotlib.pyplot as plt
import librosa
import librosa.display
import numpy as np
import io
from IPython.display import Audio, display as ip_display
import pandas as pd
import seaborn as sns

logger = logging.getLogger(__name__)

# Setup

In [None]:
prepare_models()
print('Models initialized.')

# Helper Functions

In [None]:
def clear_cuda_cache(func):
    def wrapper(*args, **kwargs):
        torch.cuda.empty_cache()
        result = func(*args, **kwargs)
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
        return result

    return wrapper


def _convert_audio(source_file_data: bytes):
    wav_data = convert_to_wav(source_file_data)
    return wav_data


@clear_cuda_cache
def _transcribe_audio(wav_data: bytes) -> SpeachToTextResult:
    result = Transcriber().transcribe(wav_data)
    return result


def visualize_audio(wav_data: bytes, title: str):
    # Convert bytes to numpy array
    audio, sr = librosa.load(io.BytesIO(wav_data), sr=None)

    plt.figure(figsize=(12, 4))
    librosa.display.waveshow(audio, sr=sr)
    plt.title(title)
    plt.xlabel('Time')
    plt.ylabel('Amplitude')
    plt.show()


def visualize_spectrogram(wav_data: bytes, title: str):
    # Convert bytes to numpy array
    audio, sr = librosa.load(io.BytesIO(wav_data), sr=None)

    # Compute spectrogram
    D = librosa.stft(audio)
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)

    plt.figure(figsize=(12, 4))
    librosa.display.specshow(S_db, sr=sr, x_axis='time', y_axis='hz')
    plt.colorbar(format='%+2.0f dB')
    plt.title(title)
    plt.show()


def convert_transcription_result_to_pd(transcription_result):
    data = []
    for segment in transcription_result.segments:
        data.append(
            {'text': segment.text, 'start': segment.start, 'end': segment.end, 'length': segment.end - segment.start}
        )

    df = pd.DataFrame(data)
    return df


def analyze_segment_lengths(transcription_result: SpeachToTextResult):
    # Calculate segment lengths
    segment_lengths = [segment.end - segment.start for segment in transcription_result.segments]

    # Create a DataFrame
    df = pd.DataFrame({'segment_id': range(len(segment_lengths)), 'length': segment_lengths})

    # Calculate average length
    average_length = df['length'].mean()

    # Visualize
    plt.figure(figsize=(12, 6))
    sns.histplot(data=df, x='length', kde=True)
    plt.axvline(average_length, color='r', linestyle='--', label=f'Average: {average_length:.2f}s')
    plt.title('Distribution of Segment Lengths')
    plt.xlabel('Segment Length (seconds)')
    plt.ylabel('Count')
    plt.legend()
    plt.show()

    # Print summary statistics
    print(df['length'].describe())

    return df

# Database Operations

In [None]:
def _save_to_database(
    task_id: int,
    annotated_transcription: TranscriptionResult,
) -> None:
    with db.Session() as session:
        task = db.get_task(task_id, session)
        if not task:
            raise exceptions.TaskNotFoundError(f'Task with id: {task_id} was not found.')

        db.save_transcription_result(
            task.transcript_id,
            annotated_transcription.model_dump(),
            session=session,
        )

        logger.info('Saving speakers to database.')
        speakers = dict()
        for speaker in annotated_transcription.speakers:
            created_speaker = db.create_speaker(
                name=speaker,
                task_id=task.id,
                session=session,
            )
            speakers[speaker] = created_speaker.id
        logger.info('Saved %s speakers.', len(speakers.keys()))

        logger.info('Saving messages to database.')
        messages_to_save = []
        for message in annotated_transcription.messages:
            speaker_id = None
            if message.speaker in speakers:
                speaker_id = speakers[message.speaker]
            messages_to_save.append(
                db.Message(
                    speaker_id=speaker_id,
                    task_id=task.id,
                    text=message.text,
                    start_time=message.start,
                    end_time=message.end,
                )
            )
        db.save_messages(messages_to_save, session=session)
    logger.info('Saved %s messages.', len(messages_to_save))

# Pipeline

In [None]:
with open('data/interview.mp4', 'rb') as f:
    file_data = f.read()

wav_data = convert_to_wav(source_data=file_data)

In [None]:
visualize_audio(wav_data, 'Interview audio after conversion')

In [None]:
visualize_spectrogram(wav_data, 'Interview spectrogram after conversion')

In [None]:
ip_display(Audio(wav_data, rate=16000))

In [None]:
preprocessed_audio = preprocess_audio(wav_data=wav_data)

In [None]:
visualize_audio(preprocessed_audio, 'Interview audio after preprocessing')

In [None]:
visualize_spectrogram(preprocessed_audio, 'Interview spectrogram after preprocessing')

In [None]:
ip_display(Audio(preprocessed_audio, rate=16000))

In [None]:
transcription = _transcribe_audio(preprocessed_audio)

In [None]:
analyze_segment_lengths(transcription)

In [None]:
df = convert_transcription_result_to_pd(transcription)
df