# WhisperX48

<font size="3">将 WhisperX 部署在 Google Colab 云端上，其目标是减少视频字幕制作过程中听译和打轴的繁重工作。详细信息和帮助文档可查阅 [README](https://github.com/ifeimi/Whisper48/blob/main/README.md) 文件和[我的主页](https://ifeimi.github.io/whisper48/)。  
This IPython Notebook is designed as an implementation of WhisperX on Google Colab. The application serves to reduce the heavy and tedious work in transcription and timestamping in video-subtitling. Detailed information and help document can be found in [README](https://github.com/ifeimi/Whisper48/blob/main/README.md) and on [my website](https://ifeimi.github.io/whisper48/).   
\
请按提示依次执行以下单元格，建议在开始前先将需要转录的音频文件上传到谷歌网盘中。  
Please run the following cells in order according to the help text. It is suggested to upload your audio file to Google Drive first before you start.  
\
联系作者/Contact me: ifeimi48@gmail.com.<font size="3">  

In [None]:
#@markdown **1.1 挂载你的谷歌云盘/Mount Google Drive (approx. 0.5 min)** 
#@markdown **</br><font size="2">【重要】:** 请确认在"修改"->"笔记本设置"->"硬件加速器"中选择了GPU作为硬件加速。
#@markdown **</br>【IMPORTANT】:** Make sure you have selected GPU as hardware accelerator in "Runtime" -> "Change runtime type".</font><br/>
from google.colab import drive
from google.colab import files
import os
import logging
from IPython.display import clear_output 

clear_output()
print('Please allow connecting to Google Drive in the pop-up window')
print('请在弹出窗口中选择同意挂载谷歌云盘')
drive.mount('/drive')
print('Google Drive mounted, please proceed to next step')
print('谷歌云盘挂载完毕，请执行下一步')

In [None]:
#@markdown **1.2 配置运行环境/Setup environment (approx. 2 min)**

! pip3 install ipytree -q
! pip3 install git+https://github.com/ifeimi/whisperx.git -q
! pip3 install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121 -q
! apt-get update
! apt-get install libcudnn8=8.9.2.26-1+cuda12.1
! apt-get install libcudnn8-dev=8.9.2.26-1+cuda12.1
! python -c "import torch; torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True"
! ln -s /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8  # reference: https://github.com/m-bain/whisperX/issues/1027
! pip3 install ctranslate2==4.4.0  # fix the problem caused by new version of ctranslate2 library, reference: https://github.com/m-bain/whisperX/issues/901
clear_output()
! pip3 show torch
! pip3 show torchvision
print('Environment is ready, please proceed to next step')
print('运行环境配置完毕，请执行下一步')

In [None]:
#@markdown **1.3 从谷歌云盘选择文件/Select File From Google Drive (approx. 0 min)**
# @markdown <br/><font size="2">从网盘目录中选择要转换的文件(视频/音频），单击选中文件，点击'Select'按钮以确认。
# @markdown <br/>Navigate to the file you want to transcribe, left-click to highlight the file, then click 'Select' button to confirm.
# @markdown <br/>若到这一步才上传文件到谷歌盘，则重复执行本单元格以刷新文件列表。
# @markdown <br/>If the file was not uploaded until this cell, execute this cell again to refresh.</font>

from ipytree import Tree, Node
import ipywidgets as widgets
from ipywidgets import interactive
import os
from google.colab import output 
output.enable_custom_widget_manager()
use_drive = True
global drive_dir
drive_dir = ''

def file_tree():
    # create widgets as a simple file browser
    full_widget = widgets.HBox()
    left_widget = widgets.VBox()
    right_widget = widgets.VBox()

    path_widget = widgets.Text()
    path_widget.layout.min_width = '300px'
    select_widget = widgets.Button(
      description='Select', button_style='primary', tooltip='Select current media file.'
      )
    drive_url = widgets.Output()

    right_widget.children = [select_widget]
    full_widget.children = [left_widget]

    tree_widget = widgets.Output()
    tree_widget.layout.max_width = '300px'
    tree_widget.overflow = 'auto'

    left_widget.children = [path_widget,tree_widget]

    # init file tree
    my_tree = Tree(multiple_selection=False)
    my_tree_dict = {}
    media_names = []

    def select_file(b):
        global drive_dir 
        drive_dir = path_widget.value
        # full_widget.disabled = True
        clear_output()
        print('File selected，please execute next cell')
        print('已选择文件，请执行下个单元格')
    #     if (out_file not in my_tree_dict.keys()) and (out_dir in my_tree_dict.keys()):
    #         node = Node(os.path.basename(out_file))
    #         my_tree_dict[out_file] = node
    #         parent_node = my_tree_dict[out_dir]
    #         parent_node.add_node(node)

    select_widget.on_click(select_file)

    def handle_file_click(event):
        if event['new']:
            cur_node = event['owner']
            for key in my_tree_dict.keys():
                if (cur_node is my_tree_dict[key]) and (os.path.isfile(key)):
                    try:
                        with open(key) as f:
                            path_widget.value = key
                            path_widget.disabled = False
                            select_widget.disabled = False
                            full_widget.children = [left_widget, right_widget]
                    except Exception as e:
                        path_widget.value = key
                        path_widget.disabled = True
                        select_widget.disabled = True

                        return

    def handle_folder_click(event):
        if event['new']:
            full_widget.children = [left_widget]

    # redirect cwd to default drive root path and add nodes
    my_dir = '/drive/MyDrive'
    my_root_name = my_dir.split('/')[-1]
    my_root_node = Node(my_root_name)
    my_tree_dict[my_dir] = my_root_node
    my_tree.add_node(my_root_node)
    my_root_node.observe(handle_folder_click, 'selected')

    for root, d_names, f_names in os.walk(my_dir):
        folders = root.split('/')
        for folder in folders:
            if folder.startswith('.'):
                continue
        for d_name in d_names:
            if d_name.startswith('.'):
                d_names.remove(d_name)
        for f_name in f_names:
            # if f_name.startswith('.'):
            #     f_names.remove(f_name)
            # only add media files
            if f_name.endswith(('mp3','m4a','flac','aac','wav','mp4','mkv','ts','flv')):
                media_names.append(f_name)

        d_names.sort()
        f_names.sort()
        media_names.sort()
        keys = my_tree_dict.keys()

        if root not in my_tree_dict.keys():
          # print(f'root name is {root}') # folder path
          name = root.split('/')[-1] # folder name
          # print(f'folder name is {name}')
          dir_name = os.path.dirname(root) # parent path of folder
          # print(f'dir name is {dir_name}')
          parent_node = my_tree_dict[dir_name]
          node = Node(name)
          my_tree_dict[root] = node
          parent_node.add_node(node)
          node.observe(handle_folder_click, 'selected')

        if len(media_names) > 0:
              parent_node = my_tree_dict[root] # parent folders
              # print(parent_node)
              parent_node.opened = False
              for f_name in media_names:
                  node = Node(f_name)
                  node.icon = 'file' 
                  full_path = os.path.join(root, f_name)
                  # print(full_path)
                  my_tree_dict[full_path] = node
                  parent_node.add_node(node)
                  node.observe(handle_file_click, 'selected')
        media_names.clear()

    with tree_widget:
      tree_widget.clear_output()
      display(my_tree)

    return full_widget


tree= file_tree()
tree


In [None]:
# @markdown **2.1 参数设置/Parameter setting (approx. 0 min)**
# @markdown </br></br><font size="3">**2.1.1 选择上传的文件类型(视频-video/音频-audio）/ Select the type of the file uploaded.**</font><br/>
file_type = "audio"  # @param ["audio","video"]

# @markdown <font size="3">**2.1.2 选择模型和语言 / Model size and language.**</font><br/>
model_size = "large-v3"  # @param ["base","small","medium", "large-v1","large-v2","large-v3"]
language = "ja"  # @param ["ja","zh","en","fr", "de","es","it","pt","ru"]

# @markdown <font size="3">**2.1.3 高级设置 / Andvanced settings** (如果你不确定这是做什么的，请保持默认)</font>
max_line_width = "None"  # @param {type:"string"}
max_line_count = "None"  # @param {type:"string"}
highlight_words = False  # @param ["False","True"]
chunk_size = 5  # @param {type:"integer"}

if max_line_width is not None:
    try:
        max_line_width = int(max_line_width)
    except ValueError:
        if max_line_width.lower() == "none":
            max_line_width = None
        else:
            print("The max_line_width you entered is not a valid integer value or None.")
            max_line_width = None

if max_line_count is not None:
    try:
        max_line_count = int(max_line_count)
    except ValueError:
        if max_line_count.lower() == "none":
            max_line_count = None
        else:
            print("The max_line_count you entered is not a valid integer value or None.")
            max_line_count = 1

In [None]:
#@markdown **2.2 运行WhisperX/Run WhisperX (approx. ? min)**
#@markdown </br><font size="2">完成后srt文件将自动下载到谷歌云盘中
#@markdown </br>SRT file will be downloaded automatically after finishing.</font><br/>

import os
import subprocess
import torch
import whisperx
import time
import pandas as pd
from urllib.parse import quote_plus
from pathlib import Path
import sys
import gc
# assert file_name != ""
# assert language != ""

output_dir = os.path.dirname(drive_dir)
try:
    file_name = drive_dir
    file_basename = file_name.split('.')[0]
    output_dir = os.path.dirname(drive_dir)
except Exception as e:
    print(f'error: {e}')


if file_type == "video":
  print('提取音频中 Extracting audio from video file...')
  os.system(f'ffmpeg -i {file_name} -ar 16000 -ac 1 -c:a pcm_s16le {file_basename}.wav')
  print('提取完毕 Done.') 
  audio_file = f'{file_basename}.wav'
else:
    audio_file = f'{file_name}'

device = "cuda"
batch_size = 16 # reduce if low on GPU mem
compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)

print('加载模型 Loading model...')
model = whisperx.load_model(model_size, device, compute_type= compute_type, language= language)
audio = whisperx.load_audio(audio_file)

# Original whisper transcribe
tic = time.time()
print('识别中 Transcribe in progress...')
result = model.transcribe(audio, batch_size= batch_size, chunk_size= chunk_size)
print('语音转录完成 Transcribing completed')
gc.collect(); torch.cuda.empty_cache(); del model

#Write SRT file
options = {"max_line_width":max_line_width,"max_line_count":max_line_count,"highlight_words":highlight_words}
from whisperx.utils import WriteSRT
filename_srt = file_basename + "_transcribe.srt"
with open(filename_srt, "w", encoding="utf-8") as srt:
    srt_writer = WriteSRT(filename_srt)
    srt_writer.write_result(result, srt, options)
# files.download(filename_srt)

# Load alignment model and metadata
print('加载调整模型 Load alignment model...')
# model_id = "jonatasgrosman/wav2vec2-large-xlsr-53-japanese"
alignment_model, metadata = whisperx.load_align_model(language_code= language, device= device)

# Align whisper output
print('调整识别结果 Align whisper output...')
result_aligned = whisperx.align(result["segments"], alignment_model, metadata, audio, device, return_char_alignments=False)
result_aligned["language"] = result["language"]

toc = time.time()
print('调整完毕 Alignment done')
print(f'Time consumpution {toc-tic} s')
gc.collect(); torch.cuda.empty_cache(); del alignment_model

#Write SRT file
filename_srt = file_basename + ".srt"
with open(filename_srt, "w", encoding="utf-8") as srt:
    srt_writer = WriteSRT(filename_srt)
    srt_writer.write_result(result_aligned, srt, options)
files.download(filename_srt)

# #Write JSON file
# from whisperx.utils import WriteJSON
# filename_json = file_basename + "_align.json"
# with open(filename_json, "w", encoding="utf-8") as json:
#     json_writer = WriteJSON(filename_json)
#     json_writer.write_result(result, json, options)
# files.download(filename_json)

print('字幕生成完毕 Subtitle generated!')