In [None]:
from PIL import Image
from transformers import AutoModel, AutoConfig
from transformers import CLIPImageProcessor, pipeline, CLIPTokenizer
import torch
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
import pickle
from torch.nn.functional import cosine_similarity
import torch
import pickle
from tqdm import tqdm
import json
import cv2
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output
import os

In [None]:
model_name_or_path = "BAAI/EVA-CLIP-8B"
image_size = 224

model = AutoModel.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.float16,
    trust_remote_code=True
).eval()

processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained(model_name_or_path)

In [None]:
# Mô hình translate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


model_name = "VietAI/envit5-translation"
tokenizer_translate = AutoTokenizer.from_pretrained(model_name)  
model_translate = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model_translate = model_translate.to('cuda:0')

In [None]:
text_model = model.text_model.to("cuda:0") 
vision_model = model.vision_model.to("cuda:1")  
model.visual_projection = model.visual_projection.to("cuda:1")  
model.text_projection = model.text_projection.to("cuda:0")  

In [None]:
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to('cuda:1')
    return inputs['pixel_values'] 

def preprocess_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to('cuda:0')
    return inputs

In [None]:
def get_image_features(image_path):
    pixel_values = preprocess_image(image_path)  
    with torch.no_grad(), torch.cuda.amp.autocast():  
        image_features = model.encode_image(pixel_values).to('cuda:0')
    image_features /= image_features.norm(dim=-1, keepdim=True)  
    return image_features


def get_text_features(text):
    text_inputs = preprocess_text(text)  
    with torch.no_grad(), torch.cuda.amp.autocast(): 
        text_features = model.encode_text(text_inputs['input_ids']).to('cuda:0')
    text_features /= text_features.norm(dim=-1, keepdim=True) 
    return text_features

In [None]:
def translate_vi_en(text: str) -> str:
    # Thêm prefix "vi:" vào input
    prefixed = f"vi: {text}"
    inputs = tokenizer_translate([prefixed], return_tensors="pt", padding=True).input_ids.to('cuda:0')
    outputs = model_translate.generate(inputs, max_length=50)
    translated = tokenizer_translate.decode(outputs[0], skip_special_tokens=True)
    # Xóa prefix "en:" nếu có
    return translated.replace("en:", "").strip()

In [None]:
# from torch.nn.functional import cosine_similarity
# import re

# # Hàm tìm kiếm ảnh phù hợp với văn bản
# def text_to_image_retrieval(text, image_paths):
#     text_features = get_text_features(text)
    
#     similarities = []
#     for image_path in image_paths:
#         image_features = get_image_features(image_path)
#         similarity = cosine_similarity(text_features, image_features)
#         similarities.append((image_path, similarity.item()))
    
#     # Sắp xếp theo độ tương đồng giảm dần
#     similarities.sort(key=lambda x: x[1], reverse=True)
    
#     return similarities

# # Thử với một văn bản
# text_query = "Cháy rừng"
# text_query_en = translate_vi_en(text_query)
# retrieved_images = text_to_image_retrieval(text_query_en, image_paths)

# # In ảnh có độ tương đồng cao nhất
# print("Best matching images:")

# for i, (image_path, similarity) in enumerate(retrieved_images[:5]):
#     print(f"{image_path} with similarity: {similarity:.4f}")
    
#     # Hiển thị ảnh
#     img = Image.open(image_path)
#     plt.figure(figsize=(8, 6))  # Kích thước hiển thị
#     plt.imshow(np.array(img))   # Chuyển sang array để imshow xử lý
#     plt.title(f"Image: {image_path}\nSimilarity: {similarity:.4f}")
#     plt.axis('off')  # Tắt trục tọa độ
#     plt.show()


In [None]:
json_files = ['/kaggle/input/aic-sample-test/keyframes_index/L21_V003_keyframes_index.json',
             '/kaggle/input/aic-sample-test/keyframes_index/L21_V006_keyframes_index.json',
             '/kaggle/input/aic-sample-test/keyframes_index/L21_V007_keyframes_index.json',
             '/kaggle/input/aic-sample-test/keyframes_index/L21_V011_keyframes_index.json'
            ]


video_paths = ['/kaggle/input/aic-sample-test/videos/L21_V003.mp4', 
               '/kaggle/input/aic-sample-test/videos/L21_V006.mp4',
               '/kaggle/input/aic-sample-test/videos/L21_V007.mp4',
               '/kaggle/input/aic-sample-test/videos/L21_V011.mp4'
              ]

output_dir = 'extracted_frames'
os.makedirs(output_dir, exist_ok=True)

image_paths = []

for i, json_file in enumerate(json_files):
    with open(json_file, 'r') as f:
        frame_indices = json.load(f)

    cap = cv2.VideoCapture(video_paths[i])
    
    if not cap.isOpened():
        print(f"Không thể mở video: {video_paths[i]}")
        continue

    # tqdm để hiển thị tiến trình
    for frame_index in tqdm(frame_indices, desc=f"Trích xuất từ video {i+1}/{len(json_files)}"):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        ret, frame = cap.read()
        if not ret:
            print(f"⚠️ Không thể đọc frame tại index {frame_index} trong video {i+1}.")
            continue

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_image = Image.fromarray(frame_rgb)
        
        frame_path = os.path.join(output_dir, f'video{i+1}_frame_{frame_index}.png')
        frame_image.save(frame_path)
        image_paths.append(frame_path)

    cap.release()
    
clear_output()
print(f"Đã trích xuất tổng cộng {len(image_paths)} frame.")
print(f"Các đường dẫn frame:", image_paths)

In [None]:
def build_image_embeddings(image_paths, save_path="image_embeddings.pkl"):
    image_embeddings = {}
    for img_path in tqdm(image_paths, desc="Extracting image embeddings"):
        image_features = get_image_features(img_path)  # Hàm của bạn
        # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        image_embeddings[img_path] = image_features.cpu()
    
    with open(save_path, "wb") as f:
        pickle.dump(image_embeddings, f)

build_image_embeddings(image_paths)

In [None]:
def text_to_image_retrieval(text, embeddings_path="image_embeddings.pkl"):
    # Load embeddings đã lưu
    with open(embeddings_path, "rb") as f:
        image_embeddings = pickle.load(f)
    
    # Trích xuất text features
    text_features = get_text_features(text).cpu()
    # text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    similarities = []
    for img_path, img_feat in image_embeddings.items():
        sim = cosine_similarity(text_features, img_feat)
        similarities.append((img_path, sim.item()))
    
    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities

# Ví dụ
text_query = "Hái chuối"
text_query_en = translate_vi_en(text_query)
retrieved_images = text_to_image_retrieval(text_query_en)


for i, (image_path, similarity) in enumerate(retrieved_images[:5]):
    print(f"{image_path} with similarity: {similarity:.4f}")
    
    # Hiển thị ảnh
    img = Image.open(image_path)
    plt.figure(figsize=(8, 6))  # Kích thước hiển thị
    plt.imshow(np.array(img))   # Chuyển sang array để imshow xử lý
    plt.title(f"Image: {image_path}\nSimilarity: {similarity:.4f}")
    plt.axis('off')  # Tắt trục tọa độ
    plt.show()