# Import

In [1]:
import os

os.chdir(os.path.dirname(os.getcwd()))

In [2]:
import json
from tqdm import tqdm
from collections import defaultdict

import faiss
import pickle
import numpy as np

from run import ModelRunner
from jovis_model.config import Config
from jovis_model.utils.helper import build_faiss_index
from jovis_model.utils.report import ReportMaker

# InternVL Report

### Runner

In [3]:
params = {
    "pkg": "llm",
    "task": "sentence_embedding",
    "use_hf_model": True,
    "params": {
        "hf_name": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
    }
}
config = Config(**params)
runner = ModelRunner(
    config=config,
    mode="inference"
)



### Helper

In [4]:
def parse_desc(desc, target:list, dropna: bool = True):
    parsed = {}
    try:
        desc = json.loads(desc)
        for k, v in desc.items():
            if k in target: 
                if k == "Fashion attribute":
                    if isinstance(v, dict):
                        fa_dict = {}
                        for k_, v_ in v.items():
                            if v_ != "N/A":
                                fa_dict[k_] = v_
                        v = fa_dict
                    parsed[k] = v 
                else:
                    if v != "N/A":
                        parsed[k] = v
    except Exception as e:
        return None
    return parsed

### descriptions

In [5]:
with open("outputs/skb/descriptions_original.json", "r") as f:
    desc = json.load(f)

### build search embeddings

In [6]:
target_columns = ["Season", "Place", "Occasion", "Style", "Gender", "Background", "Model", "Fashion attribute"]
dropna=True

In [7]:
pids = []
embeddings = []
for pid, des in tqdm(desc.items()):
    des = parse_desc(
        des,
        target=target_columns,
        dropna=dropna
    )
    if des:
        des = json.dumps(des)
        pids.append(pid)
        embeddings.append(runner.run([des]).detach().cpu().numpy()[0])
        
print(f"Total: {len(desc)}, Success: {len(pids)}")

100%|██████████| 19180/19180 [03:58<00:00, 80.42it/s] 

Total: 19180, Success: 19174





In [8]:
build_faiss_index(
    embeddings=embeddings,
    save_path="outputs/skb",
    save_name="descriptions_all_mpnet",
    pids=pids
)

### build quey embeddings

In [9]:
target_columns = ["Season", "Place", "Occasion", "Style"]
dropna=True

In [10]:
pids = []
embeddings = []
for pid, des in tqdm(desc.items()):
    des = parse_desc(
        des,
        target=target_columns,
        dropna=dropna
    )
    if des:
        des = json.dumps(des)
        pids.append(pid)
        embeddings.append(runner.run([des]).detach().cpu().numpy()[0].tolist())
print(f"Total: {len(desc)}, Success: {len(pids)}")

100%|██████████| 19180/19180 [03:29<00:00, 91.35it/s] 

Total: 19180, Success: 17368





In [11]:
f_name = "query_embeddings_4cols_mpnet"

In [12]:
with open(f"outputs/skb/{f_name}.json", "w") as f:
    json.dump({k:v for k, v in zip(pids, embeddings)}, f)

### Evaluation

In [13]:
index_file_name = "descriptions_all_mpnet"

In [14]:
query_file_name = "query_embeddings_4cols_mpnet"

In [15]:
desc_index = faiss.read_index(f"outputs/skb/{index_file_name}.index")
with open(f"outputs/skb/{index_file_name}_map.json", "r") as f:
    desc_map = json.load(f)
with open(f"outputs/skb/{query_file_name}.json", "r") as f:
    desc_text_embeddings = json.load(f)

In [17]:
len(desc_text_embeddings)

17368

In [19]:
res = {}
hit = 0
topk = 10
for pid, embed in tqdm(list(desc_text_embeddings.items())):
    scores, ids = desc_index.search(np.array(embed).reshape(1, -1), topk)
    scores = [f"[invl] {s:.4f}" for s in scores[0]]
    ids = [desc_map[str(i)] for i in ids[0]]
    if pid in ids:
        hit += 1

100%|██████████| 17368/17368 [38:37<00:00,  7.49it/s]


### for report

In [72]:
target_columns = ["Season", "Place", "Occasion", "Style", "Fashion attribute"]
dropna=True

In [73]:
querys = {}
for pid, des in tqdm(desc.items()):
    des = parse_desc(
        des,
        target=target_columns,
        dropna=dropna
    )
    if des:
        f_des = []
        for t in target_columns:
            tmp = des.get(t, "")
            if tmp != "":
                if t == "Fashion attribute":
                    if isinstance(tmp, dict):
                        for k, v in tmp.items():
                            f_des.append(f"{k}: {v}")
                    else:
                        f_des.append(f"{t}: {tmp}")
                else:
                    f_des.append(f"{t}: {tmp}")
        f_des = "<br>".join(f_des)
        querys[pid] = f_des 

100%|██████████| 19180/19180 [00:00<00:00, 60666.32it/s]


In [80]:
topk = 5

In [81]:
res = {}
hit = 0
for pid, embed in tqdm(list(desc_text_embeddings.items())[:100]):
    res[pid] = defaultdict(list)
    scores, ids = desc_index.search(np.array(embed).reshape(1, -1), topk)
    scores = ["[invl] {:.4f}".format(s) for s in scores[0]]
    ids = [desc_map[str(i)] for i in ids[0]]
    scores = [[i, s] for i, s in zip(ids, scores)]
    desc_str = querys[pid]
    res[pid][desc_str].append({"text": scores, "image": ids})
    if pid in ids:
        hit += 1

100%|██████████| 100/100 [00:01<00:00, 84.06it/s]


In [82]:
rm = ReportMaker(
    data_dict=res,
    image_path="/data/local/multimodal_for_skb/images/skb",
    max_len=10
)

In [83]:
rm.make_report(
    save_path="outputs/skb",
    save_name="multimodal_internvl_1"
)