In [1]:
!pip install sentence-transformers --quiet

In [2]:
from transformers import AutoTokenizer, AutoModel
from torch import Tensor
import torch
from tqdm import tqdm
import torch.nn.functional as F


  from .autonotebook import tqdm as notebook_tqdm


## Load tag description

In [5]:
TAG_VERSION = "20240726"

In [6]:
import json
tag_desc_path = f"../data/tag_desc_{TAG_VERSION}.json"

with open(tag_desc_path, 'r') as f:
    tag_desc = json.load(f)

In [7]:
category = list(tag_desc.keys())
category

['ACCOMMODATION',
 'DINING',
 'EXPERIENCE',
 'ACCOMMODATION_TUI',
 'ACCOMMODATION_IDS']

## Encode

In [12]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModel.from_pretrained(model_dir)
    return model, tokenizer


def encode(model, tokenizer, desc):
    encoded_input = tokenizer(desc, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        model_output = model(**encoded_input)
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings

In [13]:
model, tokenizer = model_fn("sentence-transformers/all-MiniLM-L12-v2")

In [14]:
tag_emb = {}

for cate in category:
    tag_emb[cate] = {}
    for tag, desc_sub_cate in tqdm(tag_desc[cate].items(), total = len(tag_desc[cate])):
        if tag == "":
            print(f"skip empty tag")
            continue
        desc, sub_cate = desc_sub_cate[0], desc_sub_cate[1]
        if sub_cate not in tag_emb[cate]:
            tag_emb[cate][sub_cate] = {}
        tag_emb[cate][sub_cate][tag] = {
            'description': desc,
            "miniLM-L12-v2": encode(model, tokenizer, desc)[0].cpu()
        }

100%|██████████| 6131/6131 [02:04<00:00, 49.44it/s]
100%|██████████| 1551/1551 [00:31<00:00, 48.53it/s]
100%|██████████| 255/255 [00:05<00:00, 49.43it/s]
100%|██████████| 124/124 [00:02<00:00, 56.19it/s]
100%|██████████| 361/361 [00:07<00:00, 49.83it/s]


In [None]:
import pickle

with open(f'../data/tag_emb_{TAG_VERSION}.pkl', 'wb') as f:
    pickle.dump(tag_emb, f)

### Turn into json

In [19]:
import pickle

with open(f'../data/tag_emb_{TAG_VERSION}.pkl', 'rb') as f:
    tag_emb = pickle.load(f)
    
# import torch
# import json

# def tensor_to_list(obj):
#     if isinstance(obj, dict):
#         return {k: tensor_to_list(v) for k, v in obj.items()}
#     elif isinstance(obj, list):
#         return [tensor_to_list(v) for v in obj]
#     elif isinstance(obj, torch.Tensor):
#         return obj.tolist()
#     else:
#         return obj

# # Convert tensors to lists
# converted_tag_emb = tensor_to_list(tag_emb)

# # Save to a file
# with open('../data/tag_emb.json', 'w') as f:
#     json.dump(converted_tag_emb, f, indent=2)