In [2]:
import torch
import pickle

## Make Local Tag Embedding Index For TopK Search

### Load Tag Embedding

In [3]:
TAG_VERSION = "20240726"

In [6]:
with open(f"../data/tag_emb_{TAG_VERSION}.pkl", "rb") as f:
    tag_emb = pickle.load(f)

  return torch.load(io.BytesIO(b))


In [7]:
tag_emb.keys()

dict_keys(['ACCOMMODATION', 'DINING', 'EXPERIENCE', 'ACCOMMODATION_TUI', 'ACCOMMODATION_IDS'])

In [8]:
category = tag_emb.keys()

In [9]:
# tag_emb['ACCOMMODATION']['Category']['Luxury']

### Create tag_list ~ embedding

In [10]:
for cate in category:
    subcat_tag_data = tag_emb[cate]
    if cate == "DINING":
        dining_tag_list = []
        dining_emb_list = []
        for sub_cate, tag_data in subcat_tag_data.items():
            for tag, data in tag_data.items():
                dining_tag_list.append((tag, sub_cate))
                dining_emb_list.append(data['miniLM-L12-v2'])
    elif cate == "EXPERIENCE":
        experience_tag_list = []
        experience_emb_list = []
        for sub_cate, tag_data in subcat_tag_data.items():
            for tag, data in tag_data.items():
                experience_tag_list.append((tag, sub_cate))
                experience_emb_list.append(data['miniLM-L12-v2'])
    elif cate == "ACCOMMODATION":
        accomm_tag_list, accomm_brand_tag_list = [], []
        accomm_tag_emb_list, accomm_brand_tag_emb_list = [], []
        for sub_cate, tag_data in subcat_tag_data.items():
            for tag, data in tag_data.items():
                if sub_cate == "Brand":
                    accomm_brand_tag_list.append((tag, sub_cate))
                    accomm_brand_tag_emb_list.append(data['miniLM-L12-v2'])
                else:
                    accomm_tag_list.append((tag, sub_cate))
                    accomm_tag_emb_list.append(data['miniLM-L12-v2'])
    elif cate == "ACCOMMODATION_TUI":
        tui_accomm_tag_list, tui_accomm_brand_tag_list = [], []
        tui_accomm_tag_emb_list, tui_accomm_brand_tag_emb_list = [], []
        for sub_cate, tag_data in subcat_tag_data.items():
            for tag, data in tag_data.items():
                if sub_cate == "Brand":
                    tui_accomm_brand_tag_list.append((tag, sub_cate))
                    tui_accomm_brand_tag_emb_list.append(data['miniLM-L12-v2'])
                else:
                    tui_accomm_tag_list.append((tag, sub_cate))
                    tui_accomm_tag_emb_list.append(data['miniLM-L12-v2'])
    elif cate == "ACCOMMODATION_IDS":
        ids_accomm_tag_list, ids_accomm_brand_tag_list = [], []
        ids_accomm_tag_emb_list, ids_accomm_brand_tag_emb_list = [], []
        for sub_cate, tag_data in subcat_tag_data.items():
            for tag, data in tag_data.items():
                if sub_cate == "Brand":
                    ids_accomm_brand_tag_list.append((tag, sub_cate))
                    ids_accomm_brand_tag_emb_list.append(data['miniLM-L12-v2'])
                else:
                    ids_accomm_tag_list.append((tag, sub_cate))
                    ids_accomm_tag_emb_list.append(data['miniLM-L12-v2'])

In [11]:
dining_emb = torch.stack(dining_emb_list, axis=0)
experience_emb = torch.stack(experience_emb_list, axis=0)
accomm_brand_tag_emb = torch.stack(accomm_brand_tag_emb_list, axis=0)
accomm_tag_emb = torch.stack(accomm_tag_emb_list, axis=0)
# tui_accomm_brand_tag_emb = torch.stack(tui_accomm_brand_tag_emb_list, axis=0)
tui_accomm_tag_emb = torch.stack(tui_accomm_tag_emb_list, axis=0)
ids_accomm_brand_tag_emb = torch.stack(ids_accomm_brand_tag_emb_list, axis=0)
# ids_accomm_tag_emb = torch.stack(ids_accomm_tag_emb_list, axis=0)

In [12]:
dining_emb.shape, experience_emb.shape, accomm_brand_tag_emb.shape, accomm_tag_emb.shape, tui_accomm_tag_emb.shape, ids_accomm_brand_tag_emb.shape

(torch.Size([1551, 384]),
 torch.Size([255, 384]),
 torch.Size([5551, 384]),
 torch.Size([580, 384]),
 torch.Size([124, 384]),
 torch.Size([361, 384]))

In [13]:
len(dining_tag_list), len(experience_tag_list), len(accomm_brand_tag_list), len(accomm_tag_list), len(tui_accomm_tag_list), len(ids_accomm_brand_tag_list)

(1551, 255, 5551, 580, 124, 361)

In [14]:
# new added tag
("TUI Blue", "Brand") in accomm_brand_tag_list

False

In [15]:
("Free Bottled Water", "Special Property Features") in tui_accomm_tag_list

True

In [16]:
("Swire Hotels", "Brand") in ids_accomm_brand_tag_list

True

### Save

In [21]:
import os

LOCAL_INDEX_FOLDER = "../local_tag_emb_index"
os.makedirs(LOCAL_INDEX_FOLDER, exist_ok=True)

local_index = [
    "dining-tag-vector", 
    "experience-tag-vector", 
    "accommodation-tag-vector",
    "accommodation-brand-tag-vector",
    "tui-accommodation-tag-vector",
    "ids-accommodation-brand-tag-vector",
]

index2data = {
    "dining-tag-vector": (dining_tag_list, dining_emb),
    "experience-tag-vector": (experience_tag_list, experience_emb),
    "accommodation-brand-tag-vector": (accomm_brand_tag_list, accomm_brand_tag_emb),
    "accommodation-tag-vector": (accomm_tag_list, accomm_tag_emb),
    "tui-accommodation-tag-vector": (tui_accomm_tag_list, tui_accomm_tag_emb),
    "ids-accommodation-brand-tag-vector": (ids_accomm_brand_tag_list, ids_accomm_brand_tag_emb),
}

In [22]:
for index in local_index:
    tag_list, embedding = index2data[index]
    with open(os.path.join(LOCAL_INDEX_FOLDER, f"{index}.pkl"), 'wb') as f:
        pickle.dump(
            {"tags": tag_list, "embeddings": embedding}, f 
        )

In [23]:
def load_index(index):
    with open(os.path.join(LOCAL_INDEX_FOLDER, f"{index}.pkl"), 'rb') as f:
        data = pickle.load(f)
    return { "tags": data["tags"], "embeddings": data["embeddings"] }


INDEX2DATA = { idx:load_index(idx) for idx in local_index }


  return torch.load(io.BytesIO(b))


In [24]:
INDEX2DATA.keys()

dict_keys(['dining-tag-vector', 'experience-tag-vector', 'accommodation-tag-vector', 'accommodation-brand-tag-vector', 'tui-accommodation-tag-vector', 'ids-accommodation-brand-tag-vector'])

## Test Inference.py Locally

In [25]:
!pip install sentence-transformers --quiet

In [32]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import pickle
import os


# ----------------------------- #
#      Index For TopK Search    #
# ----------------------------- #
LOCAL_INDEX_FOLDER = "../local_tag_emb_index"

local_index = [
    "dining-tag-vector", 
    "experience-tag-vector", 
    "accommodation-tag-vector",
    "accommodation-brand-tag-vector",
    "tui-accommodation-tag-vector",
    "ids-accommodation-brand-tag-vector",
]


def load_index(index):
    with open(os.path.join(LOCAL_INDEX_FOLDER, f"{index}.pkl"), 'rb') as f:
        data = pickle.load(f)
    return { "tags": data["tags"], "embeddings": data["embeddings"] }


INDEX2DATA = { idx:load_index(idx) for idx in local_index }

def retrieve_topk_tags(query_emb, query_index, topk=5):
    # remove when we start using other models
    assert query_emb.shape  == (1, 384)
    
    index_data = INDEX2DATA[query_index]
    
    tag_list, tag_embeddings = index_data['tags'], index_data['embeddings']
    
    similarities = F.cosine_similarity(query_emb, tag_embeddings, dim=1)
    
    # Sort indices based on similarity in descending order
    sorted_indices = torch.argsort(similarities, descending=True).tolist()
    
    topk_tags = [tag_list[idx] for idx in sorted_indices[:topk]]
    
    return topk_tags
    
    
# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModel.from_pretrained(model_dir)
    return model, tokenizer


def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
 
    # Tokenize sentences
    query = data.pop("inputs", data)
    query_index = data.pop("index", data)
    query_topK = data.pop("topK", data)
    
    encoded_input = tokenizer(query, padding=True, truncation=True, return_tensors='pt')
 
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)
 
    # Perform pooling
    query_emb = mean_pooling(model_output, encoded_input['attention_mask'])
 
    # Normalize embeddings
    query_emb = F.normalize(query_emb, p=2, dim=1)
        
    topK_tags = retrieve_topk_tags(query_emb, query_index, query_topK)
 
    # return dictonary, which will be json serializable
    return {"topK_tags":topK_tags}


  return torch.load(io.BytesIO(b))


In [33]:
model_dir = "sentence-transformers/all-MiniLM-L12-v2"
data = {
    "inputs": "TUI Blue is a premium accommodation brand designed for leisure travelers seeking a blend of comfort, local culture, and personalized experiences. Offering stylish hotels and resorts in picturesque destinations, TUI Blue focuses on providing exceptional service, modern amenities, and a variety of activities to ensure a memorable and relaxing vacation.",
    "index": "accommodation-brand-tag-vector",
    "topK": 10 
}

In [28]:
model_dir = "sentence-transformers/all-MiniLM-L12-v2"
data = {
    "inputs": "TUI Blue is a premium accommodation brand designed for leisure travelers seeking a blend of comfort, local culture, and personalized experiences. Offering stylish hotels and resorts in picturesque destinations, TUI Blue focuses on providing exceptional service, modern amenities, and a variety of activities to ensure a memorable and relaxing vacation.",
    "index": "tui-accommodation-tag-vector",
    "topK": 10 
}

In [34]:
model_dir = "sentence-transformers/all-MiniLM-L12-v2"
data = {
    "inputs": "Swire Hotels are my favorite.",
    "index": "ids-accommodation-brand-tag-vector",
    "topK": 10
}

In [35]:
model_and_tokenizer = model_fn(model_dir)

In [36]:
predict_fn(data, model_and_tokenizer)

{'topK_tags': [('Swire Hotels', 'Brand'),
  ('element', 'Brand'),
  ('SWOT', 'Brand'),
  ('Renaissance', 'Brand'),
  ('Destination Hotels', 'Brand'),
  ('Small Luxury Hotels', 'Brand'),
  ('VP Hotels', 'Brand'),
  ('On Hotels', 'Brand'),
  ('The Leading Hotels of the World', 'Brand'),
  ('W Hotels', 'Brand')]}