# WhisperX48

<font size="3">将 WhisperX 部署在 Google Colab 云端上，其目标是减少视频字幕制作过程中听译和打轴的繁重工作。详细信息和帮助文档可查阅 [README](https://github.com/ifeimi/Whisper48/blob/main/README.md) 文件和[我的主页](https://ifeimi.github.io/whisper48/)。  
This IPython Notebook is designed as an implementation of WhisperX on Google Colab. The application serves to reduce the heavy and tedious work in transcription and timestamping in video-subtitling. Detailed information and help document can be found in [README](https://github.com/ifeimi/Whisper48/blob/main/README.md) and on [my website](https://ifeimi.github.io/whisper48/).   
\
请按提示依次执行以下单元格，建议在开始前先将需要转录的音频文件上传到谷歌网盘中。  
Please run the following cells in order according to the help text. It is suggested to upload your audio file to Google Drive first before you start.  
\
联系作者/Contact me: yfwu0202@gmail.com.<font size="3">

In [None]:
#@markdown **1.1 挂载你的谷歌网盘/Mount Google Drive (approx. 0.5 min)** 
#@markdown **</br><font size="2">【重要】:** 务必在"修改"->"笔记本设置"->"硬件加速器"中选择GPU！
#@markdown **</br>【IMPORTANT】:** Make sure you select GPU as hardware accelerator in "Runtime" -> "Change runtime type".</font><br/>
from google.colab import drive
from google.colab import files
import os
import logging
from IPython.display import clear_output 

clear_output()
print('Please allow the connection to Google Drive in the pop-up window')
print('请在弹出窗口中选择同意挂载谷歌云盘')
drive.mount('/drive')
print('Google Drive mounted，please proceed to next step')
print('谷歌云盘挂载完毕，请执行下一步')

In [None]:
#@markdown **1.2 配置运行环境/Setup environment (approx. 3 min)**
# @markdown <br/><font size="2">目前在Colab上安装WhisperX时似乎会出现一些和PyTorch版本相关的[问题](https://github.com/m-bain/whisperX/issues/165)，解决方案已经包括在这里。如果本单元格正常运行结束，只要忽略过程中产生的报错信息即可。我会继续关注这个问题。
# @markdown <br/>Right now there seems to be [some problem](https://github.com/m-bain/whisperX/issues/165) related to PyTorch versions when installing WhisperX on Colab. Solution is already included here. If this cell successfully runs to the end, just ignore the error messages. I will continue monitoring this problem. </font><br/>

! pip install geemap -q
import geemap
! pip install git+https://github.com/m-bain/whisperx.git -q
! pip install light-the-torch -q
! ltt install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1  torchtext==0.14.1
clear_output()
print('Environment is ready，please proceed to next step')
print('运行环境配置完毕，请执行下一步')

In [None]:
#@markdown **1.3 从谷歌云盘选择文件/Select File From Google Drive (approx. 0 min)**
# @markdown <br/><font size="2">从网盘目录中选择要转换的文件(视频/音频），单击选中文件，点击'Select'按钮以确认。
# @markdown <br/>Navigate to the file you want to transcribe, left-click to highlight the file, then click 'Select' button to confirm.
# @markdown <br/>若到这一步才上传文件到谷歌盘，则重复执行本单元格以刷新文件列表。
# @markdown <br/>If the file was not uploaded until this cell, execute this cell again to refresh.</font>

from ipytree import Tree, Node
import ipywidgets as widgets
from ipywidgets import interactive
import os
from google.colab import output 
output.enable_custom_widget_manager()
use_drive = True
global drive_dir
drive_dir = ''

def file_tree():
    # create widgets as a simple file browser
    full_widget = widgets.HBox()
    left_widget = widgets.VBox()
    right_widget = widgets.VBox()

    path_widget = widgets.Text()
    path_widget.layout.min_width = '300px'
    select_widget = widgets.Button(
      description='Select', button_style='primary', tooltip='Select current media file.'
      )
    drive_url = widgets.Output()

    right_widget.children = [select_widget]
    full_widget.children = [left_widget]

    tree_widget = widgets.Output()
    tree_widget.layout.max_width = '300px'
    tree_widget.overflow = 'auto'

    left_widget.children = [path_widget,tree_widget]

    # init file tree
    my_tree = Tree(multiple_selection=False)
    my_tree_dict = {}
    media_names = []

    def select_file(b):
        global drive_dir 
        drive_dir = path_widget.value
        # full_widget.disabled = True
        clear_output()
        print('File selected，please execute next cell')
        print('已选择文件，请执行下个单元格')
    #     if (out_file not in my_tree_dict.keys()) and (out_dir in my_tree_dict.keys()):
    #         node = Node(os.path.basename(out_file))
    #         my_tree_dict[out_file] = node
    #         parent_node = my_tree_dict[out_dir]
    #         parent_node.add_node(node)

    select_widget.on_click(select_file)

    def handle_file_click(event):
        if event['new']:
            cur_node = event['owner']
            for key in my_tree_dict.keys():
                if (cur_node is my_tree_dict[key]) and (os.path.isfile(key)):
                    try:
                        with open(key) as f:
                            path_widget.value = key
                            path_widget.disabled = False
                            select_widget.disabled = False
                            full_widget.children = [left_widget, right_widget]
                    except Exception as e:
                        path_widget.value = key
                        path_widget.disabled = True
                        select_widget.disabled = True

                        return

    def handle_folder_click(event):
        if event['new']:
            full_widget.children = [left_widget]

    # redirect cwd to default drive root path and add nodes
    my_dir = '/drive/MyDrive'
    my_root_name = my_dir.split('/')[-1]
    my_root_node = Node(my_root_name)
    my_tree_dict[my_dir] = my_root_node
    my_tree.add_node(my_root_node)
    my_root_node.observe(handle_folder_click, 'selected')

    for root, d_names, f_names in os.walk(my_dir):
        folders = root.split('/')
        for folder in folders:
            if folder.startswith('.'):
                continue
        for d_name in d_names:
            if d_name.startswith('.'):
                d_names.remove(d_name)
        for f_name in f_names:
            # if f_name.startswith('.'):
            #     f_names.remove(f_name)
            # only add media files
            if f_name.endswith(('mp3','m4a','flac','aac','wav','mp4','mkv','ts','flv')):
                media_names.append(f_name)

        d_names.sort()
        f_names.sort()
        media_names.sort()
        keys = my_tree_dict.keys()

        if root not in my_tree_dict.keys():
          # print(f'root name is {root}') # folder path
          name = root.split('/')[-1] # folder name
          # print(f'folder name is {name}')
          dir_name = os.path.dirname(root) # parent path of folder
          # print(f'dir name is {dir_name}')
          parent_node = my_tree_dict[dir_name]
          node = Node(name)
          my_tree_dict[root] = node
          parent_node.add_node(node)
          node.observe(handle_folder_click, 'selected')

        if len(media_names) > 0:
              parent_node = my_tree_dict[root] # parent folders
              # print(parent_node)
              parent_node.opened = False
              for f_name in media_names:
                  node = Node(f_name)
                  node.icon = 'file' 
                  full_path = os.path.join(root, f_name)
                  # print(full_path)
                  my_tree_dict[full_path] = node
                  parent_node.add_node(node)
                  node.observe(handle_file_click, 'selected')
        media_names.clear()

    with tree_widget:
      tree_widget.clear_output()
      display(my_tree)

    return full_widget


tree= file_tree()
tree


In [None]:
# @markdown **2.1 参数设置/Parameter setting (approx. 0 min)**
# @markdown </br></br><font size="3">**2.1.1 选择上传的文件类型(视频-video/音频-audio）/ Select the type of the file uploaded.**</font><br/>
file_type = "audio"  # @param ["audio","video"]

# @markdown <font size="3">**2.1.2 选择模型和语言 / Model size and language.**</font><br/>
model_size = "large-v2"  # @param ["base","small","medium", "large-v1","large-v2"]
language = "ja"  # @param ["ja","zh","en","fr", "de","es","it","pt","ru"]

# @markdown <font size="3">**2.1.3 分割行 / Split line**</font>
# @markdown <br/><font size="2">Option for split line text by spaces. The splited lines all use the same time stamp, with 'adjust_required' label as remark for manual adjustment.
# @markdown <br/>将存在空格的单行文本分割为多行（多句）。分割后的若干行均临时采用相同时间戳，且添加了adjust_required标记提示调整时间戳避免叠轴
# @markdown <br/>普通分割（Modest): 当空格后的文本长度超过5个字符，则另起一行
# @markdown <br/>全部分割（Aggressive): 只要遇到空格即另起一行
is_split = "No"  # @param ["No","Yes"]
split_method = "Modest"  # @param ["Modest","Aggressive"]

# @markdown <font size="3">**2.1.4 高级设置 / Andvanced settings**（尚不可用/Under development</font>

compression_ratio_threshold = 2.4 # @param {type:"number"}
no_speech_threshold = 0.6 # @param {type:"number"}
logprob_threshold = -1.0 # @param {type:"number"}
condition_on_previous_text = "True" # @param ["True", "False"]

In [None]:
#@markdown **2.2 运行WhisperX/Run WhisperX (approx. ? min)**
#@markdown </br><font size="2">完成后srt文件将自动下载到谷歌云盘中
#@markdown </br>SRT file will be downloaded automatically after finishing.</font><br/>

import os
import subprocess
import torch
import whisperx
import whisper
import time
import pandas as pd
from urllib.parse import quote_plus
from pathlib import Path
import sys
# assert file_name != ""
# assert language != ""

output_dir = os.path.dirname(drive_dir)
try:
    file_name = drive_dir
    file_basename = file_name.split('.')[0]
    output_dir = os.path.dirname(drive_dir)
except Exception as e:
    print(f'error: {e}')


if file_type == "video":
  print('提取音频中 Extracting audio from video file...')
  os.system(f'ffmpeg -i {file_name} -ar 16000 -ac 1 -c:a pcm_s16le {file_basename}.wav')
  print('提取完毕 Done.') 

audio_file = f'{file_name}'
device = "cuda"
torch.cuda.empty_cache()
print('加载模型 Loading model...')
model = whisper.load_model(model_size, device)

# Original whisper transcribe
tic = time.time()
print('识别中 Transcribe in progress...')
result = model.transcribe(audio_file, language= language)

# Load alignment model and metadata
print('加载调整模型 Load alignment model...')
# model_id = "jonatasgrosman/wav2vec2-large-xlsr-53-japanese"
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)

# Align whisper output
print('调整识别结果 Align whisper output...')
result_aligned = whisperx.align(result["segments"], model_a, metadata, audio_file, device)

toc = time.time()
print('识别完毕 Done')
print(f'Time consumpution {toc-tic} s')

#Write SRT file
from whisperx.utils import write_srt
with open(Path(output_dir) / (file_basename + ".srt"), "w", encoding="utf-8") as srt:
    write_srt(result["segments"], file=srt)

'''#Convert SRT to ASS
from srt2ass import srt2ass
assSub = srt2ass(file_basename + ".srt", sub_style, is_split,split_method)
print('ASS subtitle saved as: ' + assSub)
files.download(assSub)
# os.remove(file_basename + ".srt")'''

print('字幕生成完毕 Subtitle generated!')

torch.cuda.empty_cache()

In [None]:
#@title **【实验功能】Experimental Features**

# @markdown **3.1 AI文本翻译/AI Translation (approx. ? min)**
# @markdown </br><font size="2">此功能允许用户使用AI翻译服务对识别的字幕文件做逐行翻译，并以相同的格式生成双语对照字幕。
# @markdown </br>This feature allow users to translate previously transcribed subtitle text line by line using AI translation.
# @markdown </br>Then generate bilingual subtitle files in same sub style.Read documentaion to learn more.</font>

# @markdown <font size="2">我无意在这里扩展太多字幕翻译的相关功能，因为这并不需要用到Google Colab提供的显存资源，在这里运行是一种浪费，反而还容易因为运行时间限制而崩溃。目前已经有很多调用API实现AI翻译的方案，并且这些工具还在迅速发展中。我后续会试验和整理一些放在我的主页上。
# @markdown </br>希望在本地使用字幕翻译功能的用户，推荐尝试 [subtitle-translator-electron](https://github.com/gnehs/subtitle-translator-electron)</font>

# @markdown **ChatGPT:**
# @markdown </br><font size="2">要使用ChatGPT翻译，请填入你自己的OpenAI API Key，目标语言，输出类型，然后执行单元格。
# @markdown </br>Please input your own OpenAI API Key, then execute this cell.</font>
openai_key = '' # @param {type:"string"}
target_language = 'zh-hans'# @param ["zh-hans","english"]
output_format = "ass"  # @param ["ass","srt"]

import sys
import os
import re
import time
import codecs
import regex as re
from pathlib import Path
from tqdm import tqdm
from google.colab import files
from IPython.display import clear_output 

!pip install openai
!pip install pysubs2
import openai
import pysubs2

clear_output()

sub_source = "use_transcribed" 
if sub_source == 'upload_new':
  uploaded = files.upload()
  sub_name = list(uploaded.keys())[0]
  sub_basename = Path(sub_name).stem
elif sub_source == 'use_transcribed':
  sub_name = file_basenames[0] +'.ass'
  sub_basename = file_basenames[0]

# original code
class ChatGPTAPI():
    def __init__(self, key, language):
        self.key = key
        # self.keys = itertools.cycle(key.split(","))
        self.language = language
        self.key_len = len(key.split(","))


    # def rotate_key(self):
    #     openai.api_key = next(self.keys)

    def translate(self, text):
        # print(text)
        # self.rotate_key()
        openai.api_key = self.key
        try:
            completion = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[
                    {
                        "role": "user",
                        # english prompt here to save tokens
                        "content": f"Please help me to translate,`{text}` to {self.language}, please return only translated content not include the origin text",
                    }
                ],
            )
            t_text = (
                completion["choices"][0]
                .get("message")
                .get("content")
                .encode("utf8")
                .decode()
            )
        except Exception as e:
            # TIME LIMIT for open api , pay to reduce the waiting time
            sleep_time = int(60 / self.key_len)
            time.sleep(sleep_time)
            print(e, f"will sleep  {sleep_time} seconds")
            # self.rotate_key()
            openai.api_key = self.key
            completion = openai.ChatCompletion.create(
                model="gpt-3.5-turbo",
                messages=[
                    {
                        "role": "user",
                        "content": f"Please help me to translate,`{text}` to {self.language}, please return only translated content not include the origin text",
                    }
                ],
            )
            t_text = (
                completion["choices"][0]
                .get("message")
                .get("content")
                .encode("utf8")
                .decode()
            )
        # print(t_text)
        return t_text

class SubtitleTranslator():

    def __init__(self, sub_src, model, key, language):
        self.sub_src = sub_src
        self.translate_model = model(key, language)

    def translate_by_line(self):
        sub_trans = pysubs2.load(self.sub_src)
        total_lines = len(sub_trans)
        for line in tqdm(sub_trans,total = total_lines):
            line_trans = self.translate_model.translate(line.text)
            line.text += (r'\N'+ line_trans)
            print(line_trans)

        return sub_trans


clear_output()

translate_model = ChatGPTAPI

assert translate_model is not None, "unsupported model"
OPENAI_API_KEY = openai_key

if not OPENAI_API_KEY:
    raise Exception(
        "OpenAI API key is not provided."
    )

t = SubtitleTranslator(
    sub_src=sub_name,
    model= translate_model,
    key = OPENAI_API_KEY,
    language=target_language)

translation = t.translate_by_line()

#Download ass file

if output_format == 'ass':
  translation.save(sub_basename + '_translation.ass')
  files.download(sub_basename + '_translation.ass')
elif output_format == 'srt':
  translation.save(sub_basename + '_translation.srt')
  files.download(sub_basename + '_translation.srt')

print('双语字幕生成完毕 Translated subtitles generated!')