**Xem hướng dẫn tại [đây](https://docs.google.com/document/d/1d3Fp33kRVa50WXEYbYSRs05z2uSFCCyI_pbpnYB0PgU/edit?usp=sharing)**

In [None]:
#@markdown **Kiểm tra GPU** (1 trong 3 GPU này: V100, P100 hoặc T4 thì đã được liên kết GPU từ Colab)
!nvidia-smi -L
!nvidia-smi

In [None]:
#@markdown **Liên kết Google Drive** (bỏ qua nếu up file trực tiếp vào Google Colab)
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@markdown **Cài đặt Whisper**
%pip install -q deepl srt demucs faster-whisper

In [None]:
#@markdown **Load Whisper**
#@markdown <br>
#@markdown Model càng lớn thì sub càng chuẩn nhưng thời gian lâu hơn
#@markdown <br>
#@markdown Riêng model mặc định *large-v2* (khuyên dùng) thời gian nhanh hơn *large-v1*
model_size = "large-v2"  # @param ["tiny","base","small","medium", "large-v1", "large-v2"]

import torch, torchaudio, os, srt, datetime, json, deepl, urllib.request, faster_whisper
from tqdm import tqdm
from google.colab import files as g_files, drive as g_drive
from demucs.pretrained import get_model as demucs_get_model
from demucs.separate import load_track as demucs_load_track
from demucs.apply import apply_model as demucs_apply_model

DEMUCS_MODEL = demucs_get_model("htdemucs").cuda()
WHISPER_MODEL = faster_whisper.WhisperModel(model_size, device="cuda")

PUNCT_MATCH = ["。", "、", ",", ".", "〜", "！", "!", "？", "?", "-"]
REMOVE_QUOTES = dict.fromkeys(map(ord, '"„“‟”＂「」'), None)
GARBAGE_LIST = [
	"a",
	"aa",
	"ah",
	"ahh",
	"h",
	"ha",
	"haa",
	"hah",
	"haha",
	"hahaha",
	"hm",
	"hmm",
	"huh",
	"m",
	"mh",
	"mm",
	"mmh",
	"mmm",
	"o",
	"oh",
]
NEED_CONTEXT_LINES = [
	"feelsgod",
	"godbye",
	"godnight",
	"thankyou",
]

clean_text = lambda text: (text
	.replace(".", "")
	.replace(",", "")
	.replace(":", "")
	.replace(";", "")
	.replace("!", "")
	.replace("?", "")
	.replace("-", " ")
	.replace("  ", " ")
	.replace("  ", " ")
	.replace("  ", " ")
	.lower()
	.replace("that feels", "feels")
	.replace("it feels", "feels")
	.replace("feels good", "feelsgood")
	.replace("good bye", "goodbye")
	.replace("good night", "goodnight")
	.replace("thank you", "thankyou")
	.replace("aaaaaa", "a")
	.replace("aaaa", "a")
	.replace("aa", "a")
	.replace("aa", "a")
	.replace("mmmmmm", "m")
	.replace("mmmm", "m")
	.replace("mm", "m")
	.replace("mm", "m")
	.replace("hhhhhh", "h")
	.replace("hhhh", "h")
	.replace("hh", "h")
	.replace("hh", "h")
	.replace("oooooo", "o")
	.replace("oooo", "o")
	.replace("oo", "o")
	.replace("oo", "o")
)

TO_LANGUAGE_CODE = { # from https://github.com/openai/whisper/blob/main/whisper/tokenizer.py
	"afrikaans": "af",
	"albanian": "sq",
	"amharic": "am",
	"arabic": "ar",
	"armenian": "hy",
	"assamese": "as",
	"azerbaijani": "az",
	"bashkir": "ba",
	"basque": "eu",
	"belarusian": "be",
	"bengali": "bn",
	"bosnian": "bs",
	"breton": "br",
	"bulgarian": "bg",
	"burmese": "my",
	"castilian": "es",
	"catalan": "ca",
	"chinese": "zh",
	"croatian": "hr",
	"czech": "cs",
	"danish": "da",
	"dutch": "nl",
	"english": "en",
	"estonian": "et",
	"faroese": "fo",
	"finnish": "fi",
	"flemish": "nl",
	"french": "fr",
	"galician": "gl",
	"georgian": "ka",
	"german": "de",
	"greek": "el",
	"gujarati": "gu",
	"haitian creole": "ht",
	"haitian": "ht",
	"hausa": "ha",
	"hawaiian": "haw",
	"hebrew": "he",
	"hindi": "hi",
	"hungarian": "hu",
	"icelandic": "is",
	"indonesian": "id",
	"italian": "it",
	"japanese": "ja",
	"javanese": "jw",
	"kannada": "kn",
	"kazakh": "kk",
	"khmer": "km",
	"korean": "ko",
	"lao": "lo",
	"latin": "la",
	"latvian": "lv",
	"letzeburgesch": "lb",
	"lingala": "ln",
	"lithuanian": "lt",
	"luxembourgish": "lb",
	"macedonian": "mk",
	"malagasy": "mg",
	"malay": "ms",
	"malayalam": "ml",
	"maltese": "mt",
	"maori": "mi",
	"marathi": "mr",
	"moldavian": "ro",
	"moldovan": "ro",
	"mongolian": "mn",
	"myanmar": "my",
	"nepali": "ne",
	"norwegian": "no",
	"nynorsk": "nn",
	"occitan": "oc",
	"panjabi": "pa",
	"pashto": "ps",
	"persian": "fa",
	"polish": "pl",
	"portuguese": "pt",
	"punjabi": "pa",
	"pushto": "ps",
	"romanian": "ro",
	"russian": "ru",
	"sanskrit": "sa",
	"serbian": "sr",
	"shona": "sn",
	"sindhi": "sd",
	"sinhala": "si",
	"sinhalese": "si",
	"slovak": "sk",
	"slovenian": "sl",
	"somali": "so",
	"spanish": "es",
	"sundanese": "su",
	"swahili": "sw",
	"swedish": "sv",
	"tagalog": "tl",
	"tajik": "tg",
	"tamil": "ta",
	"tatar": "tt",
	"telugu": "te",
	"thai": "th",
	"tibetan": "bo",
	"turkish": "tr",
	"turkmen": "tk",
	"ukrainian": "uk",
	"urdu": "ur",
	"uzbek": "uz",
	"valencian": "ca",
	"vietnamese": "vi",
	"welsh": "cy",
	"yiddish": "yi",
	"yoruba": "yo",
}

In [None]:
#@markdown **Chạy Whisper**<br>
#@markdown *Lưu ý: Sao chép đường dẫn của file và dán vào "audio_path" trước khi Chạy Whisper*
#@markdown <br><br>
#@markdown **CÀI ĐẶT TỐI THIỂU:**
audio_path = ""  # @param {type:"string"}
language = "japanese"  # @param {type:"string"}
translation_mode = "transcription + translation"  # @param ["transcription only", "transcription + translation", "transcription + translation with DeepL"]
#@markdown Nếu không dùng DeepL thì bỏ qua 2 mục sau:
deepl_authkey = ""  # @param {type:"string"}
deepl_target_lang = "EN-US"  # @param {type:"string"}
#@markdown <br><br/>
#@markdown ***CÀI ĐẶT NÂNG CAO*** <br>
#@markdown <br>
#@markdown Cài đặt SileroVAD:
vad_threshold = 0.4  # @param {type:"number"}
chunk_duration = 3.0  # @param {type:"number"}
#@markdown Bật "vocals_extraction" cho file có thời lượng >1h sẽ tiêu hao nhiều VRAM và dễ gây sập Colab
vocals_extraction = False  # @param {type:"boolean"}
#@markdown 2 mục dưới đây giữ nguyên trừ khi cần tinh chỉnh
condition_on_previous_text = True  # @param {type:"boolean"}
initial_prompt = ""  # @param {type:"string"}

# some sanity checks
assert vad_threshold >= 0.01
assert chunk_duration >= 0.1
assert audio_path != ""
assert language != ""
language = language.lower()
assert language in TO_LANGUAGE_CODE

if translation_mode == "transcription + translation":
	task = "translate"
	run_deepl = False
elif translation_mode == "transcription + translation with DeepL":
	task = "transcribe"
	run_deepl = True
elif translation_mode == "transcription only":
	task = "transcribe"
	run_deepl = False
else:
	raise ValueError("Invalid translation mode")

if initial_prompt.strip() == "":
	initial_prompt = None

if "http://" in audio_path or "https://" in audio_path:
	print("Downloading audio …")
	urllib.request.urlretrieve(audio_path, "input_file")
	audio_path = "input_file"
else:
	if not os.path.exists(audio_path):
		try:
			audio_path = uploaded_file
			if not os.path.exists(audio_path):
				raise ValueError("Không tìm thấy file. audio_path của bạn đã đúng chưa?")
		except NameError:
			raise ValueError("Không tìm thấy file. Bạn đã upload file chưa?")

audiofilebasename = os.path.splitext(audio_path)[0]
out_path = audiofilebasename + ".srt"
out_path_pre = audiofilebasename + "_Untranslated.srt"

if vocals_extraction:
	print("Separating vocals …")
	raw_audio = demucs_load_track(audio_path, DEMUCS_MODEL.audio_channels, DEMUCS_MODEL.samplerate)
	# should not be on GPU because sometimes not enough VRAM
	if raw_audio.dim() == 1:
		raw_audio = raw_audio[None, None].repeat_interleave(2, -2)
	elif raw_audio.shape[-2] == 1:
		raw_audio = raw_audio.repeat_interleave(2, -2)
	elif raw_audio.dim() < 3:
		raw_audio = raw_audio[None]
	demucs_extract = demucs_apply_model(DEMUCS_MODEL, raw_audio, device="cuda", split=True, overlap=.25)
	torch.cuda.empty_cache()
	demucs_res = demucs_extract[0, DEMUCS_MODEL.sources.index("vocals")].mean(0)[None]
	audio_path = audiofilebasename + ".vocals.wav"
	torchaudio.save(audio_path, demucs_res, DEMUCS_MODEL.samplerate)

print("Đang chạy Whisper … VUI LÒNG ĐỢI MỘT LÁT")
segments, info = WHISPER_MODEL.transcribe(
	audio_path, task=task, language=TO_LANGUAGE_CODE[language],
	condition_on_previous_text=condition_on_previous_text, initial_prompt=initial_prompt,
	vad_filter=True, vad_parameters=dict(threshold=vad_threshold, max_speech_duration_s=chunk_duration)
)

subs = []
segment_info = []
timestamps = 0.0  # for progress bar

with tqdm(total=info.duration, unit=" audio seconds") as pbar:
	for i, seg in enumerate(segments, start=1):
		# Keep segment info for debugging
		segment_info.append(seg)
		# Add to SRT list
		subs.append(srt.Subtitle(
			index=i,
			start=datetime.timedelta(seconds=seg.start),
			end=datetime.timedelta(seconds=seg.end),
			content=seg.text.lstrip(),
		))
		pbar.update(seg.end - timestamps)
		timestamps = seg.end

with open("segment_info.json", mode="w", encoding="utf8") as f:
	json.dump(segment_info, f, indent=4)

# DeepL translation
translate_error = False
if run_deepl:
	print("Translating …")
	with open(out_path_pre, "w", encoding="utf8") as f:
		f.write(srt.compose(subs))
	print("(Untranslated subs saved to", out_path_pre, ")")

	lines = []
	for i in range(len(subs)):
		if language == "japanese":
			if subs[i].content[-1] not in PUNCT_MATCH:
				subs[i].content += "。"
			subs[i].content = "「" + subs[i].content + "」"
		else:
			if subs[i].content[-1] not in PUNCT_MATCH:
				subs[i].content += "."
			subs[i].content = '"' + subs[i].content + '"'
	for i in range(len(subs)):
		lines.append(subs[i].content)

	grouped_lines = []
	english_lines = []
	for i, l in enumerate(lines):
		if i % 30 == 0:
			# Split lines into smaller groups, to prevent error 413
			grouped_lines.append([])
			if i != 0:
				# Include previous 3 lines, to preserve context between splits
				grouped_lines[-1].extend(grouped_lines[-2][-3:])
		grouped_lines[-1].append(l.strip())
		
	try:
		translator = deepl.Translator(deepl_authkey)
		for i, n in enumerate(tqdm(grouped_lines)):
			x = ["\n".join(n).strip()]
			if language == "japanese":
				result = translator.translate_text(x, source_lang="JA", target_lang=deepl_target_lang)
			else:
				result = translator.translate_text(x, target_lang=deepl_target_lang)
			english_tl = result[0].text.strip().splitlines()
			assert len(english_tl) == len(n), f"Invalid translation line count ({len(english_tl)} vs {len(n)})"
			if i != 0:
				english_tl = english_tl[3:]
			for e in english_tl:
				english_lines.append(
					e.strip().translate(REMOVE_QUOTES).replace("’", "'")
				)
		for i, e in enumerate(english_lines):
			subs[i].content = e
	except Exception as e:
		print("DeepL translation error:", e)
		print("(downloading untranslated version instead)")
		translate_error = True

# Write SRT file
if translate_error:
	g_files.download(out_path_pre)
else:
	# Removal of garbage lines
	clean_subs = []
	last_line_garbage = False
	for i in range(len(subs)):
		c = clean_text(subs[i].content)
		is_garbage = True
		for w in c.split(" "):
			w_tmp = w.strip()
			if w_tmp == "":
				continue
			if w_tmp in GARBAGE_LIST:
				continue
			elif w_tmp in NEED_CONTEXT_LINES and last_line_garbage:
				continue
			else:
				is_garbage = False
				break
		if not is_garbage:
			clean_subs.append(subs[i])
		last_line_garbage = is_garbage
	with open(out_path, mode="w", encoding="utf8") as f:
		f.write(srt.compose(clean_subs))
	print("\nHoàn tất! File SRT đã được lưu ở", out_path)
	print("Đang tải xuống file SRT về máy tính:")
	g_files.download(out_path)