<a href="https://colab.research.google.com/github/detektor777/colab_list_audio/blob/main/translator_srt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1.  Go to the website [https://console.mistral.ai/api-keys](https://console.mistral.ai/api-keys) and log in.
2.  If you haven't chosen a pricing plan yet, click "Choose a plan." You can select either a free or a paid plan.
3.  On the page [https://console.mistral.ai/api-keys](https://console.mistral.ai/api-keys), create an API key and paste it into the field below.


In [None]:
#@title ##**Apy key** { display-mode: "form" }
! pip install mistralai

api_key = "" #@param {type:"string"}

In [None]:
#@title ##**Upload Subtitles** { display-mode: "form" }
from google.colab import files
from IPython.display import display, clear_output
import re

clear_output(wait=True)

uploaded = files.upload()
if not uploaded:
    print("Error: No file uploaded.")
    raise SystemExit

subtitle_file = list(uploaded.keys())[0]
subtitle_format = subtitle_file.split('.')[-1].lower()

combined_segments = []

def parse_srt_time(time_str):
    try:
        time_str = time_str.replace(',', ':').replace('.', ':')
        parts = time_str.split(':')
        if len(parts) != 4:
            raise ValueError(f"Invalid time format: {time_str}")
        hours, minutes, seconds, milliseconds = map(int, parts)
        return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
    except Exception as e:
        print(f"Error parsing time '{time_str}': {e}")
        return 0

with open(subtitle_file, 'r', encoding='utf-8') as f:
    content = f.read().strip()
    blocks = content.split('\n\n')
    for block in blocks:
        lines = block.strip().split('\n')
        if len(lines) >= 3 and '-->' in lines[1]:
            index = lines[0]
            try:
                start_time, end_time = lines[1].split(' --> ')
                text = ' '.join(lines[2:])
                combined_segments.append({
                    "start": parse_srt_time(start_time),
                    "end": parse_srt_time(end_time),
                    "text": text.strip()
                })
            except ValueError as e:
                print(f"Skipping invalid block: {block}\nError: {e}")
        elif subtitle_format == 'txt' and lines:
            combined_segments.append({
                "start": 0,
                "end": 0,
                "text": block.strip()
            })

if not combined_segments:
    print("Error: Could not parse subtitles. Ensure the file is in SRT or TXT format.")
    raise SystemExit

print(f"Subtitles loaded successfully from {subtitle_file}. {len(combined_segments)} segments detected.")

In [None]:
#@title ##**Translate Subtitles with Mistral** { display-mode: "form" }
translate_to = "uk"                 #@param ["ar", "az", "bg", "bn", "ca", "cs", "da", "de", "el", "en", "eo", "es", "et", "eu", "fa", "fi", "fr", "gl", "he", "hi", "hu", "id", "it", "ja", "ko", "lt", "lv", "ms", "nb", "nl", "pb", "pl", "pt", "ro", "ru", "sk", "sl", "sq", "sr", "sv", "th", "tl", "tr", "uk", "ur", "vi", "zh", "zt"]
translate_format = "srt"            #@param ["txt", "srt"]
sentence_size = 10                  #@param {type:"integer"}
show_translated_text = True         #@param {type:"boolean"}
preprocess_subtitles = False        #@param {type:"boolean"}
retry_delay = 5                     #@param {type:"integer", description:"Delay between translation retry attempts (seconds)"}
batch_delay = 2                     #@param {type:"integer", description:"Delay between batch translation attempts (seconds)"}

import os
import datetime
from IPython.display import display, clear_output, HTML
import ipywidgets as widgets
import gc
import time
import re
import base64
from mistralai import Mistral
from mistralai.models.sdkerror import SDKError

clear_output(wait=True)
widgets.Widget.close_all()

if 'combined_segments' not in globals() or not combined_segments:
    print("Error: Subtitles not found. Please run the 'Upload Subtitles' cell first.")
    raise SystemExit

global client
if 'client' not in globals():
    print("Initializing Mistral client...")
    api_key = "NkyraCkMI3m8RJ9VqfYPK5Syj8gUnWY5"
    client = Mistral(api_key=api_key)
    print("Mistral client initialized.")
else:
    print("Using pre-initialized Mistral client.")

def seconds_to_srt_time(sec):
    hours = int(sec // 3600)
    minutes = int((sec % 3600) // 60)
    seconds = int(sec % 60)
    milliseconds = int((sec - int(sec)) * 1000)
    return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}"

def fix_spacing(text):
    text = re.sub(r'([.!?…])([a-zA-Zа-яА-Я])', r'\1 \2', text)
    text = re.sub(r'\s+([.!?…])', r'\1', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def preprocess_segments(segments):
    if not preprocess_subtitles:
        return segments

    processed_segments = []
    current_segment = None

    for seg in segments:
        text = seg["text"].strip()
        if current_segment is None:
            current_segment = {
                "start": seg["start"],
                "end": seg["end"],
                "text": text
            }
        else:
            prev_text = current_segment["text"]
            if not any(prev_text.endswith(punct) for punct in ('.', '!', '?', '...')):
                current_segment["end"] = seg["end"]
                current_segment["text"] += " " + text
            else:
                processed_segments.append(current_segment)
                current_segment = {
                    "start": seg["start"],
                    "end": seg["end"],
                    "text": text
                }

    if current_segment:
        processed_segments.append(current_segment)

    return processed_segments

translated_segments = []
progress_bar = widgets.IntProgress(min=0, max=100, description='Translating:', bar_style='info')
display(progress_bar)

mistral_responses_log = []

def translate_with_mistral(text, target_lang):
    original_blocks = text.strip().split("\n\n")
    srt_blocks = {i: None for i in range(len(original_blocks))}
    to_translate = text

    max_attempts = 5
    for attempt in range(max_attempts):
        try:
            messages = [
                {"role": "system", "content": f"You are a translator who translates SRT format text to {target_lang} language. Return only the translated text in SRT format (number, timestamps, translated text), without any extra words or comments."},
                {"role": "user", "content": f"Translate the following text to {target_lang} in SRT format:\n\n{to_translate}"}
            ]

            chat_response = client.chat.complete(
                model="mistral-large-latest",
                messages=messages,
                max_tokens=512,
                temperature=0.5,
                top_p=0.9
            )
            translated_text = chat_response.choices[0].message.content.strip()

            mistral_responses_log.append({
                "batch": to_translate,
                "response": translated_text,
                "attempt": attempt + 1
            })

            translated_blocks = translated_text.split("\n\n")
            translated_index = 0

            for i, orig_block in enumerate(original_blocks):
                orig_lines = orig_block.strip().split("\n")
                if len(orig_lines) < 3 or '-->' not in orig_lines[1]:
                    continue
                orig_number = orig_lines[0]
                orig_timecode = orig_lines[1]
                orig_text = "\n".join(orig_lines[2:])

                if srt_blocks[i] is not None:
                    continue

                translated_line = None
                if translated_index < len(translated_blocks):
                    lines = translated_blocks[translated_index].strip().split("\n")
                    if len(lines) >= 3 and '-->' in lines[1]:
                        translated_line = "\n".join(lines[2:])
                    elif len(lines) > 0:
                        translated_line = lines[0].strip()
                    translated_index += 1

                if (translated_line and
                    translated_line != orig_text and
                    not re.search(r'-{10,}', translated_line) and
                    not re.match(r'^\d+$', translated_line) and
                    "-->" not in translated_line and
                    not re.match(r'^\d+:\d+:\d+,\d+ --> \d+:\d+:\d+,\d+$', translated_line)):
                    srt_blocks[i] = f"{orig_number}\n{orig_timecode}\n{translated_line}"

            untranslated_blocks = [orig_block for i, orig_block in enumerate(original_blocks) if srt_blocks[i] is None]
            if not untranslated_blocks:
                break
            to_translate = "\n\n".join(untranslated_blocks)
            if attempt < max_attempts - 1:
                time.sleep(batch_delay)

        except SDKError as e:
            if "Status 429" in str(e):
                wait_time = 2 ** attempt
                print(f"Request limit exceeded (429), waiting {wait_time} seconds before attempt {attempt + 2}...")
                time.sleep(wait_time)
                if attempt == max_attempts - 1:
                    raise e
            else:
                raise e

    for i, orig_block in enumerate(original_blocks):
        if srt_blocks[i] is None:
            srt_blocks[i] = orig_block

    return "\n\n".join(srt_blocks.values())

try:
    processed_segments = preprocess_segments(combined_segments)

    full_text = ""
    for i, seg in enumerate(processed_segments, 1):
        start_time = seconds_to_srt_time(seg["start"])
        end_time = seconds_to_srt_time(seg["end"])
        text = seg["text"].strip()
        full_text += f"{i}\n{start_time} --> {end_time}\n{text}\n\n"
    full_text = full_text.strip()

    srt_blocks = full_text.split("\n\n")
    sentences_to_translate = [block for block in srt_blocks if len(block.strip().split("\n")) >= 3]

    total_batches = (len(sentences_to_translate) + sentence_size - 1) // sentence_size
    translated_texts = []
    for i in range(0, len(sentences_to_translate), sentence_size):
        batch = sentences_to_translate[i:i + sentence_size]
        batch_text = "\n\n".join(batch)
        progress_bar.value = (i // sentence_size + 1) * 100 // total_batches
        translated_batch_text = translate_with_mistral(batch_text, translate_to)
        translated_blocks = translated_batch_text.split("\n\n")
        for block in translated_blocks:
            lines = block.strip().split("\n")
            if len(lines) >= 3:
                translated_texts.append(lines[2])
            else:
                translated_texts.append(block)

    while len(translated_texts) < len(processed_segments):
        translated_texts.append(processed_segments[len(translated_texts)]["text"])

    for i in range(len(translated_texts)):
        raw_text = translated_texts[i]
        orig_text = processed_segments[i]["text"]
        if (re.search(r'-{10,}', raw_text) or
            re.match(r'^\d+$', raw_text.strip()) or
            not raw_text.strip() or
            (len(raw_text.strip()) < 0.7 * len(orig_text.strip()) and not raw_text.strip().endswith(('...', '.', '!', '?')))):
            original_block = f"{i + 1}\n{seconds_to_srt_time(processed_segments[i]['start'])} --> {seconds_to_srt_time(processed_segments[i]['end'])}\n{orig_text}"
            max_retry_attempts = 5
            print(f"Retranslating subtitle {i + 1}: {raw_text} (original: {orig_text})")
            for attempt in range(max_retry_attempts):
                try:
                    retry_translated_text = translate_with_mistral(original_block, translate_to)
                    retry_lines = retry_translated_text.strip().split("\n")
                    if (len(retry_lines) >= 3 and
                        not re.search(r'-{10,}', retry_lines[2]) and
                        not re.match(r'^\d+$', retry_lines[2].strip()) and
                        retry_lines[2].strip() and
                        len(retry_lines[2].strip()) >= 0.7 * len(orig_text.strip())):
                        translated_texts[i] = retry_lines[2]
                        print(f"Successful translation for subtitle {i + 1} on attempt {attempt + 1}: {retry_lines[2]}")
                        break
                    if attempt < max_retry_attempts - 1:
                        time.sleep(retry_delay)
                except SDKError as e:
                    if "Status 429" in str(e):
                        wait_time = retry_delay * (attempt + 1)
                        print(f"Request limit exceeded (429) while retranslating subtitle {i + 1}, waiting {wait_time} seconds...")
                        time.sleep(wait_time)
                        if attempt == max_retry_attempts - 1:
                            translated_texts[i] = orig_text
                            print(f"Could not translate subtitle {i + 1} after 5 attempts due to limit, using original: {orig_text}")
                            break
                    else:
                        raise e
            else:
                translated_texts[i] = orig_text
                print(f"Could not translate subtitle {i + 1} after 5 attempts, using original: {orig_text}")

    translated_texts = [fix_spacing(text) for text in translated_texts]

    for i, seg in enumerate(processed_segments):
        translated_segments.append({
            "start": seg["start"],
            "end": seg["end"],
            "text": translated_texts[i],
            "original_text": seg["text"]
        })
    progress_bar.value = 100

except Exception as e:
    print(f"Unexpected error during translation: {e}")
    raise SystemExit

finally:
    print("Translation completed")

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

def create_translated_subtitle_editor():
    output = widgets.Output()

    def update_translated_srt_file():
        with open(translated_file_path, "w", encoding="utf-8") as f:
            for i, seg in enumerate(translated_segments, start=1):
                start_time = seconds_to_srt_time(seg["start"])
                end_time = seconds_to_srt_time(seg["end"])
                text = seg["text"]
                if translate_format == "srt":
                    f.write(f"{i}\n")
                    f.write(f"{start_time} --> {end_time}\n")
                    f.write(f"{text}\n\n")
                else:
                    f.write(f"{text}\n")

    def on_text_changed(change, seg_idx):
        translated_segments[seg_idx]["text"] = change.new
        update_translated_srt_file()

    text_boxes = []
    for i, seg in enumerate(translated_segments):
        start_time = seconds_to_srt_time(seg["start"])
        end_time = seconds_to_srt_time(seg["end"])
        time_label = widgets.Label(
            value=f'{start_time} --> {end_time}',
            layout={'width': '250px', 'min_width': '250px', 'max_width': '250px', 'height': '80px'}
        )
        original_text = widgets.Textarea(
            value=seg["original_text"],
            layout={'width': '300px', 'min_width': '300px', 'max_width': '300px', 'height': '80px'},
            disabled=True
        )
        translated_text_box = widgets.Textarea(
            value=seg["text"],
            layout={'width': '300px', 'min_width': '300px', 'max_width': '300px', 'height': '80px'}
        )
        translated_text_box.observe(lambda change, idx=i: on_text_changed(change, idx), names='value')
        hbox = widgets.HBox([time_label, original_text, translated_text_box])
        text_boxes.append(translated_text_box)
        display(hbox)

    display(output)

if translate_format == "txt":
    translated_file_path = f"translated_transcription_{timestamp}.txt"
    with open(translated_file_path, "w", encoding="utf-8") as f:
        for seg in translated_segments:
            f.write(f"{seg['text']}\n")
    print(f"Translation completed. Result in file {translated_file_path}")

elif translate_format == "srt":
    translated_file_path = f"translated_transcription_{timestamp}.srt"
    with open(translated_file_path, "w", encoding="utf-8") as f:
        for i, seg in enumerate(translated_segments, start=1):
            start_time = seconds_to_srt_time(seg["start"])
            end_time = seconds_to_srt_time(seg["end"])
            text = seg["text"]
            f.write(f"{i}\n")
            f.write(f"{start_time} --> {end_time}\n")
            f.write(f"{text}\n\n")
    print(f"Translation completed. Result in file {translated_file_path}")
    print("\nInteractive editor for translated subtitles:")
    create_translated_subtitle_editor()

else:
    print("Unknown format. Check format settings.")

In [None]:
#@title ##**Download Translated Subtitles**

from google.colab import files
import os

global translated_file_path, translate_to, translate_format, subtitle_file

if 'subtitle_file' not in globals():
    print("Error: source subtitle file not uploaded. First run the 'Upload Subtitles' cell.")
    raise SystemExit

base_name, ext = os.path.splitext(subtitle_file)
output_filename = f"{base_name}.{translate_to}.{translate_format}"

if not os.path.exists(translated_file_path):
    print(f"Error: file {translated_file_path} not found. Make sure the translation completed successfully.")
    raise SystemExit

if translated_file_path != output_filename:
    os.rename(translated_file_path, output_filename)

files.download(output_filename)
print(f"File {output_filename} downloaded automatically.")