In [1]:
import torch, clip
import os.path as osp
import numpy as np
import faiss
import pandas as pd
import braceexpand
import webdataset as wds
from utils.myutils import extract_words
from gensim.models import Word2Vec, KeyedVectors
from sklearn.cluster import KMeans
clip_ckpts = {
        'clip-vit-b-32': 'ViT-B/32',
        'clip-vit-b-16': 'ViT-B/16',
        'clip-vit-l-14': 'ViT-L/14',
    }

model_name = 'clip-vit-b-16'
device = 'cuda:5'
model, _ = clip.load(clip_ckpts[model_name], device=device)



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
path = "/workspace/Dataset/local_data/clip-vit-b-16_8_c3_shards/"
prefix = "c3-000000.tar"
tar_file = osp.join(path, prefix)

dataset = (
    wds.WebDataset(tar_file, repeat=True)
    .decode('pil')
    .rename(image='jpg;png;jpeg', text='text;txt', keep=False,)
)


In [3]:
idx = 1024
keyword_list = []
for i, data in enumerate(dataset):
    img = data['image']
    text, nouns, keywords = extract_words(data['text'])
    keyword_list += keywords
    if i == idx:
        break
    
keyword_list = list(set(keyword_list))



In [4]:
text_tokens = clip.tokenize(keyword_list)
print()
text_embs = model.encode_text(text_tokens.to(device))

<class 'torch.Tensor'>


In [10]:
text_embs.dtype

torch.float16

: 

In [5]:
K = 16
kmeans = KMeans(n_clusters=K, max_iter=100).fit(text_embs.cpu().detach().numpy())

In [6]:
kmeans_archive = {}

labels = kmeans.labels_
for i in range(K):
    kmeans_archive[str(i)] = []
    
for i, label in enumerate(labels):
    kmeans_archive[str(label)].append(keyword_list[i])

In [7]:
cluster_centers = kmeans.cluster_centers_

# 각 데이터 포인트에서 가장 가까운 클러스터 중심까지의 거리 계산
distances = np.sqrt(((text_embs.cpu().detach().numpy() - cluster_centers[:, np.newaxis])**2).sum(axis=2))

# 각 클러스터에 대해 가장 중심에 가까운 데이터의 인덱스 찾기
closest_data_points = np.argmin(distances, axis=1)


In [8]:
keyword = [keyword_list[i] for i in closest_data_points]

keyword

['box',
 'skeleton',
 'eagle',
 'mountain',
 'ring',
 'tree',
 'road',
 'flowers',
 'table',
 'kids',
 'van',
 'feet',
 'plants',
 'football player',
 'food',
 'boxes']