# Install Libraries

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!pip install pyngrok
!pip install fastapi uvicorn
!pip install pinecone[grpc]
!pip install uvicorn torch torchvision transformers openai python_dotenv unidecode pandas pillow requests datasets python-multipart

Collecting pyngrok
  Downloading pyngrok-7.2.8-py3-none-any.whl.metadata (10 kB)
Downloading pyngrok-7.2.8-py3-none-any.whl (25 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.8
Collecting fastapi
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn
  Downloading uvicorn-0.34.2-py3-none-any.whl.metadata (6.5 kB)
Collecting starlette<0.47.0,>=0.40.0 (from fastapi)
  Downloading starlette-0.46.2-py3-none-any.whl.metadata (6.2 kB)
Downloading fastapi-0.115.12-py3-none-any.whl (95 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m95.2/95.2 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading uvicorn-0.34.2-py3-none-any.whl (62 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m62.5/62.5 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m

# Model Library

## Utils

In [4]:

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import models, transforms
from torch.optim import Adam
import requests
import numpy as np
import os
import json
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def cross_entropy(preds, targets, reduction='none', epsilon=1e-8):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds + epsilon)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0


    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count if self.count > 0 else 0.0

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

## Image Encoder

In [6]:
import timm
import torch
import torch.nn as nn

ImageConfigs = {
    "vit_base_patch16_224": {
        "model_name": "vit_base_patch16_224",
        "pretrained": True,
        "trainable": True,
        "img_dims": 768,
        "img_size": (224, 224)
    },
    "swin_large_patch4_window12_384": {
        "model_name": "swin_large_patch4_window12_384",
        "pretrained": True,
        "trainable": True,
        "img_dims": 1536,
        "img_size": (384, 384)
    },
    "resnet50": {
        "model_name": "resnet50",
        "pretrained": True,
        "trainable": True,
        "img_dims": 2048,
        "img_size": (224, 224)
    },
    "swin-base-patch4-window7-224": {
        "model_name": "swin_base_patch4_window7_224",
        "pretrained": True,
        "trainable": True,
        "img_dims": 1024,
        "img_size": (224, 224)
    }
}

class ImageEncoder(nn.Module):
    def __init__(self, model_name, freeze_ratio=0.5):
        super().__init__()
        if model_name not in ImageConfigs:
            raise ValueError(f"Model '{model_name}' is not defined in ImageConfigs.")

        self.configs = ImageConfigs[model_name]

        # Create model
        self.model = timm.create_model(
            self.configs["model_name"],
            pretrained=self.configs.get("pretrained", True),
            num_classes=0,  # No classification head
            global_pool="avg"  # Get a pooled feature vector
        )

        # Freeze layers if specified
        if freeze_ratio > 0:
            parameters = list(self.model.parameters())
            total_layers = len(parameters)
            freeze_layers = int(total_layers * freeze_ratio)
            for i, param in enumerate(parameters):
                if i < freeze_layers:
                    param.requires_grad = False

    def forward(self, x):
        return self.model(x)


## Projection layer

In [7]:
PROJECTION_DIM = 256
dropout = 0.2

class ProjectionHead(nn.Module):
  def __init__(
      self,
      embedding_dim,
      projection_dim,
      dropout=dropout
  ):
    super().__init__()
    self.projection = nn.Linear(embedding_dim, projection_dim)
    self.gelu = nn.GELU()
    self.fc = nn.Linear(projection_dim, projection_dim)
    self.dropout = nn.Dropout(dropout)
    self.layer_norm=nn.LayerNorm(projection_dim)

    # Apply weight initialization
    nn.init.xavier_uniform_(self.projection.weight)
    nn.init.xavier_uniform_(self.fc.weight)

    # Initialize biases
    nn.init.zeros_(self.projection.bias)
    nn.init.zeros_(self.fc.bias)

  def forward(self, x):
          projected = self.projection(x)
          x = self.gelu(projected)
          x = self.fc(x)
          x = self.dropout(x)
          x = x + projected
          x = self.layer_norm(x)
          return x

## Text Encoder

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import openai
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from google.colab import userdata
import warnings

OPENAI_API_KEY = userdata.get("OPENAI_API_KEY")
HF_ACCESS_TOKEN = userdata.get("HF_TOKEN")

class SimilarVietEcom(nn.Module):
    def __init__(self, base_model_name="dangvantuan/vietnamese-document-embedding", output_dim=512, max_seq_length=1024):
        super(SimilarVietEcom, self).__init__()
        self.base_model = SentenceTransformer(base_model_name, trust_remote_code=True)
        self.base_model.max_seq_length = max_seq_length
        self.embedding_dim = self.base_model.get_sentence_embedding_dimension()  # 768
        self.linear = nn.Linear(self.embedding_dim, output_dim)

        # Initialize linear layer weights to prevent large initial outputs
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

    def forward(self, input_ids, attention_mask, **kwargs):
        embeddings = self.base_model({'input_ids': input_ids, 'attention_mask': attention_mask})['sentence_embedding']
        output = self.linear(embeddings)
        # Normalize output embeddings to unit length (L2 normalization)
        output = nn.functional.normalize(output, p=2, dim=-1)
        return output

class TextEncoder(nn.Module):
    def __init__(
        self,
        backend="bge-m3",  # Default backend
        task_prefix="search_document",
        trainable=True,  # Controls whether the linear layer is trainable
        hf_access_token=HF_ACCESS_TOKEN,
        openai_api_key=OPENAI_API_KEY,
        model_path=None,  # Path to pre-trained checkpoint for vietnamese-embedding or viet-ecom
        viet_ecom_output_dim=512,  # Output dim for SimilarVietEcom
        viet_ecom_max_seq_length=1024,  # Max seq length for SimilarVietEcom
        freeze_ratio = 0.5
    ):
        super().__init__()
        self.backend = backend
        self.task_prefix = task_prefix
        self.trainable = trainable
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # Initialize backend and get embedding_dims
        if backend == "openai":
            if not openai_api_key:
                raise ValueError("OpenAI API key is required.")
            self.api_key = openai_api_key
            self.model_name = "text-embedding-ada-002"
            self.embedding_dims = 1536  # OpenAI Ada-002
        elif backend == "bge":
            if not hf_access_token:
                raise ValueError("Hugging Face token is required for BGE.")
            self.model = SentenceTransformer(
                "BAAI/bge-large-en-v1.5",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            self.model_name = "bge"
            self.embedding_dims = 1024  # BGE large
        elif backend == "bge-m3":
            if not hf_access_token:
                raise ValueError("Hugging Face token is required for BGE.")
            self.model = SentenceTransformer(
                "BAAI/bge-m3",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            self.model_name = "bge-m3"
            self.embedding_dims = 1024  # BGE M3
        elif backend == "sup-SimCSE-VietNamese-phobert-base":
            if not hf_access_token:
                raise ValueError("Hugging Face token is required.")
            self.model = SentenceTransformer(
                "VoVanPhuc/sup-SimCSE-VietNamese-phobert-base",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            self.model_name = "sup-SimCSE-VietNamese-phobert-base"
            self.embedding_dims = 768  # VNM Embedding
        elif backend == "nomic":
            self.model = SentenceTransformer(
                "nomic-ai/nomic-embed-text-v2-moe",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            self.model_name = "nomic"
            self.embedding_dims = 768  # Nomic
        elif backend == "aligte":
            self.model = SentenceTransformer(
                "Alibaba-NLP/gte-multilingual-base",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            self.model_name = "aligte"
            self.embedding_dims = 768  # GTE
        elif backend == "labse":
            self.model = SentenceTransformer(
                "sentence-transformers/LaBSE",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            self.model_name = "labse"
            self.embedding_dims = 768  # LaBSE
        elif backend == "vietnamese-document-embedding":
            if not hf_access_token:
                raise ValueError("Hugging Face token is required for vietnamese-embedding.")
            self.model = SentenceTransformer(
                "dangvantuan/vietnamese-document-embedding",
                trust_remote_code=True,
                use_auth_token=hf_access_token
            )
            # Set max_seq_length to match model config
            self.model.max_seq_length = 512
            # Load pre-trained checkpoint if provided
            if model_path:
                self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            self.model_name = "vietnamese-document-embedding"
            self.embedding_dims = 768  # Match RoBERTa word_embedding_dimension
            self.max_seq_length = 512
        elif backend == "viet-ecom":
            if not hf_access_token:
                raise ValueError("Hugging Face token is required for viet-ecom.")
            self.model = SimilarVietEcom(
                base_model_name="dangvantuan/vietnamese-document-embedding",
                output_dim=viet_ecom_output_dim,
                max_seq_length=viet_ecom_max_seq_length
            )
            # Load pre-trained checkpoint if provided
            if model_path:
                self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            self.tokenizer = AutoTokenizer.from_pretrained("dangvantuan/vietnamese-document-embedding")
            self.model_name = "viet-ecom"
            self.embedding_dims = viet_ecom_output_dim
            self.max_seq_length = viet_ecom_max_seq_length
        else:
            raise ValueError(f"Invalid backend: {backend}")

        # Add a trainable linear layer
        self.linear = nn.Linear(self.embedding_dims, self.embedding_dims)

        # Initialize linear layer weights
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)

        # Freeze backend
        if hasattr(self, "model"):
            for param in self.model.parameters():
                param.requires_grad = False  # Backend always frozen

        # Only linear layer is trainable if trainable=True
        self.linear.requires_grad_(trainable)

        # Move to device
        self.to(self.device)

    def get_openai_embedding(self, texts):
        import openai
        openai.api_key = self.api_key
        response = openai.embeddings.create(input=texts, model=self.model_name)
        embeddings = [np.array(e.embedding, dtype=np.float32) for e in response.data]
        return torch.tensor(embeddings, dtype=torch.float32).to(self.device)

    def get_hf_embedding(self, texts):
        # Minimal validation to prevent CUDA errors
        cleaned_texts = []
        for text in texts:
            if not isinstance(text, str) or not text.strip():
                warnings.warn(f"Invalid text input: {text}. Replacing with placeholder.")
                text = "placeholder"
            cleaned_texts.append(text[:self.max_seq_length])

        with torch.no_grad():
            embeddings = self.model.encode(cleaned_texts)
        embeddings = torch.tensor(embeddings, dtype=torch.float32).to(self.device)

        if torch.isnan(embeddings).any() or torch.isinf(embeddings).any():
            print(f"NaN/Inf detected in embeddings: {embeddings}")
            raise ValueError("Invalid embeddings produced")

        return embeddings

    def get_viet_ecom_embedding(self, texts):
        encodings = self.tokenizer(
            texts,
            max_length=self.model.base_model.max_seq_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        input_ids = encodings['input_ids'].to(self.device)
        attention_mask = encodings['attention_mask'].to(self.device)

        with torch.no_grad():
            with torch.amp.autocast('cuda'):
                embeddings = self.model(input_ids, attention_mask)
        return embeddings

    def forward(self, texts):
        if not isinstance(texts, list):
            texts = [texts]
        if self.task_prefix:
            texts = [f"{self.task_prefix}: {text}" for text in texts]

        # Get embeddings from backend (non-trainable)
        if self.backend == "openai":
            embeddings = self.get_openai_embedding(texts)
        elif self.backend == "viet-ecom":
            embeddings = self.get_viet_ecom_embedding(texts)
        else:
            embeddings = self.get_hf_embedding(texts)

        # Pass through linear layer (trainable if self.trainable=True)
        embeddings = self.linear(embeddings)

        # L2 normalization
        embeddings = F.normalize(embeddings, p=2, dim=-1)

        return embeddings

## Base clip

In [9]:
class BaselineCLIPModel(nn.Module):
  def __init__(
      self,
      image_encoder_name,
      text_encoder_name,
      temperature,
      projection_dim=PROJECTION_DIM,
  ):
    super().__init__()
    self.image_encoder_name=image_encoder_name
    self.text_encoder_name=text_encoder_name
    self.projection_dim=projection_dim
    self.temperature = temperature

    self.image_encoder = ImageEncoder(image_encoder_name).to(device)
    self.text_encoder = TextEncoder(backend=text_encoder_name).to(device)
    self.image_projection= ProjectionHead(embedding_dim=self.image_encoder.configs["img_dims"], projection_dim=projection_dim).to(device)
    self.text_projection= ProjectionHead(embedding_dim=self.text_encoder.embedding_dims, projection_dim=projection_dim).to(device)

    if self.text_encoder_name == "nomic":
      for param in self.text_encoder.model.parameters():
        param.requires_grad = True

  def reinitialize(self):
    """
    Reinitialize the model components from scratch, even after training.
    """
    self.image_encoder = ImageEncoder(self.image_encoder_name).to(device)
    self.text_encoder = TextEncoder(backend=self.text_encoder_name).to(device)
    self.image_projection = ProjectionHead(embedding_dim=self.image_encoder.configs["img_dims"], projection_dim=self.projection_dim).to(device)
    self.text_projection = ProjectionHead(embedding_dim=self.text_encoder.embedding_dims, projection_dim=self.projection_dim).to(device)

  def forward(self, batch):
    # openai clip
    image_features = F.normalize(self.image_encoder(batch["image"]), p=2, dim=-1)
    text_features = F.normalize(self.text_encoder(batch["caption"]), p=2, dim=-1)

    image_embeddings = self.image_projection(image_features)
    text_embeddings = self.text_projection(text_features)

    logits = (text_embeddings @ image_embeddings.T) / self.temperature
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T

    targets = F.softmax(
            (images_similarity + texts_similarity) / ((2)* self.temperature), dim=-1
        )

    texts_loss = cross_entropy(logits, targets, reduction='none')
    images_loss = cross_entropy(logits.T, targets.T, reduction='none')

    # print("Text loss:", texts_loss.mean())
    # print("Image loss:", images_loss.mean())

    loss =  (images_loss + texts_loss) / 2.0
    return loss.mean()
  def __str__(self):
    return str(self.image_encoder.configs.get('model_name', 'unknown')) + "_"+ str(self.text_encoder.backend)+"_Open"


## Attention CLIP

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionCLIPModel(nn.Module):
    def __init__(
        self,
        image_encoder_name,
        text_encoder_name,

        temperature,
        alpha=0.5, # contribution of image loss, 0.6 - 0.8, 0.5 for equal learning
        projection_dim=512,  # PROJECTION_DIM m·∫∑c ƒë·ªãnh
        num_heads=8,        # S·ªë head cho cross-attention
        dropout=0.1         # Dropout cho attention
    ):
        super().__init__()
        self.image_encoder_name = image_encoder_name
        self.text_encoder_name = text_encoder_name
        self.projection_dim = projection_dim
        self.temperature = temperature
        self.alpha = alpha

        self.image_encoder = ImageEncoder(image_encoder_name).to(device)
        self.text_encoder = TextEncoder(backend=text_encoder_name).to(device)
        self.image_projection = ProjectionHead(
            embedding_dim=self.image_encoder.configs["img_dims"],
            projection_dim=projection_dim
        ).to(device)
        self.text_projection = ProjectionHead(
            embedding_dim=self.text_encoder.embedding_dims,
            projection_dim=projection_dim
        ).to(device)

        # Cross-attention layer
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=projection_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        ).to(device)
        self.norm = nn.LayerNorm(projection_dim).to(device)

        # ƒê·∫£m b·∫£o text_encoder c√≥ th·ªÉ hu·∫•n luy·ªán n·∫øu l√† Nomic
        if self.text_encoder_name == "nomic":
            for param in self.text_encoder.model.parameters():
                param.requires_grad = True

    def show_configs(self):
      print(f"Image Encoder: {self.image_encoder_name}")
      print(f"Text Encoder: {self.text_encoder_name}")
      print(f"Temperature: {self.temperature}")
      print(f"Alpha: {self.alpha}")
      print(f"Projection Dimension: {self.projection_dim}")
      print(f"Cross-Attention Heads: {self.cross_attention.num_heads}")
      print(f"Cross-Attention Dropout: {self.cross_attention.dropout}")

    def reinitialize(self):
        """
        Reinitialize the model components from scratch, even after training.
        """
        self.image_encoder = ImageEncoder(self.image_encoder_name).to(device)
        self.text_encoder = TextEncoder(backend=self.text_encoder_name).to(device)
        self.image_projection = ProjectionHead(
            embedding_dim=self.image_encoder.configs["img_dims"],
            projection_dim=self.projection_dim
        ).to(device)
        self.text_projection = ProjectionHead(
            embedding_dim=self.text_encoder.embedding_dims,
            projection_dim=self.projection_dim
        ).to(device)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.projection_dim,
            num_heads=8,
            dropout=0.1,
            batch_first=True
        ).to(device)
        self.norm = nn.LayerNorm(self.projection_dim).to(device)

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])  # [batch_size, img_dims]
        text_features = self.text_encoder(batch["caption"])  # [batch_size, embedding_dims]

        image_embeddings = self.image_projection(image_features)  # [batch_size, projection_dim]
        text_embeddings = self.text_projection(text_features)    # [batch_size, projection_dim]

        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        image_embeddings = image_embeddings.unsqueeze(1)  # [batch_size, 1, projection_dim]
        text_embeddings = text_embeddings.unsqueeze(1)    # [batch_size, 1, projection_dim]

        image_attn, _ = self.cross_attention(
            query=image_embeddings,
            key=text_embeddings,
            value=text_embeddings
        )

        text_attn, _ = self.cross_attention(
            query=text_embeddings,
            key=image_embeddings,
            value=image_embeddings
        )

        image_embeddings = image_attn.squeeze(1)  # [batch_size, projection_dim]
        text_embeddings = text_attn.squeeze(1)    # [batch_size, projection_dim]

        image_embeddings = image_embeddings + self.image_projection(image_features)
        text_embeddings = text_embeddings + self.text_projection(text_features)

        image_embeddings = self.norm(image_embeddings)
        text_embeddings = self.norm(text_embeddings)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T

        targets = F.softmax(
            (images_similarity+ texts_similarity) / (2*self.temperature),
            dim=-1
        )

        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')

        loss = images_loss*self.alpha + texts_loss*(1-self.alpha)
        return loss.mean()

    def __str__(self):
        return f"{self.image_encoder.configs.get('model_name', 'unknown')}_{self.text_encoder.backend}_CrossAttention"

def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

## Triplet Attention CLIP

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionCLIPTripletModel(nn.Module):
    def __init__(
        self,
        image_encoder_name,
        text_encoder_name,
        margin=0.2,  # Margin cho triplet loss
        alpha=0.5,   # ƒë√≥ng g√≥p gi·ªØa image anchor v√† text anchor loss
        projection_dim=512,
        num_heads=8,
        dropout=0.1
    ):
        super().__init__()
        self.image_encoder_name = image_encoder_name
        self.text_encoder_name = text_encoder_name
        self.projection_dim = projection_dim
        self.margin = margin
        self.alpha = alpha
        self.num_heads = num_heads

        self.image_encoder = ImageEncoder(image_encoder_name).to(device)
        self.text_encoder = TextEncoder(backend=text_encoder_name).to(device)
        self.image_projection = ProjectionHead(
            embedding_dim=self.image_encoder.configs["img_dims"],
            projection_dim=projection_dim
        ).to(device)
        self.text_projection = ProjectionHead(
            embedding_dim=self.text_encoder.embedding_dims,
            projection_dim=projection_dim
        ).to(device)

        self.cross_attention = nn.MultiheadAttention(
            embed_dim=projection_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        ).to(device)
        self.norm = nn.LayerNorm(projection_dim).to(device)

        if self.text_encoder_name == "nomic":
            for param in self.text_encoder.model.parameters():
                param.requires_grad = True

    def show_configs(self):
      print(f"Image Encoder: {self.image_encoder_name}")
      print(f"Text Encoder: {self.text_encoder_name}")
      print(f"Alpha: {self.alpha}")
      print(f"Projection Dimension: {self.projection_dim}")
      print(f"Cross-Attention Heads: {self.cross_attention.num_heads}")
      print(f"Cross-Attention Dropout: {self.cross_attention.dropout}")
      print(f"Margin: {self.margin}")

    def reinitialize(self):
        """
        Reinitialize the model components from scratch, even after training.
        """
        self.image_encoder = ImageEncoder(self.image_encoder_name).to(device)
        self.text_encoder = TextEncoder(backend=self.text_encoder_name).to(device)
        self.image_projection = ProjectionHead(
            embedding_dim=self.image_encoder.configs["img_dims"],
            projection_dim=self.projection_dim
        ).to(device)
        self.text_projection = ProjectionHead(
            embedding_dim=self.text_encoder.embedding_dims,
            projection_dim=self.projection_dim
        ).to(device)

        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.projection_dim,
            num_heads=self.num_heads,
            dropout=dropout,
            batch_first=True
        ).to(device)
        self.norm = nn.LayerNorm(self.projection_dim).to(device)

    def forward(self, batch):
        image_features = self.image_encoder(batch["image"])  # [batch_size, img_dims]
        text_features = self.text_encoder(batch["caption"])  # [batch_size, embedding_dims]

        image_embeddings = self.image_projection(image_features)  # [batch_size, projection_dim]
        text_embeddings = self.text_projection(text_features)    # [batch_size, projection_dim]

        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        image_embeddings = image_embeddings.unsqueeze(1)
        text_embeddings = text_embeddings.unsqueeze(1)

        image_attn, _ = self.cross_attention(
            query=image_embeddings,
            key=text_embeddings,
            value=text_embeddings
        )
        text_attn, _ = self.cross_attention(
            query=text_embeddings,
            key=image_embeddings,
            value=image_embeddings
        )

        image_embeddings = image_attn.squeeze(1)
        text_embeddings = text_attn.squeeze(1)

        image_embeddings = image_embeddings + self.image_projection(image_features)
        text_embeddings = text_embeddings + self.text_projection(text_features)

        image_embeddings = self.norm(image_embeddings)
        text_embeddings = self.norm(text_embeddings)
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)

        batch_size = image_embeddings.size(0)

        # Build triplets: anchor = image, positive = text, negative = text (j‚â†i)
        pos_distance = 1 - torch.sum(image_embeddings * text_embeddings, dim=-1)  # cosine distance: 1 - cosine sim

        # For each i, negative is text_embeddings of j‚â†i
        neg_distance = []
        for i in range(batch_size):
            negatives = torch.cat([text_embeddings[:i], text_embeddings[i+1:]], dim=0)
            anchor = image_embeddings[i].unsqueeze(0).repeat(batch_size - 1, 1)
            neg_dist = 1 - torch.sum(anchor * negatives, dim=-1)
            neg_distance.append(neg_dist)

        neg_distance = torch.cat(neg_distance, dim=0)

        # Triplet loss (image as anchor)
        image_triplet_loss = F.relu(pos_distance.unsqueeze(1) - neg_distance + self.margin).mean()

        # Build triplets: anchor = text, positive = image, negative = image (j‚â†i)
        pos_distance_t = 1 - torch.sum(text_embeddings * image_embeddings, dim=-1)

        neg_distance_t = []
        for i in range(batch_size):
            negatives = torch.cat([image_embeddings[:i], image_embeddings[i+1:]], dim=0)
            anchor = text_embeddings[i].unsqueeze(0).repeat(batch_size - 1, 1)
            neg_dist = 1 - torch.sum(anchor * negatives, dim=-1)
            neg_distance_t.append(neg_dist)

        neg_distance_t = torch.cat(neg_distance_t, dim=0)

        text_triplet_loss = F.relu(pos_distance_t.unsqueeze(1) - neg_distance_t + self.margin).mean()

        # Combine both
        loss = self.alpha * image_triplet_loss + (1 - self.alpha) * text_triplet_loss

        return loss, image_triplet_loss, text_triplet_loss

    def __str__(self):
        return f"{self.image_encoder.configs.get('model_name', 'unknown')}_{self.text_encoder.backend}_CrossAttention_Triplet"


## Search Module

In [12]:
import torch
import os
import numpy as np
import random
from pathlib import Path
import textwrap
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
from pinecone import Pinecone
from sklearn.metrics import average_precision_score
from scipy.stats import rankdata
import requests
from PIL import Image
from io import BytesIO
import pandas as pd

from google.colab import userdata
PINECONE_API_KEY=userdata.get('PINECONE_API_KEY')
cropped_dir="/content/drive/MyDrive/Training Drive/CLIPv8/cropped_dir"

class CLIPSearchModule:
    def __init__(self, model, model_path, pinecone_api_key=PINECONE_API_KEY, index_name='clipv8-mobile',
                 namespace_type='yolo-clip', device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.model = model.to(self.device)
        # N·∫øu model.load_state_dict c√≥ t√πy ch·ªçn weights_only, ki·ªÉm tra l·∫°i phi√™n b·∫£n c·ªßa b·∫°n
        self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=True))
        self.model.eval()
        self.namespace = '-'.join(str(namespace_type + '-' + model_path.split('/')[-2]).split('_')).lower()
        print("namespace:", self.namespace)
        self.index_name = index_name
        self.pinecone = Pinecone(api_key=pinecone_api_key)
        self.index = self.pinecone.Index(index_name)

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)), # 224 for resnet 384
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])
        ])

    def generate_text_embedding(self, text):
        with torch.no_grad():
            text_embedding = self.model.text_encoder(text)
            text_embedding = self.model.text_projection(text_embedding)
            return torch.nn.functional.normalize(text_embedding, dim=-1).cpu().numpy().flatten().tolist()

    def generate_image_embedding(self, image_path, cropped_dir=None):
      image = None

      # Case 1: image_path is a URL
      if isinstance(image_path, str) and (image_path.startswith("http://") or image_path.startswith("https://")):
          try:
              response = requests.get(image_path, timeout=10)
              response.raise_for_status()
              image = Image.open(BytesIO(response.content)).convert("RGB")
          except Exception as e:
              return {"error": f"Failed to download image from URL: {str(e)}"}

      # Case 2: image_path is a BytesIO
      elif isinstance(image_path, BytesIO):
          try:
              image = Image.open(image_path).convert("RGB")
          except Exception as e:
              return {"error": f"Failed to read image from BytesIO: {str(e)}"}

      # Case 3: image_path is a local file (Path or str)
      elif isinstance(image_path, (str, Path)):
          image_path = Path(image_path)  # Ensure it's a Path object for robustness
          if not image_path.is_file():
              # Check in cropped_dir if provided
              if cropped_dir:
                  image_path = Path(cropped_dir) / image_path
                  if not image_path.is_file():
                      return {"error": f"Image not found at {image_path}"}
              else:
                  return {"error": f"Image not found at {image_path}"}
          try:
              image = Image.open(image_path).convert("RGB")
          except Exception as e:
              return {"error": f"Failed to open image file: {str(e)}"}

      else:
          return {"error": f"Unsupported image input type: {type(image_path)}"}

      # Transform and encode
      try:
          image_tensor = self.transform(image).unsqueeze(0).to(self.device)
          with torch.no_grad():
              image_embedding = self.model.image_encoder(image_tensor)
              image_embedding = self.model.image_projection(image_embedding)
              return torch.nn.functional.normalize(image_embedding, dim=-1).cpu().numpy().flatten().tolist()
      except Exception as e:
          return {"error": f"Failed during embedding generation: {str(e)}"}

    def search_pinecone(self, query_embedding, top_k=5, filter=None, include_values=True, include_metadata=True):
        filter = filter if filter is not None else {}
        search_results = self.index.query(
            vector=query_embedding, top_k=top_k, namespace=self.namespace,
            filter=filter, include_values=include_values, include_metadata=include_metadata
        )
        return search_results.get("matches", [])

    def search_by_text(self, text_query, top_k=10, include_values=True, include_metadata=True, filter={"type": {"$eq": "1"}}):
        text_embedding = self.generate_text_embedding(text_query)
        return {
            "query": text_query,
            "query_embedding": text_embedding,
            "top_k_results": self.search_pinecone(query_embedding=text_embedding, top_k=top_k, filter=filter, include_values=include_values, include_metadata=include_metadata)
        }

    def search_by_image(self, image_path, top_k=5, include_values=True, include_metadata=True, filter={"type": {"$eq": "0"}}):
        image_embedding = self.generate_image_embedding(image_path)
        return {
            "query": image_path,
            "query_embedding": image_embedding,
            "top_k_results": self.search_pinecone(query_embedding=image_embedding, top_k=top_k, filter=filter, include_values=include_values, include_metadata=include_metadata)
        }

    def _ndcg_score(self, relevance):
        """ Compute Normalized Discounted Cumulative Gain (nDCG) """
        dcg = np.sum((2**relevance - 1) / np.log2(np.arange(2, len(relevance) + 2)))
        ideal_dcg = np.sum((2**np.sort(relevance)[::-1] - 1) / np.log2(np.arange(2, len(relevance) + 2)))
        return dcg / ideal_dcg if ideal_dcg > 0 else 0

    def evaluate(self, df, top_k_list=[1, 5, 10], sample_size=1000):
      queries = []
      for _, row in df.iterrows():
          product_id = str(row["id"]).strip()
          keywords = []
          for lvl in ["category_level_1", "category_level_2", "category_level_3", "category_level_4"]:
              kw = str(row.get(lvl, "")).strip()
              if kw and kw.lower() not in [x.lower() for x in keywords]:
                  keywords.append(kw)
          for kw in keywords:
              queries.append((product_id, kw))

      queries = random.sample(queries, min(sample_size, len(queries)))
      image_queries = df.sample(min(sample_size, len(df))).to_dict(orient="records")

      # TEXT QUERY
      text_top_k_acc = {k: [] for k in top_k_list}
      text_sim_scores = {k: [] for k in top_k_list}
      text_mrr, text_map, text_ndcg = [], [], []

      for product_id, keyword in tqdm(queries, desc="Evaluating text queries"):
          query_embedding = self.generate_text_embedding(keyword)
          if query_embedding is None:
              continue

          results = self.search_pinecone(query_embedding, top_k=max(top_k_list), filter={"type": {"$eq": "1"}})
          if not results:
              continue

          retrieved_ids, scores, embeddings = [], [], []
          for match in results:
              meta_pid = match["metadata"].get("product_id")
              if not meta_pid:
                  continue
              retrieved_ids.append(str(meta_pid))
              scores.append(match.get("score", 0))
              embeddings.append(match.get("values"))

          if not retrieved_ids:
              continue

          valid_ids = set([product_id])
          row_match = df[df["id"] == int(product_id)]
          if not row_match.empty and "same_product_ids" in row_match.columns:
              same_ids = row_match.iloc[0]["same_product_ids"]
              if isinstance(same_ids, list):
                  valid_ids.update([str(pid) for pid in same_ids])

          relevance = np.array([1 if rid in valid_ids else 0 for rid in retrieved_ids])
          if np.sum(relevance) > 0:
              text_mrr.append(1 / (np.argmax(relevance) + 1))
              text_map.append(average_precision_score(relevance, scores))
              text_ndcg.append(self._ndcg_score(relevance))

          for k in top_k_list:
              text_top_k_acc[k].append(str(product_id) in retrieved_ids[:k])
              if len(embeddings) >= k:
                  sim_scores = np.dot(np.array(embeddings[:k]), np.array(query_embedding))
                  text_sim_scores[k].append(np.mean(sim_scores))

      # IMAGE QUERY
      image_top_k_acc = {k: [] for k in top_k_list}
      image_sim_scores = {k: [] for k in top_k_list}
      image_mrr, image_map, image_ndcg = [], [], []

      for row in tqdm(image_queries, desc="Evaluating image queries"):
          product_id = str(row["id"])
          image_embedding = self.generate_image_embedding(row["image_path"]) # cropped_image
          if isinstance(image_embedding, dict) and "error" in image_embedding:
              continue

          results = self.search_pinecone(image_embedding, top_k=max(top_k_list), filter={"type": {"$eq": "0"}})
          if not results:
              continue

          retrieved_ids, scores, embeddings = [], [], []
          for match in results:
              meta_pid = match["metadata"].get("product_id")
              if not meta_pid:
                  continue
              retrieved_ids.append(str(meta_pid))
              scores.append(match.get("score", 0))
              embeddings.append(match.get("values"))

          if not retrieved_ids:
              continue

          valid_ids = set([product_id])
          if "same_product_ids" in row and isinstance(row["same_product_ids"], list):
              valid_ids.update([str(pid) for pid in row["same_product_ids"]])

          relevance = np.array([1 if rid in valid_ids else 0 for rid in retrieved_ids])
          if np.sum(relevance) > 0:
              image_mrr.append(1 / (np.argmax(relevance) + 1))
              image_map.append(average_precision_score(relevance, scores))
              image_ndcg.append(self._ndcg_score(relevance))

          for k in top_k_list:
              image_top_k_acc[k].append(product_id in retrieved_ids[:k])
              if len(embeddings) >= k:
                  sim_scores = np.dot(np.array(embeddings[:k]), np.array(image_embedding))
                  image_sim_scores[k].append(np.mean(sim_scores))

      return {
          # TEXT ‚Üí IMAGE
          "text_top_k_accuracy": {k: np.mean(text_top_k_acc[k]) if text_top_k_acc[k] else 0 for k in top_k_list},
          "text_similarity": {k: np.mean(text_sim_scores[k]) if text_sim_scores[k] else 0 for k in top_k_list},
          "text_mrr": np.mean(text_mrr) if text_mrr else 0,
          "text_map": np.mean(text_map) if text_map else 0,
          "text_ndcg": np.mean(text_ndcg) if text_ndcg else 0,

          # IMAGE ‚Üí TEXT
          "image_top_k_accuracy": {k: np.mean(image_top_k_acc[k]) if image_top_k_acc[k] else 0 for k in top_k_list},
          "image_similarity": {k: np.mean(image_sim_scores[k]) if image_sim_scores[k] else 0 for k in top_k_list},
          "image_mrr": np.mean(image_mrr) if image_mrr else 0,
          "image_map": np.mean(image_map) if image_map else 0,
          "image_ndcg": np.mean(image_ndcg) if image_ndcg else 0,
      }

    def eval_test(self, test_df, top_k_list=[1, 5, 10], sample_size=1000, category_levels=None):
        """
        Evaluate CLIP model on a test set for text-to-image and image-to-text retrieval.

        Args:
            test_df (pd.DataFrame): DataFrame with columns 'id', 'caption', 'image_path'.
            top_k_list (list): List of k values for top-k accuracy (e.g., [1, 5, 10]).
            sample_size (int): Number of queries to sample for evaluation.
            category_levels (list): List of category level keys to check for relevance (e.g., ['category_level_1', 'category_level_2']).
                                  If None, defaults to ['category_level_1', 'category_level_2', 'category_level_3', 'category_level_4'].

        Returns:
            dict: Metrics for text-to-image and image-to-text retrieval.
        """
        # Step 1: Set default category levels if not provided
        if category_levels is None:
            category_levels = ["category_level_1", "category_level_2", "category_level_3", "category_level_4"]

        # Step 2: Extract unique captions and image paths
        # For text queries, use unique (id, caption) pairs
        text_queries = test_df[['id', 'caption']].drop_duplicates().to_dict(orient="records")
        # For image queries, use unique (id, image_path) pairs
        image_queries = test_df[['id', 'image_path']].drop_duplicates().to_dict(orient="records")

        # Sample queries if necessary
        text_queries = random.sample(text_queries, min(sample_size, len(text_queries)))
        image_queries = random.sample(image_queries, min(sample_size, len(image_queries)))

        # Step 3: Initialize metrics
        text_top_k_acc = {k: [] for k in top_k_list}
        text_mrr, text_map, text_ndcg = [], [], []

        image_top_k_acc = {k: [] for k in top_k_list}
        image_mrr, image_map, image_ndcg = [], [], []

        # Step 4: Evaluate Text Queries
        for query in tqdm(text_queries, desc="Evaluating text queries"):
            product_id = str(query["id"])
            caption = query["caption"]

            # Generate text embedding
            query_embedding = self.generate_text_embedding(caption)
            if query_embedding is None:
                text_mrr.append(0)
                text_map.append(0)
                text_ndcg.append(0)
                for k in top_k_list:
                    text_top_k_acc[k].append(False)
                continue

            # Search Pinecone
            results = self.search_pinecone(query_embedding, top_k=max(top_k_list), filter={"type": {"$eq": "1"}})
            if not results:
                text_mrr.append(0)
                text_map.append(0)
                text_ndcg.append(0)
                for k in top_k_list:
                    text_top_k_acc[k].append(False)
                continue

            # Extract retrieved IDs and scores
            retrieved_ids, scores = [], []
            retrieved_metadata = []
            for match in results:
                meta_pid = match["metadata"].get("product_id")
                if not meta_pid:
                    continue
                retrieved_ids.append(str(meta_pid))
                scores.append(match.get("score", 0))
                retrieved_metadata.append(match["metadata"])  # Store metadata for category/name checking

            if not retrieved_ids:
                text_mrr.append(0)
                text_map.append(0)
                text_ndcg.append(0)
                for k in top_k_list:
                    text_top_k_acc[k].append(False)
                continue

            # Step 5: Label relevance for text queries dynamically
            # Check if retrieved item's category or name matches the caption
            relevance = []
            caption_lower = caption.lower().strip()
            for meta in retrieved_metadata:
                is_relevant = False
                # Check category levels
                for lvl in category_levels:
                    cat = str(meta.get(lvl, "")).lower().strip()
                    if cat and caption_lower in cat:
                        is_relevant = True
                        break
                # Check product name
                if not is_relevant:
                    product_name = str(meta.get("product_name", "")).lower().strip()
                    if product_name and caption_lower in product_name:
                        is_relevant = True
                # Check if the retrieved product ID matches the query product ID
                if not is_relevant:
                    if str(meta.get("product_id", "")) == product_id:
                        is_relevant = True
                relevance.append(1 if is_relevant else 0)

            relevance = np.array(relevance)
            if np.sum(relevance) > 0:
                text_mrr.append(1 / (np.argmax(relevance) + 1))
                text_map.append(average_precision_score(relevance, scores))
                text_ndcg.append(self._ndcg_score(relevance))

            for k in top_k_list:
                text_top_k_acc[k].append(np.sum(relevance[:k]) > 0)

        # Step 6: Evaluate Image Queries
        for query in tqdm(image_queries, desc="Evaluating image queries"):
            product_id = str(query["id"])
            image_path = query["image_path"]

            # Generate image embedding
            image_embedding = self.generate_image_embedding(image_path)
            if isinstance(image_embedding, dict) and "error" in image_embedding:
                image_mrr.append(0)
                image_map.append(0)
                image_ndcg.append(0)
                for k in top_k_list:
                    image_top_k_acc[k].append(False)
                continue

            # Search Pinecone
            results = self.search_pinecone(image_embedding, top_k=max(top_k_list), filter={"type": {"$eq": "0"}})
            if not results:
                image_mrr.append(0)
                image_map.append(0)
                image_ndcg.append(0)
                for k in top_k_list:
                    image_top_k_acc[k].append(False)
                continue

            # Extract retrieved IDs and scores
            retrieved_ids, scores = [], []
            for match in results:
                meta_pid = match["metadata"].get("product_id")
                if not meta_pid:
                    continue
                retrieved_ids.append(str(meta_pid))
                scores.append(match.get("score", 0))

            if not retrieved_ids:
                image_mrr.append(0)
                image_map.append(0)
                image_ndcg.append(0)
                for k in top_k_list:
                    image_top_k_acc[k].append(False)
                continue

            # Step 7: Label relevance for image queries
            # Compare retrieved product ID with the product ID extracted from the image path
            relevance = np.array([1 if rid == product_id else 0 for rid in retrieved_ids])

            if np.sum(relevance) > 0:
                image_mrr.append(1 / (np.argmax(relevance) + 1))
                image_map.append(average_precision_score(relevance, scores))
                image_ndcg.append(self._ndcg_score(relevance))

            for k in top_k_list:
                image_top_k_acc[k].append(product_id in retrieved_ids[:k])

        # Step 8: Return results
        return {
            # TEXT ‚Üí IMAGE
            "text_top_k_accuracy": {k: np.mean(text_top_k_acc[k]) if text_top_k_acc[k] else 0 for k in top_k_list},
            "text_mrr": np.mean(text_mrr) if text_mrr else 0,
            "text_map": np.mean(text_map) if text_map else 0,
            "text_ndcg": np.mean(text_ndcg) if text_ndcg else 0,

            # IMAGE ‚Üí TEXT
            "image_top_k_accuracy": {k: np.mean(image_top_k_acc[k]) if image_top_k_acc[k] else 0 for k in top_k_list},
            "image_mrr": np.mean(image_mrr) if image_mrr else 0,
            "image_map": np.mean(image_map) if image_map else 0,
            "image_ndcg": np.mean(image_ndcg) if image_ndcg else 0,
        }

    def visualize_query_example(self, text_query, image_base_dir="/content/drive/MyDrive/dataset/product_images"):
      results = self.search_by_text(text_query, top_k=10)
      product_ids = [match["metadata"].get("product_id") for match in results["top_k_results"] if match["metadata"].get("product_id")]
      similarities = [match.get("score", 0) for match in results["top_k_results"]]

      if not product_ids:
          print("No product IDs found in results.")
          return

      # Build image paths
      image_paths = []
      for pid in product_ids:
          image_path = os.path.join(image_base_dir, f"{pid}_1.jpg")
          if os.path.exists(image_path):
              image_paths.append(image_path)
          else:
              print(f"Image not found for product_id: {pid} at {image_path}")
              image_paths.append(None)

      if all(path is None for path in image_paths):
          print("No valid image paths found.")
          return

      # Plot
      fig, axes = plt.subplots(1, min(10, len(image_paths)), figsize=(20, 5))
      fig.suptitle(f"Query: {text_query}", fontsize=16)

      for i, image_path in enumerate(image_paths):
          ax = axes[i] if len(image_paths) > 1 else axes
          if image_path and os.path.exists(image_path):
              image = Image.open(image_path).convert("RGB")
              ax.imshow(image)
              ax.set_title(f"Sim: {similarities[i]:.2f}", fontsize=9)
          else:
              ax.text(0.5, 0.5, "Image not found", ha="center", va="center")
          ax.axis("off")

      plt.tight_layout()
      plt.show()
    def visualize_image_query(self, text_query, df, image_base_dir="/content/drive/MyDrive/dataset/product_images"):
        """
        Visualize image query results by finding a product in the DataFrame matching the text query,
        using its image to search the vector database, and displaying the query image, its caption,
        and top 10 similar products with names, similarities, and images.

        Args:
            text_query (str): Text query to search for a product.
            df (pd.DataFrame): DataFrame with columns 'id', 'caption', 'image_path', 'product_name'.
            image_base_dir (str): Base directory for product images.
        """
        # Ensure matplotlib is set for Colab
        try:
            # %matplotlib inline
            import matplotlib
            print(f"Matplotlib backend: {matplotlib.get_backend()}")
        except:
            pass

        # Step 1: Find product in DataFrame
        text_query_lower = text_query.lower().strip()
        matching_row = None
        for _, row in df.iterrows():
            caption = str(row.get("caption", "")).lower().strip()
            product_name = str(row.get("product_name", "")).lower().strip()
            if text_query_lower in caption or text_query_lower in product_name:
                matching_row = row
                break

        if matching_row is None:
            print(f"No product found matching query: {text_query}")
            return

        # Step 2: Get image path and caption
        query_image_path = matching_row.get("image_path")
        query_caption = matching_row.get("caption", "No caption available")
        query_product_name = matching_row.get("product_name", "Unknown product")
        print(f"Query image path: {query_image_path}")
        print(f"Query caption: {query_caption}")

        # Step 3: Verify query image
        if not query_image_path or (not query_image_path.startswith("http") and not os.path.exists(query_image_path)):
            print(f"Image not found for product: {query_product_name} at {query_image_path}")
            return

        # Step 4: Perform image search
        search_results = self.search_by_image(query_image_path, top_k=10)
        top_k_results = search_results.get("top_k_results", [])
        print(f"Retrieved {len(top_k_results)} results")

        if not top_k_results:
            print("No results found in vector database.")
            return

        # Step 5: Extract results
        product_ids = [match["metadata"].get("product_id") for match in top_k_results if match["metadata"].get("product_id")]
        similarities = [match.get("score", 0) for match in top_k_results]
        product_names = [match["metadata"].get("product_name", "Unknown") for match in top_k_results]
        print(f"Product IDs: {product_ids}")
        print(f"Product names: {product_names}")

        # Step 6: Build image paths
        result_image_paths = []
        for pid in product_ids:
            image_path = os.path.join(image_base_dir, f"{pid}_1.jpg")
            print(f"Checking image: {image_path}, exists: {os.path.exists(image_path)}")
            result_image_paths.append(image_path if os.path.exists(image_path) else None)
        print(f"Result image paths: {result_image_paths}")

        # Step 7: Create visualization
        n_results = min(10, len(top_k_results))
        cols = 5
        rows = 1 + (n_results + cols - 1) // cols  # 1 row for query image + rows for results
        fig = plt.figure(figsize=(20, 4 * rows))

        # Query image
        plt.subplot(rows, cols, (1, cols))  # Span first row
        if query_image_path.startswith("http"):
            print(f"Fetching URL: {query_image_path}")
            try:
                response = requests.get(query_image_path, timeout=10)
                response.raise_for_status()
                query_image = Image.open(BytesIO(response.content)).convert("RGB")
            except Exception as e:
                print(f"Failed to load URL: {e}")
                query_image = None
        else:
            print(f"Loading local image: {query_image_path}")
            query_image = Image.open(query_image_path).convert("RGB") if os.path.exists(query_image_path) else None

        if query_image:
            plt.imshow(query_image)
            plt.title(f"Query: {query_product_name}\nCaption: {textwrap.shorten(query_caption, width=50)}", fontsize=12)
        else:
            plt.text(0.5, 0.5, "Query image not available", ha="center", va="center", fontsize=12)
        plt.axis("off")

        # Results grid
        if n_results > 0:
            print(f"Rendering {n_results} results in {rows - 1} row(s) of {cols} columns")
            for i in range(n_results):
                # Start results in second row, spanning columns
                subplot_idx = (cols + 1) + i  # Start after first row
                plt.subplot(rows, cols, subplot_idx)
                if result_image_paths[i] and os.path.exists(result_image_paths[i]):
                    image = Image.open(result_image_paths[i]).convert("RGB")
                    plt.imshow(image)
                else:
                    plt.text(0.5, 0.5, "Image not found", ha="center", va="center")
                plt.title(f"{textwrap.shorten(product_names[i], width=20)}\nSim: {similarities[i]:.2f}", fontsize=9)
                plt.axis("off")

        plt.tight_layout()
        plt.show()
          # plt.savefig('image_query_results.png')
          # plt.close()


# Test model

In [20]:
triplets_attention_clip = AttentionCLIPTripletModel(image_encoder_name="swin-base-patch4-window7-224",
                                    text_encoder_name="vietnamese-document-embedding",
                                    alpha=0.5,
                                    projection_dim=512)
triplets_attention_clip.show_configs()
model_path ="/content/drive/MyDrive/models/swin_base_patch4_window7_224_vietnamese-document-embedding_CrossAttention_Triplet_from_dataset_adamw/best.pt"
csm = CLIPSearchModule(triplets_attention_clip, model_path)



Image Encoder: swin-base-patch4-window7-224
Text Encoder: vietnamese-document-embedding
Alpha: 0.5
Projection Dimension: 512
Cross-Attention Heads: 8
Cross-Attention Dropout: 0.1
Margin: 0.2
namespace: yolo-clip-swin-base-patch4-window7-224-vietnamese-document-embedding-crossattention-triplet-from-dataset-adamw


In [13]:
results =csm.search_by_text("S√°ch vƒÉn h·ªçc n∆∞·ªõc ngo√†i", include_values=False)
for res in results.get("top_k_results"):
  print(f"{res.get('metadata').get('product_name').strip()}: {res.get('score')}")
  print(f"{res.get('metadata').get('benefits_text').strip()}")
  print(f"{res.get('metadata').get('description').strip()}")
  print(f"{res.get('metadata').get('installment_info_text').strip()}")
  print(f"{res.get('metadata').get('return_policies_text').strip()}")
  print(f"{res.get('metadata').get('specification_text').strip()}\n\n")

D·∫´n D·∫Øt M·ªôt B·∫ßy S√≥i Hay ChƒÉn M·ªôt ƒê√†n C·ª´u: 0.533748865
ƒê∆∞·ª£c ƒë·ªìng ki·ªÉm khi nh·∫≠n h√†ng <b>ƒê∆∞·ª£c ho√†n ti·ªÅn 200%</b> n·∫øu l√† h√†ng gi·∫£. ƒê·ªïi tr·∫£ mi·ªÖn ph√≠ trong 30 ng√†y. ƒê∆∞·ª£c ƒë·ªïi √Ω.
<p>Trong cu·ªôc chi·∫øn thu h√∫t kh√°ch h√†ng, c√°c doanh nghi·ªáp ƒë·∫ßu t∆∞ h√†ng tri·ªáu ƒë√¥ la ƒë·ªÉ c·∫£i thi·ªán tr·∫£i nghi·ªám c·ªßa kh√°ch h√†ng. H·ªç giao h√†ng nhanh h∆°n, tung ra c√°c s·∫£n ph·∫©m m·ªõi v√† kh√¥ng ng·ª´ng c·∫£i ti·∫øn giao di·ªán ng∆∞·ªùi d√πng, v√† h·ªç th∆∞·ªùng g√¢y √°p l·ª±c l·ªõn h∆°n cho nh√¢n vi√™n v√¨ l·ª£i nhu·∫≠n gi·∫£m d·∫ßn. Theo t√°c gi·∫£ Tiffani Bova, vi·ªác t·∫≠p trung duy nh·∫•t v√†o tr·∫£i nghi·ªám c·ªßa kh√°ch h√†ng ‚Äì m√† kh√¥ng xem x√©t t√°c ƒë·ªông ƒë·∫øn nh√¢n vi√™n c·ªßa b·∫°n ‚Äì th·ª±c s·ª± c·∫£n tr·ªü s·ª± ph√°t tri·ªÉn v·ªÅ l√¢u d√†i. C√°c c√¥ng ty th√†nh c√¥ng nh·∫•t √°p d·ª•ng T∆∞ duy tr·∫£i nghi·ªám ƒë·ªÉ c·ªßng c·ªë c·∫£ tr·∫£i nghi·ªám c·ªßa nh√¢n vi√™n (EX) v√† tr·∫£i nghi·ªám c·ªßa kh√°ch h√†ng (

In [14]:
results.get("top_k_results")

[{'id': '34227041-04f9-49a3-8bd3-d223c7a0c150',
  'metadata': {'benefits_text': 'ƒê∆∞·ª£c ƒë·ªìng ki·ªÉm khi nh·∫≠n h√†ng <b>ƒê∆∞·ª£c ho√†n ti·ªÅn '
                                '200%</b> n·∫øu l√† h√†ng gi·∫£. ƒê·ªïi tr·∫£ mi·ªÖn ph√≠ '
                                'trong 30 ng√†y. ƒê∆∞·ª£c ƒë·ªïi √Ω.',
               'caption': 'D·∫´n D·∫Øt M·ªôt B·∫ßy S√≥i Hay ChƒÉn M·ªôt ƒê√†n C·ª´u Nh√† S√°ch '
                          'Tiki S√°ch ti·∫øng Vi·ªát S√°ch kinh t·∫ø S√°ch qu·∫£n tr·ªã, '
                          'l√£nh ƒë·∫°o <p>Trong cu·ªôc chi·∫øn thu h√∫t kh√°ch h√†ng, c√°c '
                          'doanh nghi·ªáp ƒë·∫ßu t∆∞ h√†ng tri·ªáu ƒë√¥ la ƒë·ªÉ c·∫£i thi·ªán '
                          'tr·∫£i nghi·ªám c·ªßa kh√°ch h√†ng. H·ªç giao h√†ng nhanh h∆°n, '
                          'tung ra c√°c s·∫£n ph·∫©m m·ªõi v√† kh√¥ng ng·ª´ng c·∫£i ti·∫øn '
                          'giao di·ªán ng∆∞·ªùi d√πng, v√† h·ªç th∆∞·ªùng g√¢y √°p l·ª±c l·ªõn '
                          

# API

In [13]:
triplets_attention_clip = AttentionCLIPTripletModel(image_encoder_name="swin-base-patch4-window7-224",
                                    text_encoder_name="vietnamese-document-embedding",
                                    alpha=0.5,
                                    projection_dim=512)
triplets_attention_clip.show_configs()
model_path ="/content/drive/MyDrive/models/swin_base_patch4_window7_224_vietnamese-document-embedding_CrossAttention_Triplet_from_dataset_adamw/best.pt"
csm = CLIPSearchModule(triplets_attention_clip, model_path)

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]



modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/6.09k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/54.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

configuration.py:   0%|          | 0.00/6.09k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/dangvantuan/Vietnamese_impl:
- configuration.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling.py:   0%|          | 0.00/53.6k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/dangvantuan/Vietnamese_impl:
- modeling.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.34k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Image Encoder: swin-base-patch4-window7-224
Text Encoder: vietnamese-document-embedding
Alpha: 0.5
Projection Dimension: 512
Cross-Attention Heads: 8
Cross-Attention Dropout: 0.1
Margin: 0.2
namespace: yolo-clip-swin-base-patch4-window7-224-vietnamese-document-embedding-crossattention-triplet-from-dataset-adamw


In [14]:
import os
import io
import datetime
import base64
import tempfile
from fastapi import FastAPI, Body, HTTPException, Request, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator
from typing import Dict, List, Optional
from bs4 import BeautifulSoup
import re
import html
import logging

import nest_asyncio
from pyngrok import ngrok
nest_asyncio.apply()

# Initialize FastAPI app
app = FastAPI()

# Middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Logging utility
LOG_FILE = "./log.txt"

def clean_html_xml(text: str) -> str:
    if not text:
        return ""
    soup = BeautifulSoup(text, "lxml")
    for tag in soup(["script", "style", "noscript", "iframe"]):
        tag.decompose()

    raw_text = soup.get_text(separator=' ', strip=True)
    unescaped_text = html.unescape(raw_text)
    no_hashtags = re.sub(r'#\w+', '', unescaped_text)
    cleaned_text = re.sub(r'\s+', ' ', no_hashtags).strip()

    return cleaned_text

def parse_document(top_k_results):
  documents = []
  for result in top_k_results:
    product = {}
    product["name"] = result.get("metadata").get("product_name")
    product["description"] = clean_html_xml(result.get("metadata").get("description"))
    product["specifications"] = clean_html_xml(result.get("metadata").get("specification_text"))
    product["benefits"] = clean_html_xml(result.get("metadata").get("benefits_text"))
    product["price"] = result.get("metadata").get("price")
    product["database_id"] = result.get("metadata").get("database_id")

    documents.append(product)
  return documents

def log_to_file(message):
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    with open(LOG_FILE, "a", encoding="utf-8") as f:
        f.write(f"[{timestamp}] {message}\n")

def log_and_raise(status_code, error, message):
    log_to_file(f"{error}: {message}")
    raise HTTPException(
        status_code=status_code,
        detail={"status": "error", "error": error, "message": message}
    )

# Search models
class ImageSearchRequest(BaseModel):
    image: str = Field(..., description="Base64-encoded image string")
    filter: Optional[Dict] = Field({}, description="Optional filter for product search")

    @field_validator("image")
    def validate_image(cls, value):
        try:
            base64.b64decode(value, validate=True)
        except Exception as e:
            raise ValueError(f"Invalid base64 image: {str(e)}")
        return value

class ImageSearchResponse(BaseModel):
    query: str
    query_embedding: List[float]
    top_k_results: List[Dict]
    status: str
    latency: float

class TextSearchRequest(BaseModel):
    query: str
    filter: Optional[Dict] = {}

class TextSearchResponse(BaseModel):
    query: str
    query_embedding: List[float]
    top_k_results: List[Dict]
    status: str
    latency: float

# Shared internal search handler
def _process_search(func, *args, **kwargs):
    start_time = datetime.datetime.now()
    try:
        results = func(*args, **kwargs)
    except Exception as e:
        log_and_raise(500, "Search failed", str(e))
    latency = (datetime.datetime.now() - start_time).total_seconds()
    return results, latency

# Image search endpoint
from fastapi import HTTPException
from PIL import UnidentifiedImageError

@app.post("/api/v1/inference/search-image/")
async def get_image_response(
    image: UploadFile = File(...)
):
    try:
        # B·ªè filter ph·ª©c t·∫°p, hardcode nh∆∞ b·∫°n mong mu·ªën
        filter_dict = {"type": {"$eq": "1"}}

        # ƒê·ªçc v√† decode ·∫£nh
        image_bytes = await image.read()
        try:
            img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        except UnidentifiedImageError as e:
            log_and_raise(400, "Invalid image", f"Cannot identify image file: {str(e)}")

        # L∆∞u v√†o file t·∫°m
        with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
            img.save(temp_file.name, "JPEG")
            temp_image_path = temp_file.name

    except HTTPException as e:
        # Re-raise ƒë·ªÉ FastAPI hi·ªÉu ƒë√∫ng l·ªói
        raise e
    except Exception as e:
        log_and_raise(500, "Image processing error", f"Unexpected error when processing image: {str(e)}")

    try:
        try:
            # Th·ª±c thi t√¨m ki·∫øm
            results, latency = _process_search(
                csm.search_by_image,
                image_path=temp_image_path,
                top_k=10,
                include_values=False,
                include_metadata=True,
                filter=filter_dict
            )

            # Sanity check k·∫øt qu·∫£
            if not isinstance(results, dict) or "top_k_results" not in results:
                log_to_file(f"Invalid search results format: {results}")
                raise HTTPException(status_code=500, detail="Invalid result format from search")

            log_to_file(f"Image search success. Found {len(results['top_k_results'])} products.")
            log_to_file(parse_document(results['top_k_results']))

        except HTTPException as e:
            raise e
        except Exception as e:
            log_and_raise(500, "Search failed", f"Error during image search: {str(e)}")

        # Chu·∫©n h√≥a tr·∫£ v·ªÅ an to√†n
        return {
            'top_k_results': parse_document(results['top_k_results']),
            'status': "success",
            'latency': latency
        }

    finally:
        # Cleanup file t·∫°m an to√†n
        if os.path.exists(temp_image_path):
            os.unlink(temp_image_path)


@app.post("/api/v1/inference/search-text/")
async def get_text_response(body: TextSearchRequest):
    if not body.query or not body.query.strip():
        raise HTTPException(status_code=400, detail="Missing or empty 'query' field")
    body.filter["type"] = {"$eq": "1"}
    log_to_file(f"Received text search request for query: {body.query}, filter: {body.filter}")

    try:

        # X·ª≠ l√Ω t√¨m ki·∫øm
        results, latency = _process_search(
            csm.search_by_text,
            text_query=body.query,
            top_k=10,
            include_values=False,
            include_metadata=True,
            filter=body.filter  # truy·ªÅn filter n·∫øu c√≥
        )

        log_to_file(f"Text search success. Found {len(results['top_k_results'])} products.")
        log_to_file(parse_document(results['top_k_results']))

        return {
            'query': body.query,
            'top_k_results': parse_document(results['top_k_results']),
            'status': "success",
            'latency': latency
        }

    except Exception as e:
        log_to_file(f"Text search failed: {str(e)}")
        raise HTTPException(status_code=500, detail="Failed to process text search")


In [15]:
import threading
import uvicorn
def run():
    uvicorn.run(app, host="0.0.0.0", port=8000)

threading.Thread(target=run, daemon=True).start()

In [16]:
from pyngrok import conf
import time


os.environ["NGROK_API_KEY"] = userdata.get("NGROK_API_KEY")
conf.get_default().auth_token = os.environ["NGROK_API_KEY"]

for tunnel in ngrok.get_tunnels():
    ngrok.disconnect(tunnel.public_url)
    print(f"Closed old tunnel: {tunnel.public_url}")

time.sleep(2)
public_url = ngrok.connect(addr="8000", proto="http", hostname="sunfish-pleased-privately.ngrok-free.app").public_url
print(f"üîó Public ngrok URL: {public_url}/docs")

INFO:     Started server process [998]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


üîó Public ngrok URL: https://sunfish-pleased-privately.ngrok-free.app/docs
