In [1]:
!pip -q install ftfy regex tqdm
!pip -q install git+https://github.com/openai/CLIP.git
!pip -q install pytorch-metric-learning

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m72.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m64.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m34.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [

## Load dataset


In [2]:
#train-val-test: 70:15:15
!gdown 1u9XcSIB4-UMmgSGa_or2B3bBbgiNZe5I
!unzip -q data_cifar10_ent_split.zip

Downloading...
From (original): https://drive.google.com/uc?id=1u9XcSIB4-UMmgSGa_or2B3bBbgiNZe5I
From (redirected): https://drive.google.com/uc?id=1u9XcSIB4-UMmgSGa_or2B3bBbgiNZe5I&confirm=t&uuid=f807294f-eb81-40db-b0c4-554376971f6f
To: /kaggle/working/data_cifar10_ent_split.zip
100%|██████████████████████████████████████| 1.06G/1.06G [00:10<00:00, 99.3MB/s]


In [4]:

import os
import torch
import clip
from PIL import Image
import torch.nn.functional as F
import random
from sklearn.model_selection import train_test_split
import numpy as np

import torch.nn as nn
from pytorch_metric_learning import losses
from random import shuffle
from tqdm import tqdm
import torch.nn.init as init

from torch.optim.lr_scheduler import ReduceLROnPlateau

## Prepare Dataset

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CLIP model (ViT-B/32)
model, preprocess = clip.load("ViT-B/32", device=device)

img_dir = "./train"
embeddings = {}

def load_images_recursively(folder):
    '''
    Embed all images in the training data and return them in a dictionary
    Return:
                embeddings = {
                    "img_name": embedding_tensor,
                    ...
                }
    '''
    embeddings = {}
    for root, dirs, files in os.walk(folder):
        for fn in files:
            if fn.lower().endswith((".png", ".jpg", ".jpeg")):
                path = os.path.join(root, fn)
                try:
                    image = preprocess(Image.open(path)).unsqueeze(0).to(device)
                    with torch.no_grad():
                        emb = model.encode_image(image)
                        emb = emb / emb.norm(dim=-1, keepdim=True)
                    rel_path = os.path.relpath(path, img_dir)
                    embeddings[rel_path] = emb.squeeze(0).cpu()
                except Exception as e:
                    print(f"Failed to load {path}: {e}")
    return embeddings

embeddings = load_images_recursively(img_dir)
print(f"Total images loaded: {len(embeddings)}")

100%|███████████████████████████████████████| 338M/338M [00:07<00:00, 49.9MiB/s]


Total images loaded: 1298


In [7]:

def create_cls_map(root_dir):
    """
    Create a mapping from relative image path to its class name.

    Args:
        root_dir (str): Root directory containing class-named subfolders with images.

    Returns:
        dict: {
            "class_name/image_name.jpg": "class_name",
            ...
        }
    """
    cls_map = {}
    for cls_name in os.listdir(root_dir):
        cls_path = os.path.join(root_dir, cls_name)
        if not os.path.isdir(cls_path):
            continue
        for img_name in os.listdir(cls_path):
            if img_name.lower().endswith((".png", ".jpg", ".jpeg")):
                rel_path = f"{cls_name}/{img_name}"  # or use os.path.join and os.path.relpath for robustness
                cls_map[rel_path] = cls_name
    return cls_map
root_dir = img_dir
cls_map = create_cls_map(root_dir)
print(f"Created cls_map for {len(cls_map)} images")

Created cls_map for 1298 images


In [34]:
import torch.nn.functional as F

def generate_parent_child(cls_map, embeddings, num_children=3, verbose=False):
    """
    Generate (parent, top-k children) pairs based on cosine similarity within the same class.

    Returns:
        - parent_child: dict of {img_name: [child1, child2, ...]}
        - class_to_imgs: dict of {class_name: [img_name1, img_name2, ...]}
    """
    class_to_imgs = {}
    for img_name, cls in cls_map.items():
        class_to_imgs.setdefault(cls, []).append(img_name)

    parent_child = {}
    skipped_classes = 0

    for cls, img_list in class_to_imgs.items():
        valid_imgs = [i for i in img_list if i in embeddings]
        if len(valid_imgs) < 2:
            skipped_classes += 1
            continue  # Không đủ ảnh để tạo parent-child

        for img_name in valid_imgs:
            anchor_emb = embeddings[img_name].unsqueeze(0)

            others = [i for i in valid_imgs if i != img_name]
            if not others:
                continue

            other_embs = torch.stack([embeddings[i] for i in others])

            sim_scores = F.cosine_similarity(anchor_emb, other_embs)
            topk = min(num_children, len(others))
            topk_indices = torch.topk(sim_scores, k=topk).indices

            children = [others[i] for i in topk_indices.tolist()]
            parent_child[img_name] = children

    if verbose:
        print(f"✅ Generated parent-child pairs: {len(parent_child)}")
        print(f"⚠️ Skipped classes with <2 valid embeddings: {skipped_classes}")

    return parent_child, class_to_imgs

parent_child, class_to_imgs = generate_parent_child(cls_map, embeddings, num_children=3)
print(f"Generated parent-child pairs for {len(parent_child)} images")

Generated parent-child pairs for 1298 images


In [10]:
import torch
import torch.nn.functional as F
import random

# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Define class-wise incompatible (negative) relationships
negative_map = {
    "nose-right": ["nose-left"],
    "nose-left": ["nose-right"],
    "ear-right": ["ear-left"],
    "ear-left": ["ear-right"],
    "vc-open": ["vc-closed", "throat"],
    "vc-closed": ["vc-open"],
    "throat": ["vc-open", "vc-closed"],
}

# --- Function definitions ---

def get_hard_negative(anchor_name, anchor_emb, anchor_cls, class_to_imgs, embeddings, negative_map):
    """
    Select the hardest negative image (most similar but from a different class).
    Priority is given to predefined negative classes.
    """
    neg_classes = negative_map.get(anchor_cls, None)

    if not neg_classes:
        neg_cls_candidates = [c for c in class_to_imgs if c != anchor_cls]
        if not neg_cls_candidates:
            return None
        neg_cls = random.choice(neg_cls_candidates)
        neg_imgs = class_to_imgs[neg_cls]
    else:
        neg_imgs = []
        for neg_cls in neg_classes:
            neg_imgs.extend(class_to_imgs.get(neg_cls, []))
        if not neg_imgs:
            return None

    anchor_emb = anchor_emb.unsqueeze(0)
    max_sim = -1
    hard_neg = None
    for neg_img in neg_imgs:
        if neg_img not in embeddings:
            continue
        neg_emb = embeddings[neg_img].unsqueeze(0)
        sim = F.cosine_similarity(anchor_emb, neg_emb).item()
        if sim > max_sim:
            max_sim = sim
            hard_neg = neg_img

    return hard_neg


def filter_parent_child(parent_child_full, embeddings_subset):
    """
    Filter parent-child mapping để giữ các cặp có mặt trong embeddings hiện tại.
    """
    parent_child_filtered = {}
    valid_imgs = set(embeddings_subset.keys())
    for parent, children in parent_child_full.items():
        if parent in valid_imgs:
            filtered = [c for c in children if c in valid_imgs]
            if filtered:
                parent_child_filtered[parent] = filtered
    return parent_child_filtered


def create_triplets(embeddings, cls_map, parent_child, class_to_imgs):
    """
    Tạo các bộ 3 (anchor, positive, negative) từ tập embeddings hiện tại.
    """
    triplets = []
    for anchor_name, anchor_emb in embeddings.items():
        if anchor_name not in parent_child or len(parent_child[anchor_name]) == 0:
            continue
        anchor_cls = cls_map[anchor_name]
        pos_name = parent_child[anchor_name][0]
        neg_name = get_hard_negative(anchor_name, anchor_emb, anchor_cls, class_to_imgs, embeddings, negative_map)

        if neg_name is not None:
            triplets.append((
                anchor_emb,
                embeddings[pos_name],
                embeddings[neg_name],
                anchor_name,
                neg_name
            ))

    print(f"Generated {len(triplets)} triplets.")
    return triplets


def get_triplet_batch(triplets, batch_size=32):
    if len(triplets) < batch_size:
        return None
    batch = random.sample(triplets, batch_size)
    anchor = torch.stack([t[0] for t in batch])
    positive = torch.stack([t[1] for t in batch])
    negative = torch.stack([t[2] for t in batch])
    return anchor, positive, negative


parent_child = filter_parent_child(parent_child, embeddings)
triplets = create_triplets(embeddings, cls_map, parent_child, class_to_imgs)
batch = get_triplet_batch(triplets, batch_size=32)
if batch:
    anchor, positive, negative = batch
    print(anchor.shape, positive.shape, negative.shape)
else:
    print("Không đủ triplet để tạo batch.")


Generated 1298 triplets.
torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 512])


## Training

In [12]:
embed_dim = 512
class GaussianFourierProjection(nn.Module):
    def __init__(self, embed_dim, scale=10.0):
        super().__init__()
        # Fixed random weights for projecting scalar t to higher frequency space
        self.W = nn.Parameter(torch.randn(1, embed_dim // 2) * scale, requires_grad=False)

    def forward(self, t):
        # Ensure t has shape [B, 1]
        if t.ndim == 1:
            t = t.unsqueeze(-1)
        proj = t * self.W  # Shape: [B, D/2]
        # Return sinusoidal and cosinusoidal projection: [sin(tW), cos(tW)] → Shape: [B, D]
        return torch.cat([torch.sin(proj), torch.cos(proj)], dim=-1)

class VectorField(nn.Module):
    def __init__(self, dim, t_dim=32, hidden_dim=256, n_heads=4, dropout_prob=0.1):
        super().__init__()
        self.x_norm = nn.LayerNorm(dim)  # Normalize input embeddings
        self.time_encoder = GaussianFourierProjection(t_dim)  # Time embedding module
        self.dropout = nn.Dropout(dropout_prob)

        # Create multiple independent heads (like a lightweight transformer block)
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim + t_dim, hidden_dim),     # Project input + time
                nn.LayerNorm(hidden_dim),
                nn.SiLU(),                               # Activation: SiLU 
                nn.Dropout(dropout_prob),
                nn.Linear(hidden_dim, dim)              # Back to original embedding dimension
            ) for _ in range(n_heads)
        ])

        self.res_weight = nn.Parameter(torch.tensor(1.0))  # Learnable residual scaling
        self.out_norm = nn.LayerNorm(dim)  # Final normalization (not applied directly here)
        self.initialize_weights()

    
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                # Use Kaiming initialization (good for ReLU/SiLU)
                init.kaiming_uniform_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)


    def forward(self, x, t):
        # Handle scalar or 1D tensor time input → ensure shape [B, 1]
        if not isinstance(t, torch.Tensor):
            t = torch.full((x.shape[0], 1), t, device=x.device)
        elif t.ndim == 0:
            t = t.expand(x.shape[0], 1)
        elif t.ndim == 1:
            t = t.unsqueeze(-1)

        x_normed = self.x_norm(x)                     # Normalize input
        t_encoded = self.time_encoder(t.to(x.device)) # Encode time t
        inp = torch.cat([x_normed, t_encoded], dim=-1)  # Concatenate along feature dim

        # Pass through each head and average their outputs
        head_outs = [head(inp) for head in self.heads]
        out = torch.mean(torch.stack(head_outs), dim=0)

        # Add residual connection scaled by learnable weight
        return out + self.res_weight * x

In [43]:
def euler_integration(x0, vf, steps=10):
    """
    Args:
        x0: initial embeddings [B, D]
        vf: vector field model (takes in x and t, returns dx/dt)
        steps: number of integration steps
    Returns:
        Transformed embeddings x(T)
    """
    dt = 1.0 / steps
    x = x0
    for i in range(steps):
        t = i * dt
        k1 = vf(x, t)
        k2 = vf(x + 0.5 * dt * k1, t + 0.5 * dt)
        k3 = vf(x + 0.5 * dt * k2, t + 0.5 * dt)
        k4 = vf(x + dt * k3, t + dt)
        x = x + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
    return x

    #return correct / total if total > 0 else 0.0
def compute_triplet_loss(model, triplets, loss_func, cls_map, steps=10):
    """
       Compute loss using MultiSimilarityLoss
       It takes transformed embeddings + integer labels
   
    """
    model.eval()
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for i in range(0, len(triplets), 64):
            batch = triplets[i : i + 64]
            if len(batch) == 0:
                continue

            batch_embeddings, batch_labels = get_embeddings_labels_from_triplets(batch, cls_map, class2idx)
            if batch_embeddings is None or len(batch_embeddings) == 0:
                continue

            batch_embeddings = batch_embeddings.to(device).float()
            batch_labels = batch_labels.to(device).long()

            pred_embeddings = euler_integration(batch_embeddings, model, steps=steps)
            loss = loss_func(pred_embeddings, batch_labels)

            total_loss += loss.item()
            num_batches += 1

    return total_loss / max(1, num_batches)

def get_embeddings_labels_from_triplets(triplets_batch, cls_map, class2idx):
    '''
        Convert list of triplets into tensors for model input.
    '''
    embeddings = []
    labels = []

    for anchor_emb, positive_emb, negative_emb, anchor_name, negative_name in triplets_batch:
        anchor_cls = cls_map[anchor_name]
        negative_cls = cls_map[negative_name]

        embeddings.extend([anchor_emb, positive_emb, negative_emb])
        labels.extend([
            class2idx[anchor_cls],
            class2idx[anchor_cls],
            class2idx[negative_cls]
        ])

    embeddings_tensor = torch.stack(embeddings)
    labels_tensor = torch.tensor(labels)
    return embeddings_tensor, labels_tensor

In [36]:
root_dir_val = "./val"

def standardize_keys(d):
    return {
        os.path.join(os.path.basename(os.path.dirname(k)), os.path.basename(k)): v
        for k, v in d.items()
    }

# 1. Load embeddings
embeddings_val = load_images_recursively(root_dir_val)
print(f" Total images loaded: {len(embeddings_val)}")

# 2. Load cls_map
cls_map_val = create_cls_map(root_dir_val)
print(f" Created cls_map for {len(cls_map_val)} images")

# 3.  Standardize keys
embeddings_val = standardize_keys(embeddings_val)
cls_map_val = standardize_keys(cls_map_val)

# 4. Generate parent-child
parent_child_val, class_to_imgs_val = generate_parent_child(cls_map_val, embeddings_val, num_children=3, verbose=True)

# 5. Filter parent-child
parent_child_val = filter_parent_child(parent_child_val, embeddings_val)

# 6. Create triplets
triplets_val = create_triplets(embeddings_val, cls_map_val, parent_child_val, class_to_imgs_val)
print(f" Total triplets generated: {len(triplets_val)}")

# 7. Get batch
batch_val = get_triplet_batch(triplets_val, batch_size=32)
if batch_val:
    anchor_val, positive_val, negative_val = batch_val
    print(" Batch shapes:", anchor_val.shape, positive_val.shape, negative_val.shape)
else:
    print(" Không đủ triplet để tạo batch.")


 Total images loaded: 275
 Created cls_map for 275 images
✅ Generated parent-child pairs: 275
⚠️ Skipped classes with <2 valid embeddings: 0
Generated 275 triplets.
 Total triplets generated: 275
 Batch shapes: torch.Size([32, 512]) torch.Size([32, 512]) torch.Size([32, 512])


In [37]:
sample_cls_key = list(cls_map_val.keys())[0]
sample_emb_key = list(embeddings_val.keys())[0]
print("🔍 Sample key từ cls_map:", sample_cls_key)
print("🔍 Sample key từ embeddings:", sample_emb_key)


🔍 Sample key từ cls_map: vc-open/23146955_230624145953203831_954_image04.png
🔍 Sample key từ embeddings: vc-open/23146955_230624145953203831_954_image04.png


In [38]:
@torch.no_grad()
def compute_recall_at_1_embedded(embeddings, cls_map, type_map, vf, device="cuda", steps=10):
    """
    embeddings: dict tên ảnh → vector trước khi qua VF
    cls_map: dict tên ảnh → class (string)
    type_map: dict class → list các class cùng type
    vf: mô hình vector field (đã học)
    """
    keys = list(embeddings.keys())
    embedded_vectors = {}
    
    # Apply VF để lấy vector sau khi flow
    for k in keys:
        emb = embeddings[k].unsqueeze(0).to(device).float()
        emb = euler_integration(emb, vf, steps=steps)
        embedded_vectors[k] = emb.squeeze(0).cpu()  # chuyển lại về CPU để so cosine

    correct = 0
    total = 0

    for anchor_name in keys:
        anchor_emb = embedded_vectors[anchor_name].unsqueeze(0)
        anchor_cls = cls_map[anchor_name]
        anchor_type = set(type_map.get(anchor_cls, [])) | {anchor_cls}

        max_sim = -float("inf")
        top1_name = None

        for other_name in keys:
            if other_name == anchor_name:
                continue
            other_emb = embedded_vectors[other_name].unsqueeze(0)
            sim = F.cosine_similarity(anchor_emb, other_emb).item()
            if sim > max_sim:
                max_sim = sim
                top1_name = other_name

        if top1_name:
            pred_cls = cls_map[top1_name]
            if pred_cls in anchor_type:
                correct += 1
        total += 1

    recall = correct / total if total > 0 else 0.0
    return recall


In [46]:
#Define label mapping
class2idx = {'ear-left': 3,
 'ear-right': 2,
 'nose-left': 1,
 'nose-right': 0,
 'throat': 6,
 'vc-closed': 4,
 'vc-open': 5}

# Khởi tạo model, optimizer, loss
vf = VectorField(embed_dim).to(device).float()
optimizer = torch.optim.AdamW(vf.parameters(), lr=1e-4)
loss_func = losses.MultiSimilarityLoss()

# Hyperparameters cho scheduler và early stopping
warmup_epochs = 20
early_stop_patience = 40
scheduler_patience = 15
scheduler_factor = 0.8
delta = 5e-4

# LR Scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=scheduler_factor, patience=scheduler_patience)

# Biến theo dõi
best_val_loss = float('inf')
best_train_loss = float('inf')
epochs_no_improve = 0

# Huấn luyện
epochs = 200
batch_size = 32

for epoch in tqdm(range(epochs)):
    vf.train()
    epoch_loss = 0.0
    num_batches = 0

    random.seed(42)
    random.shuffle(triplets)

    for i in range(0, len(triplets), batch_size):
        batch = triplets[i: i + batch_size]
        if not batch:
            continue

        optimizer.zero_grad()
        batch_embeddings, batch_labels = get_embeddings_labels_from_triplets(batch, cls_map, class2idx)

        if batch_embeddings is None or len(batch_embeddings) == 0:
            continue

        batch_embeddings = batch_embeddings.to(device).float()
        batch_labels = batch_labels.to(device).long()
        pred_embeddings = euler_integration(batch_embeddings, vf, steps=10)

        loss = loss_func(pred_embeddings, batch_labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vf.parameters(), max_norm=5.0)
        optimizer.step()

        epoch_loss += loss.item()
        num_batches += 1

    # --- Evaluation ---
    vf.eval()
    val_loss = compute_triplet_loss(vf, triplets_val, loss_func,cls_map_val, steps=10)
    avg_train_loss = epoch_loss / max(1, num_batches)


    print(f"[Epoch {epoch+1}] Train Loss: {avg_train_loss:.4f}")

    # --- LR Warmup & Scheduler ---
    old_lr = optimizer.param_groups[0]['lr']
    if epoch >= warmup_epochs:
        scheduler.step(val_loss)
    else:
        scheduler.step(float('inf'))

    new_lr = optimizer.param_groups[0]['lr']
    if new_lr < old_lr:
        print(f"⚠️ LR reduced at epoch {epoch + 1} → {new_lr:.1e}")

    # --- Early Stopping ---
    if (val_loss < best_val_loss - delta) or (avg_train_loss < best_train_loss - delta):
        best_val_loss = min(val_loss, best_val_loss)
        best_train_loss = min(avg_train_loss, best_train_loss)
        epochs_no_improve = 0
        torch.save(vf.state_dict(), "best_vf.pt")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= early_stop_patience:
        print(f"⛔ Early stopping at epoch {epoch + 1}")
        break

  0%|          | 1/200 [00:18<59:51, 18.05s/it]

[Epoch 1] Train Loss: 1.4995


  1%|          | 2/200 [00:36<1:00:39, 18.38s/it]

[Epoch 2] Train Loss: 1.5001


  2%|▏         | 3/200 [00:54<1:00:00, 18.28s/it]

[Epoch 3] Train Loss: 1.4906


  2%|▏         | 4/200 [01:13<1:00:05, 18.40s/it]

[Epoch 4] Train Loss: 1.4958


  2%|▎         | 5/200 [01:32<1:00:10, 18.51s/it]

[Epoch 5] Train Loss: 1.4962


  3%|▎         | 6/200 [01:50<59:21, 18.36s/it]  

[Epoch 6] Train Loss: 1.4896


  4%|▎         | 7/200 [02:08<59:16, 18.43s/it]

[Epoch 7] Train Loss: 1.4866


  4%|▍         | 8/200 [02:26<58:22, 18.24s/it]

[Epoch 8] Train Loss: 1.4780


  4%|▍         | 9/200 [02:44<58:07, 18.26s/it]

[Epoch 9] Train Loss: 1.4845


  5%|▌         | 10/200 [03:02<57:34, 18.18s/it]

[Epoch 10] Train Loss: 1.4754


  6%|▌         | 11/200 [03:21<57:47, 18.35s/it]

[Epoch 11] Train Loss: 1.4837


  6%|▌         | 12/200 [03:40<57:40, 18.40s/it]

[Epoch 12] Train Loss: 1.4841


  6%|▋         | 13/200 [03:57<56:48, 18.23s/it]

[Epoch 13] Train Loss: 1.4784


  7%|▋         | 14/200 [04:16<56:49, 18.33s/it]

[Epoch 14] Train Loss: 1.4861


  8%|▊         | 15/200 [04:34<55:59, 18.16s/it]

[Epoch 15] Train Loss: 1.4482


  8%|▊         | 16/200 [04:53<56:32, 18.44s/it]

[Epoch 16] Train Loss: 1.4393
⚠️ LR reduced at epoch 16 → 8.0e-05


  8%|▊         | 17/200 [05:11<55:56, 18.34s/it]

[Epoch 17] Train Loss: 1.4252


  9%|▉         | 18/200 [05:30<55:58, 18.46s/it]

[Epoch 18] Train Loss: 1.4029


 10%|▉         | 19/200 [05:49<56:05, 18.59s/it]

[Epoch 19] Train Loss: 1.3962


 10%|█         | 20/200 [06:07<55:10, 18.39s/it]

[Epoch 20] Train Loss: 1.4029


 10%|█         | 21/200 [06:25<55:00, 18.44s/it]

[Epoch 21] Train Loss: 1.3982


 11%|█         | 22/200 [06:43<54:22, 18.33s/it]

[Epoch 22] Train Loss: 1.3827


 12%|█▏        | 23/200 [07:02<54:35, 18.50s/it]

[Epoch 23] Train Loss: 1.3538


 12%|█▏        | 24/200 [07:20<54:06, 18.44s/it]

[Epoch 24] Train Loss: 1.3554


 12%|█▎        | 25/200 [07:39<54:10, 18.57s/it]

[Epoch 25] Train Loss: 1.3582


 13%|█▎        | 26/200 [07:58<53:57, 18.61s/it]

[Epoch 26] Train Loss: 1.3580


 14%|█▎        | 27/200 [08:16<53:18, 18.49s/it]

[Epoch 27] Train Loss: 1.3527


 14%|█▍        | 28/200 [08:35<53:17, 18.59s/it]

[Epoch 28] Train Loss: 1.3382


 14%|█▍        | 29/200 [08:54<52:58, 18.59s/it]

[Epoch 29] Train Loss: 1.3273


 15%|█▌        | 30/200 [09:12<52:48, 18.64s/it]

[Epoch 30] Train Loss: 1.3350


 16%|█▌        | 31/200 [09:31<52:29, 18.64s/it]

[Epoch 31] Train Loss: 1.3169


 16%|█▌        | 32/200 [09:49<51:45, 18.48s/it]

[Epoch 32] Train Loss: 1.3260


 16%|█▋        | 33/200 [10:08<51:29, 18.50s/it]

[Epoch 33] Train Loss: 1.3011


 17%|█▋        | 34/200 [10:25<50:36, 18.29s/it]

[Epoch 34] Train Loss: 1.2942


 18%|█▊        | 35/200 [10:44<50:36, 18.40s/it]

[Epoch 35] Train Loss: 1.2955


 18%|█▊        | 36/200 [11:02<49:53, 18.25s/it]

[Epoch 36] Train Loss: 1.3073


 18%|█▊        | 37/200 [11:21<50:04, 18.43s/it]

[Epoch 37] Train Loss: 1.2880


 19%|█▉        | 38/200 [11:40<50:11, 18.59s/it]

[Epoch 38] Train Loss: 1.2827
⚠️ LR reduced at epoch 38 → 6.4e-05


 20%|█▉        | 39/200 [11:59<49:59, 18.63s/it]

[Epoch 39] Train Loss: 1.2842


 20%|██        | 40/200 [12:17<49:50, 18.69s/it]

[Epoch 40] Train Loss: 1.2727


 20%|██        | 41/200 [12:35<48:57, 18.47s/it]

[Epoch 41] Train Loss: 1.2559


 21%|██        | 42/200 [12:54<48:43, 18.50s/it]

[Epoch 42] Train Loss: 1.2651


 22%|██▏       | 43/200 [13:12<48:09, 18.40s/it]

[Epoch 43] Train Loss: 1.2594


 22%|██▏       | 44/200 [13:31<48:02, 18.48s/it]

[Epoch 44] Train Loss: 1.2644


 22%|██▎       | 45/200 [13:51<48:52, 18.92s/it]

[Epoch 45] Train Loss: 1.2563


 23%|██▎       | 46/200 [14:09<48:01, 18.71s/it]

[Epoch 46] Train Loss: 1.2541


 24%|██▎       | 47/200 [14:28<47:48, 18.75s/it]

[Epoch 47] Train Loss: 1.2524


 24%|██▍       | 48/200 [14:46<46:51, 18.50s/it]

[Epoch 48] Train Loss: 1.2571


 24%|██▍       | 49/200 [15:04<46:41, 18.55s/it]

[Epoch 49] Train Loss: 1.2682


 25%|██▌       | 50/200 [15:22<46:03, 18.42s/it]

[Epoch 50] Train Loss: 1.2708


 26%|██▌       | 51/200 [15:41<45:57, 18.51s/it]

[Epoch 51] Train Loss: 1.2693


 26%|██▌       | 52/200 [16:00<46:15, 18.75s/it]

[Epoch 52] Train Loss: 1.2583


 26%|██▋       | 53/200 [16:19<45:38, 18.63s/it]

[Epoch 53] Train Loss: 1.2468


 27%|██▋       | 54/200 [16:39<46:07, 18.95s/it]

[Epoch 54] Train Loss: 1.2541
⚠️ LR reduced at epoch 54 → 5.1e-05


 28%|██▊       | 55/200 [16:58<46:11, 19.11s/it]

[Epoch 55] Train Loss: 1.2456


 28%|██▊       | 56/200 [17:17<46:08, 19.23s/it]

[Epoch 56] Train Loss: 1.2474


 28%|██▊       | 57/200 [17:37<45:46, 19.21s/it]

[Epoch 57] Train Loss: 1.2400


 29%|██▉       | 58/200 [17:55<44:52, 18.96s/it]

[Epoch 58] Train Loss: 1.2433


 30%|██▉       | 59/200 [18:14<44:29, 18.93s/it]

[Epoch 59] Train Loss: 1.2512


 30%|███       | 60/200 [18:32<43:43, 18.74s/it]

[Epoch 60] Train Loss: 1.2375


 30%|███       | 61/200 [18:51<43:40, 18.85s/it]

[Epoch 61] Train Loss: 1.2366


 31%|███       | 62/200 [19:10<43:26, 18.89s/it]

[Epoch 62] Train Loss: 1.2457


 32%|███▏      | 63/200 [19:29<42:48, 18.75s/it]

[Epoch 63] Train Loss: 1.2423


 32%|███▏      | 64/200 [19:48<42:45, 18.86s/it]

[Epoch 64] Train Loss: 1.2518


 32%|███▎      | 65/200 [20:06<41:59, 18.66s/it]

[Epoch 65] Train Loss: 1.2484


 33%|███▎      | 66/200 [20:25<41:47, 18.72s/it]

[Epoch 66] Train Loss: 1.2462


 34%|███▎      | 67/200 [20:43<41:12, 18.59s/it]

[Epoch 67] Train Loss: 1.2313


 34%|███▍      | 68/200 [21:02<41:01, 18.65s/it]

[Epoch 68] Train Loss: 1.2416


 34%|███▍      | 69/200 [21:21<40:50, 18.71s/it]

[Epoch 69] Train Loss: 1.2325


 35%|███▌      | 70/200 [21:39<40:14, 18.57s/it]

[Epoch 70] Train Loss: 1.2342
⚠️ LR reduced at epoch 70 → 4.1e-05


 36%|███▌      | 71/200 [21:58<39:58, 18.59s/it]

[Epoch 71] Train Loss: 1.2295


 36%|███▌      | 72/200 [22:16<39:20, 18.44s/it]

[Epoch 72] Train Loss: 1.2397


 36%|███▋      | 73/200 [22:34<39:08, 18.49s/it]

[Epoch 73] Train Loss: 1.2355


 37%|███▋      | 74/200 [22:53<38:57, 18.55s/it]

[Epoch 74] Train Loss: 1.2319


 38%|███▊      | 75/200 [23:11<38:19, 18.40s/it]

[Epoch 75] Train Loss: 1.2332


 38%|███▊      | 76/200 [23:30<38:15, 18.51s/it]

[Epoch 76] Train Loss: 1.2267


 38%|███▊      | 77/200 [23:48<37:26, 18.26s/it]

[Epoch 77] Train Loss: 1.2291


 39%|███▉      | 78/200 [24:06<37:15, 18.32s/it]

[Epoch 78] Train Loss: 1.2291


 40%|███▉      | 79/200 [24:24<36:45, 18.23s/it]

[Epoch 79] Train Loss: 1.2259


 40%|████      | 80/200 [24:43<36:59, 18.49s/it]

[Epoch 80] Train Loss: 1.2272


 40%|████      | 81/200 [25:02<36:50, 18.57s/it]

[Epoch 81] Train Loss: 1.2248


 41%|████      | 82/200 [25:20<36:23, 18.51s/it]

[Epoch 82] Train Loss: 1.2271


 42%|████▏     | 83/200 [25:39<36:21, 18.65s/it]

[Epoch 83] Train Loss: 1.2408


 42%|████▏     | 84/200 [25:58<36:06, 18.68s/it]

[Epoch 84] Train Loss: 1.2214


 42%|████▎     | 85/200 [26:17<35:53, 18.73s/it]

[Epoch 85] Train Loss: 1.2275


 43%|████▎     | 86/200 [26:35<35:11, 18.52s/it]

[Epoch 86] Train Loss: 1.2287
⚠️ LR reduced at epoch 86 → 3.3e-05


 44%|████▎     | 87/200 [26:54<34:59, 18.58s/it]

[Epoch 87] Train Loss: 1.2225


 44%|████▍     | 88/200 [27:12<34:45, 18.62s/it]

[Epoch 88] Train Loss: 1.2208


 44%|████▍     | 89/200 [27:30<34:06, 18.43s/it]

[Epoch 89] Train Loss: 1.2167


 45%|████▌     | 90/200 [27:49<33:53, 18.48s/it]

[Epoch 90] Train Loss: 1.2160


 46%|████▌     | 91/200 [28:07<33:16, 18.31s/it]

[Epoch 91] Train Loss: 1.2160


 46%|████▌     | 92/200 [28:25<33:03, 18.36s/it]

[Epoch 92] Train Loss: 1.2204


 46%|████▋     | 93/200 [28:43<32:36, 18.28s/it]

[Epoch 93] Train Loss: 1.2105


 47%|████▋     | 94/200 [29:02<32:36, 18.45s/it]

[Epoch 94] Train Loss: 1.2228


 48%|████▊     | 95/200 [29:21<32:28, 18.56s/it]

[Epoch 95] Train Loss: 1.2046


 48%|████▊     | 96/200 [29:39<31:54, 18.41s/it]

[Epoch 96] Train Loss: 1.1907


 48%|████▊     | 97/200 [29:58<31:56, 18.61s/it]

[Epoch 97] Train Loss: 1.2015


 49%|████▉     | 98/200 [30:17<31:33, 18.56s/it]

[Epoch 98] Train Loss: 1.1760


 50%|████▉     | 99/200 [30:36<31:28, 18.70s/it]

[Epoch 99] Train Loss: 1.1720


 50%|█████     | 100/200 [30:54<31:03, 18.64s/it]

[Epoch 100] Train Loss: 1.1655


 50%|█████     | 101/200 [31:13<30:59, 18.79s/it]

[Epoch 101] Train Loss: 1.1616


 51%|█████     | 102/200 [31:32<30:29, 18.67s/it]

[Epoch 102] Train Loss: 1.1333
⚠️ LR reduced at epoch 102 → 2.6e-05


 52%|█████▏    | 103/200 [31:50<29:46, 18.42s/it]

[Epoch 103] Train Loss: 1.1207


 52%|█████▏    | 104/200 [32:08<29:32, 18.47s/it]

[Epoch 104] Train Loss: 1.1119


 52%|█████▎    | 105/200 [32:26<29:01, 18.33s/it]

[Epoch 105] Train Loss: 1.0996


 53%|█████▎    | 106/200 [32:44<28:43, 18.34s/it]

[Epoch 106] Train Loss: 1.1197


 54%|█████▎    | 107/200 [33:03<28:20, 18.29s/it]

[Epoch 107] Train Loss: 1.0909


 54%|█████▍    | 108/200 [33:21<28:01, 18.27s/it]

[Epoch 108] Train Loss: 1.0799


 55%|█████▍    | 109/200 [33:39<27:45, 18.30s/it]

[Epoch 109] Train Loss: 1.0716


 55%|█████▌    | 110/200 [33:57<27:06, 18.07s/it]

[Epoch 110] Train Loss: 1.0754


 56%|█████▌    | 111/200 [34:15<26:53, 18.12s/it]

[Epoch 111] Train Loss: 1.1003


 56%|█████▌    | 112/200 [34:33<26:28, 18.06s/it]

[Epoch 112] Train Loss: 1.0656


 56%|█████▋    | 113/200 [34:52<26:25, 18.23s/it]

[Epoch 113] Train Loss: 1.0732


 57%|█████▋    | 114/200 [35:10<26:03, 18.18s/it]

[Epoch 114] Train Loss: 1.0474


 57%|█████▊    | 115/200 [35:28<26:02, 18.38s/it]

[Epoch 115] Train Loss: 1.0467


 58%|█████▊    | 116/200 [35:47<25:44, 18.39s/it]

[Epoch 116] Train Loss: 1.0648


 58%|█████▊    | 117/200 [36:05<25:14, 18.25s/it]

[Epoch 117] Train Loss: 1.0544


 59%|█████▉    | 118/200 [36:24<25:10, 18.42s/it]

[Epoch 118] Train Loss: 1.0364
⚠️ LR reduced at epoch 118 → 2.1e-05


 60%|█████▉    | 119/200 [36:42<24:50, 18.40s/it]

[Epoch 119] Train Loss: 1.0395


 60%|██████    | 120/200 [37:01<24:55, 18.69s/it]

[Epoch 120] Train Loss: 1.0377


 60%|██████    | 121/200 [37:19<24:21, 18.50s/it]

[Epoch 121] Train Loss: 1.0376


 61%|██████    | 122/200 [37:37<23:49, 18.33s/it]

[Epoch 122] Train Loss: 1.0283


 62%|██████▏   | 123/200 [37:56<23:45, 18.51s/it]

[Epoch 123] Train Loss: 1.0113


 62%|██████▏   | 124/200 [38:14<23:15, 18.36s/it]

[Epoch 124] Train Loss: 1.0325


 62%|██████▎   | 125/200 [38:33<23:00, 18.40s/it]

[Epoch 125] Train Loss: 1.0214


 63%|██████▎   | 126/200 [38:51<22:34, 18.31s/it]

[Epoch 126] Train Loss: 1.0167


 64%|██████▎   | 127/200 [39:10<22:29, 18.48s/it]

[Epoch 127] Train Loss: 1.0213


 64%|██████▍   | 128/200 [39:29<22:22, 18.64s/it]

[Epoch 128] Train Loss: 1.0169


 64%|██████▍   | 129/200 [39:47<22:00, 18.60s/it]

[Epoch 129] Train Loss: 1.0087


 65%|██████▌   | 130/200 [40:06<21:54, 18.78s/it]

[Epoch 130] Train Loss: 1.0319


 66%|██████▌   | 131/200 [40:24<21:19, 18.55s/it]

[Epoch 131] Train Loss: 1.0080


 66%|██████▌   | 132/200 [40:43<21:00, 18.54s/it]

[Epoch 132] Train Loss: 1.0384


 66%|██████▋   | 133/200 [41:01<20:30, 18.37s/it]

[Epoch 133] Train Loss: 1.0154


 67%|██████▋   | 134/200 [41:20<20:21, 18.51s/it]

[Epoch 134] Train Loss: 1.0159
⚠️ LR reduced at epoch 134 → 1.7e-05


 68%|██████▊   | 135/200 [41:39<20:07, 18.58s/it]

[Epoch 135] Train Loss: 1.0019


 68%|██████▊   | 136/200 [41:58<19:56, 18.69s/it]

[Epoch 136] Train Loss: 1.0076


 68%|██████▊   | 137/200 [42:16<19:42, 18.76s/it]

[Epoch 137] Train Loss: 1.0004


 69%|██████▉   | 138/200 [42:35<19:11, 18.58s/it]

[Epoch 138] Train Loss: 1.0067


 70%|██████▉   | 139/200 [42:53<18:55, 18.61s/it]

[Epoch 139] Train Loss: 1.0184


 70%|███████   | 140/200 [43:11<18:27, 18.46s/it]

[Epoch 140] Train Loss: 1.0073


 70%|███████   | 141/200 [43:30<18:13, 18.53s/it]

[Epoch 141] Train Loss: 0.9919


 71%|███████   | 142/200 [43:49<17:54, 18.53s/it]

[Epoch 142] Train Loss: 0.9962


 72%|███████▏  | 143/200 [44:07<17:30, 18.43s/it]

[Epoch 143] Train Loss: 0.9823


 72%|███████▏  | 144/200 [44:26<17:18, 18.54s/it]

[Epoch 144] Train Loss: 0.9969


 72%|███████▎  | 145/200 [44:44<16:55, 18.46s/it]

[Epoch 145] Train Loss: 1.0032


 73%|███████▎  | 146/200 [45:02<16:38, 18.49s/it]

[Epoch 146] Train Loss: 1.0165


 74%|███████▎  | 147/200 [45:21<16:14, 18.39s/it]

[Epoch 147] Train Loss: 1.0121


 74%|███████▍  | 148/200 [45:40<16:04, 18.55s/it]

[Epoch 148] Train Loss: 0.9892


 74%|███████▍  | 149/200 [45:58<15:45, 18.54s/it]

[Epoch 149] Train Loss: 0.9980


 75%|███████▌  | 150/200 [46:16<15:16, 18.33s/it]

[Epoch 150] Train Loss: 1.0080
⚠️ LR reduced at epoch 150 → 1.3e-05


 76%|███████▌  | 151/200 [46:34<14:59, 18.36s/it]

[Epoch 151] Train Loss: 0.9854


 76%|███████▌  | 152/200 [46:52<14:35, 18.24s/it]

[Epoch 152] Train Loss: 0.9849


 76%|███████▋  | 153/200 [47:11<14:21, 18.34s/it]

[Epoch 153] Train Loss: 1.0054


 77%|███████▋  | 154/200 [47:29<14:00, 18.27s/it]

[Epoch 154] Train Loss: 0.9889


 78%|███████▊  | 155/200 [47:47<13:46, 18.36s/it]

[Epoch 155] Train Loss: 0.9830


 78%|███████▊  | 156/200 [48:06<13:30, 18.42s/it]

[Epoch 156] Train Loss: 0.9988


 78%|███████▊  | 157/200 [48:24<13:06, 18.30s/it]

[Epoch 157] Train Loss: 0.9894


 79%|███████▉  | 158/200 [48:43<12:54, 18.44s/it]

[Epoch 158] Train Loss: 1.0053


 80%|███████▉  | 159/200 [49:01<12:32, 18.36s/it]

[Epoch 159] Train Loss: 0.9812


 80%|████████  | 160/200 [49:20<12:18, 18.46s/it]

[Epoch 160] Train Loss: 0.9994


 80%|████████  | 161/200 [49:38<11:57, 18.40s/it]

[Epoch 161] Train Loss: 0.9980


 81%|████████  | 162/200 [49:57<11:45, 18.57s/it]

[Epoch 162] Train Loss: 0.9860


 82%|████████▏ | 163/200 [50:16<11:27, 18.59s/it]

[Epoch 163] Train Loss: 0.9939


 82%|████████▏ | 164/200 [50:34<11:05, 18.50s/it]

[Epoch 164] Train Loss: 0.9817


 82%|████████▎ | 165/200 [50:53<10:50, 18.57s/it]

[Epoch 165] Train Loss: 0.9814


 83%|████████▎ | 166/200 [51:11<10:25, 18.40s/it]

[Epoch 166] Train Loss: 0.9957
⚠️ LR reduced at epoch 166 → 1.1e-05


 84%|████████▎ | 167/200 [51:30<10:14, 18.63s/it]

[Epoch 167] Train Loss: 0.9918


 84%|████████▍ | 168/200 [51:48<09:56, 18.65s/it]

[Epoch 168] Train Loss: 0.9825


 84%|████████▍ | 169/200 [52:07<09:33, 18.51s/it]

[Epoch 169] Train Loss: 0.9769


 85%|████████▌ | 170/200 [52:25<09:17, 18.57s/it]

[Epoch 170] Train Loss: 0.9868


 86%|████████▌ | 171/200 [52:44<08:56, 18.49s/it]

[Epoch 171] Train Loss: 0.9859


 86%|████████▌ | 172/200 [53:03<08:41, 18.63s/it]

[Epoch 172] Train Loss: 0.9801


 86%|████████▋ | 173/200 [53:21<08:20, 18.52s/it]

[Epoch 173] Train Loss: 0.9880


 87%|████████▋ | 174/200 [53:40<08:03, 18.59s/it]

[Epoch 174] Train Loss: 0.9760


 88%|████████▊ | 175/200 [53:59<07:50, 18.84s/it]

[Epoch 175] Train Loss: 0.9775


 88%|████████▊ | 176/200 [54:17<07:27, 18.63s/it]

[Epoch 176] Train Loss: 0.9786


 88%|████████▊ | 177/200 [54:36<07:08, 18.65s/it]

[Epoch 177] Train Loss: 0.9696


 89%|████████▉ | 178/200 [54:54<06:47, 18.51s/it]

[Epoch 178] Train Loss: 0.9815


 90%|████████▉ | 179/200 [55:13<06:30, 18.58s/it]

[Epoch 179] Train Loss: 1.0224


 90%|█████████ | 180/200 [55:32<06:14, 18.73s/it]

[Epoch 180] Train Loss: 0.9836


 90%|█████████ | 181/200 [55:51<05:55, 18.70s/it]

[Epoch 181] Train Loss: 0.9831


 91%|█████████ | 182/200 [56:10<05:39, 18.85s/it]

[Epoch 182] Train Loss: 0.9817
⚠️ LR reduced at epoch 182 → 8.6e-06


 92%|█████████▏| 183/200 [56:28<05:17, 18.69s/it]

[Epoch 183] Train Loss: 0.9913


 92%|█████████▏| 184/200 [56:47<04:58, 18.65s/it]

[Epoch 184] Train Loss: 0.9780


 92%|█████████▎| 185/200 [57:05<04:38, 18.54s/it]

[Epoch 185] Train Loss: 0.9803


 93%|█████████▎| 186/200 [57:24<04:20, 18.59s/it]

[Epoch 186] Train Loss: 0.9836


 94%|█████████▎| 187/200 [57:42<04:01, 18.56s/it]

[Epoch 187] Train Loss: 0.9768


 94%|█████████▍| 188/200 [58:00<03:40, 18.36s/it]

[Epoch 188] Train Loss: 0.9792


 94%|█████████▍| 189/200 [58:18<03:21, 18.34s/it]

[Epoch 189] Train Loss: 0.9770


 95%|█████████▌| 190/200 [58:36<03:02, 18.28s/it]

[Epoch 190] Train Loss: 0.9713


 96%|█████████▌| 191/200 [58:55<02:45, 18.43s/it]

[Epoch 191] Train Loss: 0.9699


 96%|█████████▌| 192/200 [59:14<02:27, 18.42s/it]

[Epoch 192] Train Loss: 0.9810


 96%|█████████▋| 193/200 [59:32<02:09, 18.53s/it]

[Epoch 193] Train Loss: 0.9741


 97%|█████████▋| 194/200 [59:51<01:50, 18.49s/it]

[Epoch 194] Train Loss: 0.9742


 98%|█████████▊| 195/200 [1:00:09<01:31, 18.34s/it]

[Epoch 195] Train Loss: 0.9824


 98%|█████████▊| 196/200 [1:00:27<01:13, 18.46s/it]

[Epoch 196] Train Loss: 0.9742


 98%|█████████▊| 197/200 [1:00:46<00:55, 18.35s/it]

[Epoch 197] Train Loss: 0.9754


 99%|█████████▉| 198/200 [1:01:04<00:36, 18.46s/it]

[Epoch 198] Train Loss: 0.9810
⚠️ LR reduced at epoch 198 → 6.9e-06


100%|█████████▉| 199/200 [1:01:23<00:18, 18.54s/it]

[Epoch 199] Train Loss: 0.9747


100%|██████████| 200/200 [1:01:43<00:00, 18.52s/it]

[Epoch 200] Train Loss: 0.9738





In [47]:
torch.save(vf.state_dict(), "vf_model.pth")

In [50]:
root_dir_test = "./test"

def standardize_keys(d):
    return {
        os.path.join(os.path.basename(os.path.dirname(k)), os.path.basename(k)): v
        for k, v in d.items()
    }

# 1. Load embeddings
embeddings_test = load_images_recursively(root_dir_test)
print(f" Total images loaded: {len(embeddings_test)}")

# 2. Load cls_map
cls_map_test = create_cls_map(root_dir_test)
print(f" Created cls_map for {len(cls_map_test)} images")

# 3.  Standardize keys
embeddings_test = standardize_keys(embeddings_test)
cls_map_test = standardize_keys(cls_map_test)



 Total images loaded: 284
 Created cls_map for 284 images


In [51]:
recall_val = compute_recall_at_1_embedded(embeddings_test, cls_map_test, type_map, vf, device=device)

In [52]:
recall_val

0.954225352112676