In [1]:
import torch
import numpy as np
from tqdm import tqdm
import os
from sklearn.metrics.pairwise import cosine_similarity
from model import SGRAF
from data import get_dataset, get_loader
from vocab import deserialize_vocab
import argparse
import json

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def extract_image_embeddings(model_path, opt, save_path):
    print(f"Loading model from {model_path}")
    checkpoint = torch.load(model_path)

    # Load and patch opt if vocab_size is missing
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, f"{opt.data_name}_vocab.json"))
    opt.vocab_size = len(vocab)

    # Rebuild model and load weights
    model = SGRAF(opt)
    model.load_state_dict(checkpoint["model_A"])  # or "model_B"
    model.val_start()

    # Load dataset and dataloader
    loader, _, _ = get_loader(
        captions, images, "warmup",
        opt.batch_size, opt.workers,
        opt.noise_ratio, opt.noise_file
    )

    all_img_embs = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Extracting image embeddings"):
            images = batch[0].cuda()  # batch = (images, captions, ...)
            img_embs = model.img_enc(images)  # (B, 36, 1024)
            all_img_embs.append(img_embs.cpu().numpy())

    # Save as npy
    img_embs = np.concatenate(all_img_embs, axis=0)
    np.save(save_path, img_embs)
    print(f"Saved image embeddings to {save_path}")
    
    return img_embs

In [3]:
def build_batches_from_embeddings(img_embs, batch_size=128):
    """Stage 2 Step 2: Build batches based on image similarity."""
    n = len(img_embs)
    sim_matrix = cosine_similarity(img_embs)
    used = np.zeros(n, dtype=bool)
    batches = []

    anchor_idx = np.random.choice(np.where(~used)[0])
    while not used.all():
        sim = sim_matrix[anchor_idx]
        sim[used] = -np.inf  # mask already used
        topk = np.argsort(sim)[-batch_size:]

        batches.append(topk)
        used[topk] = True

        # 다음 anchor는 현재 anchor와 가장 유사도가 낮은 것 중 미사용으로 선택
        remaining = np.where(~used)[0]
        if len(remaining) == 0:
            break
        anchor_idx = remaining[np.argmin(sim_matrix[anchor_idx][remaining])]

    print(f"Built {len(batches)} batches of size {batch_size}")
    return batches

In [4]:
def detect_noisy_pairs(model_path, opt, batch_indices, captions):
    """Stage 2 Step 3: Detect noisy image-caption pairs within each batch."""
    checkpoint = torch.load(model_path)
    model = SGRAF(opt)
    model.load_state_dict(checkpoint["model_A"])
    model.val_start()

    model.txt_enc.eval()
    vocab = deserialize_vocab(os.path.join(opt.vocab_path, f"{opt.data_name}_vocab.json"))

    noisy_indices = []
    for batch in tqdm(batch_indices, desc="Detecting noisy pairs"):
        batch_caps = [captions[i] for i in batch]
        lengths = [len(c) for c in batch_caps]
        cap_tensor = torch.zeros(len(batch_caps), max(lengths)).long()
        for i, cap in enumerate(batch_caps):
            cap_tensor[i, :lengths[i]] = torch.tensor(cap[:lengths[i]])
        cap_tensor = cap_tensor.cuda()

        with torch.no_grad():
            cap_emb = model.txt_enc(cap_tensor, lengths)  # (B, L, D)

        cap_avg = cap_emb.mean(dim=1)
        sim = cosine_similarity(cap_avg.cpu().numpy())

        # 중앙값 기준보다 작은 평균 유사도를 가진 것들을 noisy로 간주
        avg_sim = sim.mean(axis=1)
        threshold = np.median(avg_sim)
        noisy = np.where(avg_sim < threshold)[0]
        noisy_indices.append([batch[i] for i in noisy])

    return noisy_indices

In [5]:
with open('./output/2025_04_07_15_16_52/config.json', 'r') as f:
    config = json.load(f)

opt = argparse.Namespace(**config)

In [6]:
model_path = "./output/2025_04_07_15_16_52/model_best.pth.tar"
img_emb_save_path = "./train_img_embs.npy"

img_embs = extract_image_embeddings(model_path, opt, img_emb_save_path)

Loading model from ./output/2025_04_07_15_16_52/model_best.pth.tar
load /home/capstone_nc/NCR-data/data/f30k_precomp / train data: 29000 images, 145000 captions
=> load noisy index from output/2025_04_07_15_16_52/f30k_precomp_0.2.npy
train  data has a size of 145000


Extracting image embeddings: 100%|██████████| 1133/1133 [00:47<00:00, 23.81it/s]


Saved image embeddings to ./train_img_embs.npy


In [None]:
img_embs[0].shape

(36, 1024)

In [None]:
batches_first = build_batches_from_embeddings(img_embs, batch_size=128)