<a href="https://colab.research.google.com/github/detektor777/colab_list_audio/blob/main/whisper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }
%%capture

!pip install -q git+https://github.com/openai/whisper.git
!pip install -q git+https://github.com/yt-dlp/yt-dlp.git
!pip install -q ipywidgets

!pip install -q torch torchaudio
!pip install -q moviepy

import IPython
IPython.display.clear_output()


In [None]:
#@title ##**Select Audio File** { display-mode: "form" }
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
from google.colab import files
from google.colab import drive

upload_option = "Upload from PC"  #@param ["Upload from PC", "Load from Google Drive Root", "Load from Google Drive"]

audio_file_path = None

def convert_to_wav(input_file, output_file):
    print(f"Converting {input_file} to WAV...")
    video_formats = ['.mp4', '.mkv', '.avi', '.mov']
    if os.path.splitext(input_file)[1].lower() in video_formats:
        audio = AudioSegment.from_file(input_file)
        audio.export(output_file, format="wav")
        print(f"Converted to {output_file}")
        return output_file
    return input_file

if upload_option == "Upload from PC":
    print("Please upload an audio or video file.")
    uploaded = files.upload()
    if uploaded:
        file_name = list(uploaded.keys())[0]
        audio_file_path = file_name
        if not audio_file_path.endswith('.wav'):
            audio_file_path = convert_to_wav(file_name, "converted_audio.wav")
    else:
        print("No file uploaded.")
        audio_file_path = None

elif upload_option == "Load from Google Drive Root":
    drive.mount('/content/drive')
    root_dir = '/content/drive/MyDrive/'

    audio_video_extensions = ['.mp3', '.wav', '.flac', '.aac', '.mp4', '.mkv', '.avi', '.mov']
    files_list = []

    # Рекурсивный обход всех папок в Google Drive
    for dirpath, _, filenames in os.walk(root_dir):
        for f in filenames:
            if os.path.splitext(f)[1].lower() in audio_video_extensions:
                relative_path = os.path.relpath(os.path.join(dirpath, f), root_dir)
                files_list.append(relative_path)

    if not files_list:
        print("No audio or video files found in Google Drive or its subfolders.")
        audio_file_path = None
    else:
        print("Select a file from Google Drive (including subfolders):")

        output = widgets.Output()
        buttons = []

        def on_button_clicked(b):
            global audio_file_path
            with output:
                clear_output()
                selected_file = b.description
                full_path = os.path.join(root_dir, selected_file)
                if os.path.splitext(selected_file)[1].lower() in ['.mp4', '.mkv', '.avi', '.mov']:
                    audio_file_path = convert_to_wav(full_path, "/content/converted_audio.wav")
                else:
                    audio_file_path = full_path
                print(f"Selected file: {audio_file_path}")

        for file in files_list:
            button = widgets.Button(description=file)
            button.on_click(on_button_clicked)
            buttons.append(button)

        display(widgets.VBox(buttons), output)

elif upload_option == "Load from Google Drive":
    drive.mount('/content/drive')
    root_dir = '/content/drive/MyDrive/'

    audio_video_extensions = ['.mp3', '.wav', '.flac', '.aac', '.mp4', '.mkv', '.avi', '.mov']
    files_list = []

    # Рекурсивный обход всех папок в Google Drive
    for dirpath, _, filenames in os.walk(root_dir):
        for f in filenames:
            if os.path.splitext(f)[1].lower() in audio_video_extensions:
                relative_path = os.path.relpath(os.path.join(dirpath, f), root_dir)
                files_list.append(relative_path)

    if not files_list:
        print("No audio or video files found in Google Drive or its subfolders.")
        audio_file_path = None
    else:
        print("Select a file from Google Drive (including subfolders):")

        output = widgets.Output()
        buttons = []

        def on_button_clicked(b):
            global audio_file_path
            with output:
                clear_output()
                selected_file = b.description
                full_path = os.path.join(root_dir, selected_file)
                if os.path.splitext(selected_file)[1].lower() in ['.mp4', '.mkv', '.avi', '.mov']:
                    audio_file_path = convert_to_wav(full_path, "/content/converted_audio.wav")
                else:
                    audio_file_path = full_path
                print(f"Selected file: {audio_file_path}")

        for file in files_list:
            button = widgets.Button(description=file)
            button.on_click(on_button_clicked)
            buttons.append(button)

        display(widgets.VBox(buttons), output)

if audio_file_path:
    print(f"Audio file path set to: {audio_file_path}")
else:
    print("Audio file path not set. Please select a file.")

In [None]:
#@title ##**Run Transcription** { display-mode: "form" }
selected_model = "large"         #@param ["tiny", "base", "small", "medium", "large"]
selected_language = "auto"       #@param ["auto", "ru", "en", "de", "fr", "es", "it", "uk", "zh", "ja"]
selected_format = "srt"         #@param ["txt", "srt"]
show_text = True                #@param {type:"boolean"}

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import whisper
import torch
import gc
import datetime
from IPython.display import display, Audio, clear_output
import ipywidgets as widgets

if 'audio_file_path' not in globals() or audio_file_path is None:
    print("Error: Audio file path not set. Please run the 'Select Audio File' cell first and choose a file.")
    raise SystemExit

def seconds_to_srt_time(sec):
    hours = int(sec // 3600)
    minutes = int((sec % 3600) // 60)
    seconds = int(sec % 60)
    milliseconds = int((sec - int(sec)) * 1000)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"

def extract_audio_segment(audio_path, start_time, end_time):
    from pydub import AudioSegment
    print(f"Extracting segment from {start_time} to {end_time}")
    audio = AudioSegment.from_file(audio_path)
    start_ms = start_time * 1000
    end_ms = end_time * 1000
    segment = audio[start_ms:end_ms]
    print("Segment extracted")
    del audio
    return segment

def retranscribe_segment(audio_path, start, end, whisper_model, lang, progress_label):
    print("Preparing for retranscribe...")
    torch.cuda.empty_cache()
    gc.collect()
    try:
        progress_label.value = "Extracting audio..."
        print("Calling extract_audio_segment...")
        audio_segment = extract_audio_segment(audio_path, start, end)
        temp_file = "temp_segment.wav"
        print("Exporting audio segment to temp file...")
        audio_segment.export(temp_file, format="wav")
        del audio_segment
        progress_label.value = "Transcribing..."
        print("Starting Whisper transcription...")
        if lang == "auto":
            result = whisper_model.transcribe(temp_file)
        else:
            result = whisper_model.transcribe(temp_file, language=lang)
        print("Transcription completed, cleaning up...")
        os.remove(temp_file)
        return result["text"]
    except Exception as e:
        print(f"Error in retranscribe_segment: {str(e)}")
        raise
    finally:
        torch.cuda.empty_cache()
        gc.collect()

torch.cuda.empty_cache()
gc.collect()
clear_output(wait=True)

print("Starting transcription...")

if not os.path.exists(audio_file_path):
    print("Audio file not found. Check the path:", audio_file_path)
    raise SystemExit

print(f"Loading Whisper model: {selected_model}...")
if 'model' not in globals() or model is None:
    try:
        model = whisper.load_model(selected_model)
    except RuntimeError as e:
        print(f"Failed to load model: {e}")
        print("Try using a smaller model (e.g., 'medium' or 'small') or restarting the runtime.")
        torch.cuda.empty_cache()
        gc.collect()
        raise SystemExit
else:
    print("Model already loaded, reusing existing instance.")

combined_text = ""
combined_segments = []

try:
    print("Transcribing the entire file...")
    if selected_language == "auto":
        result = model.transcribe(audio_file_path, verbose=show_text)
    else:
        result = model.transcribe(audio_file_path, language=selected_language, verbose=show_text)
    combined_text = result["text"]
    combined_segments = result.get("segments", [])
    for seg in combined_segments:
        seg["speaker"] = "Speaker"
except Exception as e:
    print(f"Error during initial transcription: {e}")
    if 'model' in globals():
        del model
    torch.cuda.empty_cache()
    gc.collect()
    raise SystemExit

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

def create_subtitle_editor(whisper_model):
    output = widgets.Output()

    def update_srt_file():
        with open(output_file_path, "w", encoding="utf-8") as f:
            for i, seg in enumerate(combined_segments, start=1):
                start_time = seconds_to_srt_time(seg["start"])
                end_time = seconds_to_srt_time(seg["end"])
                speaker = seg.get("speaker", "Unknown")
                text = seg["text"].strip()
                f.write(f"{i}\n")
                f.write(f"{start_time} --> {end_time}\n")
                f.write(f"[{speaker}] {text}\n\n")

    def on_play_clicked(b, seg_idx):
        with output:
            output.clear_output()
            audio_seg = extract_audio_segment(audio_file_path,
                                           combined_segments[seg_idx]["start"],
                                           combined_segments[seg_idx]["end"])
            audio_data = audio_seg.export(format="wav").read()
            del audio_seg
            display(Audio(audio_data, autoplay=True))

    def retranscribe_segment_wrapper(seg_idx, progress_label, whisper_model):
        print(f"Starting retranscribe for segment {seg_idx}")
        try:
            new_text = retranscribe_segment(audio_path=audio_file_path,
                                          start=combined_segments[seg_idx]["start"],
                                          end=combined_segments[seg_idx]["end"],
                                          whisper_model=whisper_model,
                                          lang=selected_language,
                                          progress_label=progress_label)
            combined_segments[seg_idx]["text"] = new_text
            text_boxes[seg_idx].value = new_text
            update_srt_file()
            progress_label.value = "Done"
            print(f"Retranscribe completed for segment {seg_idx}")
        except Exception as e:
            progress_label.value = f"Error: {str(e)}"
            print(f"Retranscribe failed for segment {seg_idx}: {str(e)}")

    def on_retranscribe_clicked(b, seg_idx, progress_label):
        with output:
            output.clear_output()
            progress_label.value = "Starting..."
            print("Button clicked, launching retranscribe...")
            retranscribe_segment_wrapper(seg_idx, progress_label, whisper_model)

    def on_text_changed(change, seg_idx):
        combined_segments[seg_idx]["text"] = change.new
        update_srt_file()

    text_boxes = []
    for i, seg in enumerate(combined_segments):
        start_time = seconds_to_srt_time(seg["start"])
        end_time = seconds_to_srt_time(seg["end"])
        time_label = widgets.Label(value=f'{start_time} --> {end_time}',
                                 layout={'width': '250px'})

        text_box = widgets.Textarea(
            value=seg["text"],
            layout={'width': '500px'}
        )
        text_box.observe(lambda change, idx=i: on_text_changed(change, idx), names='value')

        play_button = widgets.Button(description="Play")
        play_button.on_click(lambda b, idx=i: on_play_clicked(b, idx))

        retranscribe_button = widgets.Button(description="Retranscribe")
        progress_label = widgets.Label(value="", layout={'width': '150px'})

        retranscribe_button.on_click(lambda b, idx=i, pl=progress_label: on_retranscribe_clicked(b, idx, pl))

        hbox = widgets.HBox([time_label, text_box, play_button, retranscribe_button, progress_label])
        text_boxes.append(text_box)
        display(hbox)

    display(output)

if selected_format == "txt":
    output_file_path = f"transcription_{timestamp}.txt"
    with open(output_file_path, "w", encoding="utf-8") as f:
        for seg in combined_segments:
            speaker = seg.get("speaker", "Unknown")
            text = seg["text"].strip()
            f.write(f"[{speaker}] {text}\n")
    print(f"Transcription completed. Result in file {output_file_path}")

elif selected_format == "srt":
    output_file_path = f"transcription_{timestamp}.srt"
    with open(output_file_path, "w", encoding="utf-8") as f:
        for i, seg in enumerate(combined_segments, start=1):
            start_time = seconds_to_srt_time(seg["start"])
            end_time = seconds_to_srt_time(seg["end"])
            speaker = seg.get("speaker", "Unknown")
            text = seg["text"].strip()
            f.write(f"{i}\n")
            f.write(f"{start_time} --> {end_time}\n")
            f.write(f"[{speaker}] {text}\n\n")
    print(f"Transcription completed. Result in file {output_file_path}")

    print("\nInteractive Subtitle Editor:")
    create_subtitle_editor(model)

else:
    print("Unknown format. Check the selected format settings.")

print("Cleaning up memory at the end...")
if 'model' in globals():
    del model
torch.cuda.empty_cache()
gc.collect()

In [None]:
#@title ##**Download** { display-mode: "form" }

from google.colab import files
import os

if "output_file_path" in globals() and output_file_path and os.path.exists(output_file_path):
    files.download(output_file_path)
    print(f"File {output_file_path} download to your computer has started.")
else:
    print("Result file not found. Please complete Step 3 first.")
