In [None]:
!git clone https://github.com/hein-nkhh/unilm.git
%cd unilm/beit3

In [None]:
from IPython.display import clear_output

In [None]:
!pip install -r requirements.txt
clear_output()

In [None]:
import os
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from modeling_finetune import beit3_large_patch16_384_retrieval
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import autocast
from transformers import XLMRobertaTokenizer
import torch
import json
import cv2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = XLMRobertaTokenizer("/kaggle/input/beit3_base_retrieval/pytorch/default/2/beit3.spm")

# Mô hình beit_3
ckpt = "/kaggle/input/beit3_base_retrieval/pytorch/default/2/beit3_large_patch16_384_coco_retrieval.pth"
model = beit3_large_patch16_384_retrieval(pretrained=False)
state_dict = torch.load(ckpt, map_location=device)
model.load_state_dict(state_dict["model"], strict=False)
model = model.to(device)
model.eval()
clear_output()

transform  = transforms.Compose([
    transforms.Resize((384, 384), interpolation=3), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])

In [None]:
# Mô hình translate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "VietAI/envit5-translation"
tokenizer_translate = AutoTokenizer.from_pretrained(model_name)  
model_translate = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model_translate = model_translate.to('cuda:0')

In [None]:
def encode_text(query: str, tokenizer, model, max_len=32, device=device):
    tokens = tokenizer.tokenize(query)
    if len(tokens) > max_len - 2:
        tokens = tokens[:max_len - 2]
    token_ids = tokenizer.convert_tokens_to_ids(tokens)

    bos, eos, pad = tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
    token_ids = [bos] + token_ids + [eos]
    num_tokens = len(token_ids)

    padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens)
    language_tokens = token_ids + [pad] * (max_len - num_tokens)

    language_tokens = torch.tensor(language_tokens, dtype=torch.long).unsqueeze(0).to(device)
    padding_mask = torch.tensor(padding_mask, dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        _, language_cls = model(text_description=language_tokens,
                                padding_mask=padding_mask,
                                only_infer=True)
        lang_norm = F.normalize(language_cls, p=2, dim=-1).cpu()  # (1,D)
    return lang_norm

In [None]:
json_files = ['/kaggle/input/aic-sample-test/keyframes_index/L21_V003_keyframes_index.json',
             '/kaggle/input/aic-sample-test/keyframes_index/L21_V006_keyframes_index.json',
             '/kaggle/input/aic-sample-test/keyframes_index/L21_V007_keyframes_index.json',
             '/kaggle/input/aic-sample-test/keyframes_index/L21_V011_keyframes_index.json'
            ]


video_paths = ['/kaggle/input/aic-sample-test/videos/L21_V003.mp4', 
               '/kaggle/input/aic-sample-test/videos/L21_V006.mp4',
               '/kaggle/input/aic-sample-test/videos/L21_V007.mp4',
               '/kaggle/input/aic-sample-test/videos/L21_V011.mp4'
              ]

output_dir = 'extracted_frames'
os.makedirs(output_dir, exist_ok=True)

image_paths = []

for i, json_file in enumerate(json_files):
    with open(json_file, 'r') as f:
        frame_indices = json.load(f)

    cap = cv2.VideoCapture(video_paths[i])
    
    if not cap.isOpened():
        print(f"Không thể mở video: {video_paths[i]}")
        continue

    # tqdm để hiển thị tiến trình
    for frame_index in tqdm(frame_indices, desc=f"Trích xuất từ video {i+1}/{len(json_files)}"):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        ret, frame = cap.read()
        if not ret:
            print(f"⚠️ Không thể đọc frame tại index {frame_index} trong video {i+1}.")
            continue

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_image = Image.fromarray(frame_rgb)
        
        frame_path = os.path.join(output_dir, f'video{i+1}_frame_{frame_index}.png')
        frame_image.save(frame_path)
        image_paths.append(frame_path)

    cap.release()
    
clear_output()
print(f"Đã trích xuất tổng cộng {len(image_paths)} frame.")
# print(f"Các đường dẫn frame:", image_paths)

In [None]:
image_paths = sorted(image_paths)

In [None]:
# path = "/kaggle/input/aic-small-2024/Keyframes_L21/keyframes/L21_V001"
# image_paths = [os.path.join(path, name) for name in os.listdir(path)]

embeddings = []
ids = []

with torch.no_grad():
    for img_path in tqdm(image_paths, desc="🔄 Extracting image embeddings"):
        image = Image.open(img_path).convert("RGB")
        image_tensor = transform(image).unsqueeze(0).to(device)

        with autocast(): 
            vision_cls, _ = model(image=image_tensor, only_infer=True)
            vision_norm = F.normalize(vision_cls, p=2, dim=-1)

        embeddings.append(vision_norm.squeeze(0).cpu())   # (D,)
        ids.append(img_path)

        del image_tensor, vision_cls, vision_norm
        torch.cuda.empty_cache()

image_embeddings = torch.stack(embeddings, dim=0)  # (N,D)
torch.save({"embeddings": image_embeddings, "ids": ids}, "image_embeddings.pt")
print("✅ Saved embeddings:", image_embeddings.shape)



In [None]:
def translate_vi_en(text: str) -> str:
    # Thêm prefix "vi:" vào input
    prefixed = f"vi: {text}"
    inputs = tokenizer_translate([prefixed], return_tensors="pt", padding=True).input_ids.to('cuda:0')
    outputs = model_translate.generate(inputs, max_length=50)
    translated = tokenizer_translate.decode(outputs[0], skip_special_tokens=True)
    # Xóa prefix "en:" nếu có
    return translated.replace("en:", "").strip()

In [None]:
text_query = "Băng tan ở Nam Cực"
text_query_en = translate_vi_en(text_query)
lang_vec = encode_text(text_query_en, tokenizer, model)

# Similarity
similarities = torch.matmul(lang_vec, image_embeddings.t()).squeeze(0)  # (N,)

# Top-K results
topk = torch.topk(similarities, k=10)
print("Top matches:")
for idx, score in zip(topk.indices, topk.values):
    img_path = ids[idx]
    print(img_path, float(score))

    img = Image.open(img_path)
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Sim: {score:.4f}")
    plt.show()