# Generate Embeddings by BAAI/BGE

In [None]:
import timm
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import numpy as np
import json
def make_descriptor_sentence(descriptor):
    if descriptor.startswith('a') or descriptor.startswith('an'):
        return f"which is {descriptor}"
    elif descriptor.startswith('has') or descriptor.startswith('often') or descriptor.startswith('typically') or descriptor.startswith('may') or descriptor.startswith('can'):
        return f"which {descriptor}"
    elif descriptor.startswith('used'):
        return f"which is {descriptor}"
    else:
        return f"which has {descriptor}"

In [None]:
# Take ade20k for example.
# Change this to generate embeddings for other datasets. 
# Also set proper clusters number for different datasets.
dataset_name = 'ade20k'

with open(f"./descriptors/descriptors_{dataset_name}_gpt3.5_cluster.json") as json_file:
    descriptions = json.load(json_file)
num_classes = len(tuple(descriptions.keys()))
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
ade_desc_prompt_template_embeddings = dict()
# Load model from HuggingFace Hub
model_name = "bge-base-en-v1.5"
model_name_prefix = "BAAI/"
tokenizer = AutoTokenizer.from_pretrained(model_name_prefix + model_name)
model = AutoModel.from_pretrained(model_name_prefix + model_name).to(device)
model.eval()

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


all_descriptions_embeddings = []
for class_name, desc in descriptions.items():
    
    sentences = [item.lower() for item in desc]
    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device)    
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
    
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1).cpu()
        all_descriptions_embeddings.append(sentence_embeddings)
all_descriptions_embeddings_tensor = torch.cat(all_descriptions_embeddings, dim=0)


In [None]:
# Choose proper cluster number for different datasets. We set 256 for ADE20K.
n_clusters = 256
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto").fit(all_descriptions_embeddings_tensor.numpy())
desc_class_idx_list = []
for class_idx, class_desc in enumerate(all_descriptions_embeddings):
    desc_class_idx_list.append(torch.tensor([class_idx] * len(class_desc)))
desc_class_idx_list = torch.cat(desc_class_idx_list)
ground_truth_all_classes = []
activated_clusters_all_classes = []
for class_idx in range(num_classes):
    class_cluster_info = kmeans.labels_[desc_class_idx_list == class_idx]
    activated_clusters = np.unique(class_cluster_info)
    print(activated_clusters)
    activated_clusters_all_classes.append(activated_clusters)
    ground_truth = np.zeros(n_clusters)
    ground_truth[activated_clusters] = 1
    ground_truth_all_classes.append(ground_truth)

In [None]:
unique_label, unique_count =  np.unique(np.array(ground_truth_all_classes), axis=0, return_counts=True)
print(unique_count)
if not (unique_count == 1).all():
    confused_labels = np.where((np.array(ground_truth_all_classes) == unique_label[unique_count>1]).all(axis=1))
    for class_idx in confused_labels[0]:
        print('----------------')
        print(tuple(descriptions.keys())[class_idx])
        print(descriptions[tuple(descriptions.keys())[class_idx]])
        print('----------------')

ground_truth_all_classes = torch.tensor(np.array(ground_truth_all_classes).transpose(1, 0))

cluster_embedding_bank = kmeans.cluster_centers_.transpose(1, 0)
cluster_embedding_bank = torch.tensor(cluster_embedding_bank)

cluster_bank = [cluster_embedding_bank, ground_truth_all_classes]

cluster_dict = {'descriptions': descriptions}
cluster_dict[f'{model_name}_gpt3.5_cluster_{n_clusters}_embeddings_and_labels'] = cluster_bank

In [None]:
torch.save(cluster_dict, f'./embeddings/{dataset_name}_desc_{model_name}_gpt3.5_cluster_{n_clusters}_embedding_bank.pth')