In [None]:
!pip install -q transformers torch Pillow tqdm aiohttp aiofiles

import torch
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import aiohttp, asyncio, aiofiles
import os, time
import numpy as np
import pandas as pd
from tqdm import tqdm 

In [None]:
# =========================
# 1️⃣ Load Model
# =========================
model_name = "jinaai/jina-clip-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code = True)
model = AutoModel.from_pretrained(model_name,trust_remote_code = True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()


In [None]:
# =========================
# 0️⃣ Imports & Setup
# =========================
import os, time, asyncio, aiohttp, aiofiles
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch

# Assume `model` and `device` are already defined
# model: JinaCLIPModel or similar
# device: "cuda" or "cpu"

# =========================
# 1️⃣ Load Dataset
# =========================
train_df = pd.read_csv("/kaggle/input/test-csv/test.csv")  # columns: image_link, catalog_content
sample_df = train_df.iloc[:75000].reset_index(drop=True)
print(f"Using {len(sample_df)} samples for embedding extraction.")

texts = sample_df["catalog_content"].tolist()  # all rows are valid
urls = sample_df["image_link"].tolist()        # all rows are valid URLs

# =========================
# 2️⃣ Image Cache + Placeholder
# =========================
os.makedirs("cache_images", exist_ok=True)

# Create a single black placeholder image
placeholder_path = "cache_images/placeholder.jpg"
if not os.path.exists(placeholder_path):
    Image.new("RGB", (224, 224), color=(0, 0, 0)).save(placeholder_path)

async def download_image(session, url, idx, retries=2):
    """Download image asynchronously with retries, return placeholder if fails."""
    cache_path = f"cache_images/{idx}.jpg"
    if os.path.exists(cache_path):
        return cache_path
    for attempt in range(retries + 1):
        try:
            async with session.get(url, timeout=10) as response:
                if response.status == 200:
                    async with aiofiles.open(cache_path, 'wb') as f:
                        await f.write(await response.read())
                    return cache_path
        except:
            continue
    # Return placeholder if all attempts fail
    return placeholder_path

async def batch_download(urls):
    async with aiohttp.ClientSession() as session:
        tasks = [download_image(session, url, i) for i, url in enumerate(urls)]
        return await asyncio.gather(*tasks)

# =========================
# 3️⃣ Encoding Functions
# =========================
@torch.no_grad()
def encode_texts(texts, batch_size=128):
    """Encodes text embeddings in FP32."""
    all_embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        emb = model.encode_text(batch, convert_to_tensor=True, device=device)
        all_embs.append(emb.cpu())
    return torch.cat(all_embs)  # shape: [num_rows, text_dim]

@torch.no_grad()
def encode_images(img_paths, batch_size=32):
    """Yields image embeddings and their corresponding indices."""
    imgs, batch_indices = [], []
    for i, path in enumerate(img_paths):
        try:
            img = Image.open(path).convert("RGB") if path is not None else None
            if img is not None:
                imgs.append(img)
                batch_indices.append(i)
        except:
            # Rare corrupted image, use placeholder
            img = Image.open(placeholder_path).convert("RGB")
            imgs.append(img)
            batch_indices.append(i)
        
        # If batch full or last image
        if len(imgs) == batch_size or i == len(img_paths) - 1:
            if imgs:
                with torch.autocast(device_type=device, dtype=torch.float16):
                    batch_emb = model.encode_image(imgs)
                yield batch_indices, batch_emb
            imgs, batch_indices = [], []

def fuse_embeddings(text_emb, image_emb, method="concat"):
    """Fuse text and image embeddings with L2 normalization."""
    if method == "concat":
        fused = torch.cat([text_emb, image_emb], dim=-1)
    else:
        fused = text_emb + image_emb
    return fused / fused.norm(p=2, dim=-1, keepdim=True)

# =========================
# 4️⃣ Main Embedding Extraction
# =========================
async def main():
    start_time = time.time()
    
    # -------------------------
    # 1️⃣ Download images
    # -------------------------
    print("📥 Downloading images asynchronously...")
    img_paths = await batch_download(urls)
    
    # -------------------------
    # 2️⃣ Encode text embeddings
    # -------------------------
    print("🧠 Encoding text embeddings...")
    text_embs = encode_texts(texts, batch_size=128)  # FP32 tensor
    
    # -------------------------
    # 3️⃣ Encode images & fuse embeddings
    # -------------------------
    print("🖼 Encoding image embeddings and fusing...")
    
    fused_embs = None
    placeholder_img_emb = None
    
    for batch_indices, img_embs in tqdm(encode_images(img_paths, batch_size=32), desc="Processing images"):
        img_embs_tensor = torch.tensor(img_embs, dtype=torch.float32)
        
        # Initialize fused_embs and placeholder after first batch
        if fused_embs is None:
            image_dim = img_embs_tensor.shape[1]
            fused_dim = text_embs.shape[1] + image_dim
            fused_embs = torch.zeros((len(texts), fused_dim), dtype=torch.float32)
            placeholder_img_emb = torch.zeros((1, image_dim), dtype=torch.float32)
        
        # Handle any missing images in batch
        if img_embs_tensor.shape[0] < len(batch_indices):
            full_embs = placeholder_img_emb.repeat(len(batch_indices), 1)
            full_embs[:img_embs_tensor.shape[0]] = img_embs_tensor
            img_embs_tensor = full_embs
        
        # Fuse with text embeddings
        fused_batch = fuse_embeddings(text_embs[batch_indices], img_embs_tensor)
        fused_embs[batch_indices] = fused_batch

    # -------------------------
    # 4️⃣ Done
    # -------------------------
    fused_embs_np = fused_embs.numpy()
    total_time = time.time() - start_time
    print(f"\n✅ Done! {len(fused_embs_np)} embeddings created.")
    print(f"⏱ Total time: {total_time:.2f}s (~{total_time/len(fused_embs_np):.3f}s per sample)")
    
    # -------------------------
    # 5️⃣ Save embeddings
    # -------------------------
    np.save("/kaggle/working/test_70k_75k.npy", fused_embs_np)
    print("💾 Saved fused embeddings → train_fused_embeddings20k_30k.npy")

# Run the async main
await main()
