<a href="https://colab.research.google.com/github/feynlee/whisper2subtitles/blob/main/Whisper2subtitles.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Requirements


The commands below will install the Python packages needed to use Whisper models and evaluate the transcription results.

In [22]:
! pip install -q git+https://github.com/openai/whisper.git
! pip install -q pytube transformers sentencepiece tqdm

In [None]:
#@markdown ### Check Type of GPU and VRAM available.
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader

# Transcribe the Video

In [None]:
#@markdown If `video_path` is a YouTube link, the video will be downloaded at the `save_path`.
video_path = '' #@param {type: 'string'}
#@markdown Choose a Whisper model. `base` is the fastest and uses the least amount of memory.
model_type = 'small'  #@param ["base", "small", "medium", "large"]
#@markdown Video Language Code
video_lang = 'en'   #@param {type: 'string'}
#@markdown Where to save the video and subtitle.
save_path = 'data'  #@param {type: 'string'}
#@markdown What to name the saved video and subtitle.
filename = '' #@param {type: 'string'}
#@markdown Which format to save the subtitle in.
format = 'srt' #@param ["srt", "txt"]


import os
from tqdm import tqdm
import whisper
import numpy as np
from pathlib import Path
from pytube import YouTube


def get_video_from_youtube_url(url, save_path=None, filename=None):
    yt = YouTube(url)
    save_path = Path(save_path)
    save_path.mkdir(exist_ok=True, parents=True)
    video_file = str(save_path/f'{filename}.mp4')
    yt.streams.filter(file_extension='mp4').first().download(filename=video_file)
    return video_file


def transcribe(video, save_path, filename, model_type='small'):
    if video.startswith('http'):
        print("Downloading Youtube Video\n")
        video = get_video_from_youtube_url(video, save_path=save_path, filename=filename
        )
    # predict without timestamps for short-form transcription
    options = whisper.DecodingOptions(fp16=False, language=video_lang)
    model = whisper.load_model(model_type)
    result = model.transcribe(video, **options.__dict__, verbose=False)
    return result, video


def segments_to_srt(segs):
    text = []
    for i,s in tqdm(enumerate(segs)):
        text.append(str(i+1))

        time_start = s['start']
        hours, minutes, seconds = int(time_start/3600), (time_start/60) % 60, (time_start) % 60
        timestamp_start = "%02d:%02d:%06.3f" % (hours, minutes, seconds)
        timestamp_start = timestamp_start.replace('.',',')     
        time_end = s['end']
        hours, minutes, seconds = int(time_end/3600), (time_end/60) % 60, (time_end) % 60
        timestamp_end = "%02d:%02d:%06.3f" % (hours, minutes, seconds)
        timestamp_end = timestamp_end.replace('.',',')        
        text.append(timestamp_start + " --> " + timestamp_end)

        text.append(s['text'].strip() + "\n")
            
    return "\n".join(text)


def convert_to_subtitle(segs):
    if format == 'srt':
        sub = segments_to_srt(segs)
    elif format == 'txt':
        sub = transcribed_text(segs)
    else:
        raise ValueError(f"format {format} is not supported!")
    return sub
    

def save_subtitle(sub, save_path, filename, format='srt'):
    srt_file = os.path.join(save_path, f'{filename}.{format}')
    with open(srt_file, 'w') as f:
        f.write(sub)
    return srt_file


def transcribed_text(segs):
    texts = [s['text'] for s in segs]
    text = '\n'.join(texts)
    return text


print("loading model")
model = whisper.load_model(f'{model_type}')
result, video = transcribe(video_path, save_path, filename, model_type=model_type)
sub = convert_to_subtitle(result['segments'])
sub_transcribed = save_subtitle(sub, save_path, filename+'-sub', format=format)
print(f"subtitle saved: {sub_transcribed}")

# Translate (Optional)

### Pick ONLY ONE of the following methods. If you execute both, the translated text will be overwritten.

In [None]:
#@markdown ## Method 1: Use Facebook's M2M100 model

# translate = "Provide my own translation file"     #@param ["Translate with M2M100", "Provide my own translation file"]
translation_language_code = 'zh'        #@param {type: 'string'}


from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer


def batch_text(result, gs=32):
    """split list into small groups of group size `gs`."""
    segs = result['segments']
    length = len(segs)
    mb = length // gs
    text_batches = []
    for i in range(mb):
        text_batches.append([s['text'] for s in segs[i*gs:(i+1)*gs]])
    if mb*gs != length:
        text_batches.append([s['text'] for s in segs[mb*gs:length]])
    return text_batches


def _translate(text, tokenizer, model_tr, src_lang='en', tr_lang='zh'):
    tokenizer.src_lang = src_lang
    encoded_en = tokenizer(text, return_tensors="pt", padding=True)
    generated_tokens = model_tr.generate(**encoded_en, forced_bos_token_id=tokenizer.get_lang_id(tr_lang))
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)


def batch_translate(texts, tokenizer, model_tr, src_lang='en', tr_lang='zh'):
    translated = []
    for t in tqdm(texts):
        tt = _translate(t, tokenizer, model_tr, src_lang=src_lang, tr_lang=tr_lang)
        translated += tt
    return translated


def translate(result, tr_lang = 'zh'):
    ckpt = 'facebook/m2m100_418M'
    model_tr = M2M100ForConditionalGeneration.from_pretrained(ckpt)
    tokenizer = M2M100Tokenizer.from_pretrained(ckpt)

    texts = batch_text(result, gs=32)
    texts_tr = batch_translate(texts, tokenizer, model_tr, src_lang=result['language'], tr_lang=tr_lang)

    return texts_tr


texts_tr = translate(result, tr_lang=translation_language_code)

In [36]:
#@markdown ## Method 2: Provide Your Own Translation
#@markdown 1. Execute this cell
#@markdown 2. Copy the transcribed text from the box on the left
#@markdown 3. Translate it using other tools (google or deepl or bing) 
#@markdown 4. Copy the translated text into the text box on the right.

#@markdown **Note:** The translated text SHOULD match the original text line by line, otherwise the generated subtitle won't be matched up.

#@markdown **Note:** Executing this cell will overwrite the translation generated in Method 1.
import ipywidgets as widgets
from ipywidgets import TwoByTwoLayout, Layout, HBox, VBox
from IPython.display import display


text = transcribed_text(result['segments'])

txt_ori = widgets.Textarea(
    value=text,
    placeholder='Transcribed Text',
    description='Original:',
    disabled=False,
    layout=Layout(width='90%',height='500px')
)

txt_tr = widgets.Textarea(
    value='',
    placeholder='Put Your Translated Text Here',
    description='Translation:',
    disabled=False,
    layout=Layout(width='90%',height='500px')
)

submit_button=widgets.Button(description='Submit translation',
                             button_style='success',
                             layout=Layout(float='right')
                             )
box_layout = widgets.Layout(display='flex',
                flex_flow='column',
                align_items='flex-end',
                width='90%')
box=HBox(children=[submit_button],layout=box_layout)


def on_button_clicked(b):
    global texts_tr
    texts_tr = [t.strip() for t in txt_tr.value.split('\n')]



submit_button.on_click(on_button_clicked)
TwoByTwoLayout(bottom_left=txt_ori,
               bottom_right=VBox(children=[txt_tr, box])
               )

TwoByTwoLayout(children=(Textarea(value=" The End\n I've invited all the participants from this year's trick o…

In [None]:
#@markdown ## Generate translated subtitles
#@markdown Position of the translation in the generated subtitles: whether to put the translation above the original. If "translation only" is chosen, only the translation will be kept in the subtitle.
translation_position = "top" #@param ["top", "bottom", "translation only"]

if translation_position == "translation only":
    keep_both = False
else:
    keep_both = True

def combine_translated(segs, text_translated, keep_both=True, tr_pos='top'):
    "Combine the translated text into the 'text' field of segments."
    comb = []
    for s, tr in zip(segs, text_translated):
        if keep_both == False:
            c = f"{tr}\n"
        else:
            if tr_pos == 'top':
                c = f"{tr.strip()}\\N{s['text'].strip()}\n"
            else:
                c = f"{s['text'].strip()}\\N{tr.strip()}\n"
        s['text'] = c 
        comb.append(s)
    return comb


segs = combine_translated(result['segments'], texts_tr, 
                          keep_both=keep_both, tr_pos=translation_position)
sub_tr = convert_to_subtitle(segs)
sub_translated = save_subtitle(sub_tr, save_path, 
                               filename+'-translated_sub', 
                               format=format)
sub_translated

# Burn Subtitles Into the Video (Optional)

In [None]:
!apt install -q ffmpeg
!pip install -q ffpb

In [None]:
subtitled = os.path.join(save_path, f'{filename}-subtitled.mp4')
!ffpb -i $video -vf subtitles=$sub_transcribed -y $subtitled

# For the translated subtitles in another language, you may need to get corresponding fonts in order for it to be displayed correctly in video.
# !ffpb -i $video -vf subtitles=$sub_translated -y $subtitled