# Setup

In [None]:
# Parameters
bs = None
keyframes_dir = None
save_dir = None

In [None]:
import os

dir_path = os.getcwd()

if not keyframes_dir:
    if 'google.colab' in str(get_ipython()):
        # Update this path as necessary
        keyframes_dir = f'{dir_path}/keyframes'
    elif 'kaggle' in str(get_ipython()):
        keyframes_dir = f'{dir_path}/keyframes'
    else:
        parent_dir_path = os.path.dirname(dir_path)
        keyframes_dir = f'{parent_dir_path}/transnet/keyframes'
        
if not bs:
    bs = 1
    
if not save_dir:
    save_dir = './ocr'

In [None]:
! pip install aiohttp
! pip install aiofiles
! pip install git+https://github.com/JaidedAI/EasyOCR.git
! pip install transformers



In [None]:
import os
import json
import asyncio
import glob
from tqdm import tqdm
import easyocr
import aiofiles
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Parse data path

In [None]:
def parse_keyframe_info(keyframes_dir):
    all_keyframe_paths = {}
    for part in sorted(os.listdir(keyframes_dir)):
        data_part_path = f'{keyframes_dir}/{part}'
        data_part = part.split('/')[-1]
        all_keyframe_paths[data_part] = []
        image_path = sorted(glob.glob(f'{data_part_path}/*.jpg'))
        all_keyframe_paths[data_part] = image_path
    return all_keyframe_paths

# Inference

In [None]:
async def create_directory(path):
    """Create a directory if it does not exist."""
    os.makedirs(path, exist_ok=True)

async def save_ocr_results(save_dir, video_id, ocr_results):
    """Save OCR results to a JSON file."""
    async with aiofiles.open(os.path.join(save_dir, f"{video_id}.json"), "w", encoding='utf-8') as jsonfile:
        await jsonfile.write(json.dumps(ocr_results, ensure_ascii=False))

async def process_image(reader, image_path):
    """Process a single image with OCR."""
    result = await asyncio.to_thread(reader.readtext, image_path)
    refined_result = [item for item in result if item[2] > 0.6]
    refined_result = easyocr.utils.get_paragraph(refined_result)
    return [item[1] for item in refined_result]

async def translate_text(text, tokenizer, model, device):
    """Translate text using VinAI's translation model."""
    input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
    output_ids = model.generate(
        input_ids,
        decoder_start_token_id=tokenizer.lang_code_to_id["en_XX"],
        num_return_sequences=1,
        num_beams=5,
        early_stopping=True
    )
    translated_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    return translated_text
    
async def process_video_keyframes(reader, tokenizer, model, device, video_keyframe_paths, batch_size=16):
    """Process keyframes of a video and perform OCR and translation."""
    video_ocr_results = {}
    tasks = []

    for i in range(0, len(video_keyframe_paths), batch_size):
        batch = video_keyframe_paths[i:i+batch_size]
        for image_path in batch:
            task = asyncio.create_task(process_image(reader, image_path))
            tasks.append((os.path.basename(image_path), task))

    for frame_name, task in tasks:
        text_detected = await task
        if text_detected:
            joined_text = " ||| ".join(text_detected)
            translated_text = await translate_text(joined_text, tokenizer, model, device)
            translated_items = translated_text.split(" ||| ")
            video_ocr_results[frame_name] = translated_items

    return video_ocr_results


# # This code for sequence processing
# async def process_video_keyframes(reader, translator, video_keyframe_paths, batch_size=16):
#     """Process keyframes of a video and perform OCR and translation."""
#     video_ocr_results = {}

#     for i in range(0, len(video_keyframe_paths), batch_size):
#         batch = video_keyframe_paths[i:i+batch_size]
#         for image_path in batch:
#             text_detected = await process_image(reader, image_path)
#             if text_detected:
#                 # Join the OCR texts with a special delimiter
#                 joined_text = " ||| ".join(text_detected)
#                 translated_text = await translate_text(joined_text, translator)
                
#                 # Split the translated text back into separate items
#                 translated_items = translated_text.split(" ||| ")
                
#                 video_ocr_results[os.path.basename(image_path)] = translated_items
        
#         torch.cuda.empty_cache()  # Clear the GPU cache after each batch

#     return video_ocr_results

async def ocr_and_save_results(reader, tokenizer, model, device, all_keyframe_paths, save_dir, batch_size=16):
    """Perform OCR on keyframes, translate, and save results to JSON files."""
    await create_directory(save_dir)
    keys = sorted(all_keyframe_paths.keys())

    for key in tqdm(keys, desc="Processing keys"):
        video_keyframe_paths = all_keyframe_paths[key]
        video_ocr_results = await process_video_keyframes(reader, tokenizer, model, device, video_keyframe_paths, batch_size)
        await save_ocr_results(save_dir, key, video_ocr_results)

In [None]:
# Main execution
all_keyframe_paths = parse_keyframe_info(keyframes_dir)
reader = easyocr.Reader(['vi'], gpu=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("vinai/vinai-translate-vi2en-v2", src_lang="vi_VN")
model = AutoModelForSeq2SeqLM.from_pretrained("vinai/vinai-translate-vi2en-v2").to(device)

await ocr_and_save_results(reader, tokenizer, model, device, all_keyframe_paths, save_dir, bs)

Processing keys:   0%|          | 0/4 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 230.00 MiB. GPU 