In [None]:
import cv2
import torch
import numpy as np
from transformers import MarianMTModel, MarianTokenizer
import pytesseract
from tqdm import tqdm
import concurrent.futures
import os
import json
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sacrebleu import corpus_bleu

In [None]:

class SubtitleDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    src_texts, tgt_texts = zip(*batch)
    return list(src_texts), list(tgt_texts)

class VideoSubtitleTranslator:
    def __init__(self, src_lang='de', tgt_lang='en', model_path=None):
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.fps = 30.0
        self.sync_threshold = 0.02  # 2% of a frame duration for 98% accuracy
        
        # Load pre-trained translation model
        if model_path and os.path.exists(model_path):
            self.model = torch.load(model_path)
        else:
            model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
            self.model = MarianMTModel.from_pretrained(model_name)
        
        self.tokenizer = MarianTokenizer.from_pretrained(f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}')
        
        # Set device (use GPU if available)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)

    def preprocess_video(self, video_path):
        video = cv2.VideoCapture(video_path)
        total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / self.fps
        
        frames = []
        timestamps = []
        
        for i in tqdm(range(total_frames), desc="Extracting frames"):
            ret, frame = video.read()
            if not ret:
                break
            frames.append(frame)
            timestamps.append(i / self.fps)
        
        video.release()
        return frames, timestamps

    def extract_text_from_frame(self, frame):
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
        text = pytesseract.image_to_string(thresh, lang=self.src_lang)
        return text.strip()

    def translate_text(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = inputs.to(self.device)
        
        with torch.no_grad():
            translated = self.model.generate(**inputs)
        
        return self.tokenizer.decode(translated[0], skip_special_tokens=True)

    def process_frame(self, frame, timestamp):
        text = self.extract_text_from_frame(frame)
        if text:
            translated_text = self.translate_text(text)
            return Subtitle(translated_text, timestamp, timestamp + 1.0 / self.fps)
        return None

    def translate_video(self, video_path):
        frames, timestamps = self.preprocess_video(video_path)
        subtitles = []
        
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = [executor.submit(self.process_frame, frame, timestamp) 
                       for frame, timestamp in zip(frames, timestamps)]
            
            for future in concurrent.futures.as_completed(futures):
                subtitle = future.result()
                if subtitle:
                    subtitles.append(subtitle)
        
        return self.post_process_subtitles(subtitles)

    def post_process_subtitles(self, subtitles):
        subtitles.sort(key=lambda x: x.start_time)
        
        processed_subtitles = []
        current_subtitle = None
        
        for subtitle in subtitles:
            if current_subtitle is None:
                current_subtitle = subtitle
            elif subtitle.start_time - current_subtitle.end_time <= self.sync_threshold:
                current_subtitle.text += " " + subtitle.text
                current_subtitle.end_time = subtitle.end_time
            else:
                processed_subtitles.append(current_subtitle)
                current_subtitle = subtitle
        
        if current_subtitle:
            processed_subtitles.append(current_subtitle)
        
        return processed_subtitles

    def export_subtitles(self, subtitles, output_path):
        with open(output_path, 'w', encoding='utf-8') as f:
            for i, subtitle in enumerate(subtitles, 1):
                f.write(f"{i}\n")
                f.write(f"{self.format_time(subtitle.start_time)} --> {self.format_time(subtitle.end_time)}\n")
                f.write(f"{subtitle.text}\n\n")

   

In [None]:
 @staticmethod
    def format_time(seconds):
        hours = int(seconds / 3600)
        minutes = int((seconds % 3600) / 60)
        seconds = seconds % 60
        milliseconds = int((seconds - int(seconds)) * 1000)
        return f"{hours:02d}:{minutes:02d}:{int(seconds):02d},{milliseconds:03d}"

    def train(self, train_data, val_data, epochs=5, batch_size=32):
        train_dataset = SubtitleDataset(train_data)
        val_dataset = SubtitleDataset(val_data)
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
        
        optimizer = torch.optim.Adam(self.model.parameters())
        criterion = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
        
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            
            for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
                src_texts, tgt_texts = batch
                
                inputs = self.tokenizer(src_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
                inputs = inputs.to(self.device)
                
                targets = self.tokenizer(tgt_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
                target_ids = targets['input_ids'].to(self.device)
                
                outputs = self.model(**inputs, labels=target_ids)
                loss = outputs.loss
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")
            
            # Validation
            self.model.eval()
            val_loss = 0
            with torch.no_grad():
                for batch in val_loader:
                    src_texts, tgt_texts = batch
                    inputs = self.tokenizer(src_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
                    inputs = inputs.to(self.device)
                    targets = self.tokenizer(tgt_texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
                    target_ids = targets['input_ids'].to(self.device)
                    outputs = self.model(**inputs, labels=target_ids)
                    val_loss += outputs.loss.item()
            
            avg_val_loss = val_loss / len(val_loader)
            print(f"Validation Loss: {avg_val_loss:.4f}")

    def evaluate(self, test_data):
        self.model.eval()
        hypotheses = []
        references = []
        
        for src_text, tgt_text in tqdm(test_data, desc="Evaluating"):
            inputs = self.tokenizer(src_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
            inputs = inputs.to(self.device)
            
            with torch.no_grad():
                translated = self.model.generate(**inputs)
            
            translated_text = self.tokenizer.decode(translated[0], skip_special_tokens=True)
            hypotheses.append(translated_text)
            references.append([tgt_text])
        
        bleu_score = corpus_bleu(hypotheses, references)
        return bleu_score.score

class Subtitle:
    def __init__(self, text, start_time, end_time):
        self.text = text
        self.start_time = start_time
        self.end_time = end_time

def load_parallel_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return [(item['source'], item['target']) for item in data]



In [None]:

if __name__ == "__main__":
    # Load and preprocess data
    data = load_parallel_data('parallel_corpus.json')
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
    train_data, val_data = train_test_split(train_data, test_size=0.1, random_state=42)

    # Initialize translator
    translator = VideoSubtitleTranslator(src_lang='de', tgt_lang='en')

    # Train the model
    translator.train(train_data, val_data, epochs=5, batch_size=32)

    # Evaluate the model
    bleu_score = translator.evaluate(test_data)
    print(f"BLEU Score: {bleu_score:.2f}")

    # Save the trained model
    torch.save(translator.model, 'trained_translator_model.pth')

    # Translate a video
    video_path = "path/to/your/video.mp4"
    translated_subtitles = translator.translate_video(video_path)
    translator.export_subtitles(translated_subtitles, "translated_subtitles.srt")

    print("Video translation complete. Subtitles saved to 'translated_subtitles.srt'")