In [20]:
from categories import secondary_category_BoW_flattened, primary_category_BoW, secondary_category_BoW
from predict_string import rank_categories_by_features
import pandas as pd

INFO:numexpr.utils:NumExpr defaulting to 8 threads.


In [8]:
# About GT labelling - construct value pool templates
concrete_primary_categories = set(primary_category_BoW.keys()) - set(secondary_category_BoW)

all_concrete_categories = concrete_primary_categories.union(secondary_category_BoW_flattened.keys())
value_pool = {}

for cat in all_concrete_categories:
    value_pool[cat] = {
        'general': []
    }
    
import json
print(json.dumps(value_pool, indent=4))

{
    "location": {
        "general": []
    },
    "news": {
        "general": []
    },
    "podcast": {
        "general": []
    },
    "city": {
        "general": []
    },
    "artist": {
        "general": []
    },
    "album": {
        "general": []
    },
    "playlist_url": {
        "general": []
    },
    "second": {
        "general": []
    },
    "name": {
        "general": []
    },
    "ratio": {
        "general": []
    },
    "auth_code": {
        "general": []
    },
    "description": {
        "general": []
    },
    "storage": {
        "general": []
    },
    "minute": {
        "general": []
    },
    "book": {
        "general": []
    },
    "url": {
        "general": []
    },
    "video": {
        "general": []
    },
    "iot_device": {
        "general": []
    },
    "payee": {
        "general": []
    },
    "balance": {
        "general": []
    },
    "department": {
        "general": []
    },
    "label": {
        "general": []
    

In [None]:
def evaluate_ranking(data_info, labels, raw_result=False, prune=False, prune_count=3, use_context=False, baseline=False):
    tf = data_info['textfield_info']

    if data_info['textfield_path'] not in labels:
        raise Exception(f'No label found for {data_info["textfield_path"]}')

    target_labels = convert_category_labels(labels[data_info['textfield_path']])

    if use_context:
        ranked_categories_with_distance = find_nearest_category_using_context(tf, data_info['view_tree'], k=CATEGORY_SIZE, prune=prune, prune_count=prune_count)
    elif baseline:
        ranked_categories_with_distance = find_nearest_category_baseline(tf, k=CATEGORY_SIZE)
    else:
        ranked_categories_with_distance = find_nearest_category(tf, k=CATEGORY_SIZE, prune=prune, prune_count=prune_count)

    ranked_categories = [category for category, _ in ranked_categories_with_distance]

    if ranked_categories_with_distance[0][1] > DISTANCE_THRESHOLD:
        ranked_categories.insert(0, 'any')
        ranked_categories_with_distance.insert(0, ('any', DISTANCE_THRESHOLD))

    if 'any' in target_labels:
        if raw_result:
            return None, ranked_categories_with_distance

        return None

    min_index = CATEGORY_SIZE
    for l in target_labels:
        i = ranked_categories.index(l)
        if i < min_index:
            min_index = i

    if raw_result:
        return (min_index + 1, ranked_categories_with_distance)

    return min_index + 1

In [16]:
def get_GT_categories(tf_label):
    categories = {}
    for cat in tf_label['GT_categories']:
        categories[cat[0]] = cat[1]
    
    pcats = list(categories.keys())
    
    return pcats, categories


def get_cat_rank(GT_cats, cats_ranked):
    cats_ranked = [c[0] for c in cats_ranked]
    min_rank = 100000
    for GT_cat in GT_cats:
        rank = cats_ranked.index(GT_cat) + 1
        if rank < min_rank:
            min_rank = rank
    
    return min_rank

In [26]:
queue = ['data/textfield_features_OSS.json', 'data/textfield_features_samsung.json']

labels = json.load(open('data/GT/labels.json'))

ranking_result = []

settings = []

for tf_feature_file in queue:
    tf_features = json.load(open(tf_feature_file))
    
    for tf_path, tf in tf_features.items():
        if tf_path not in labels:
            continue # skip not yet labelled data
        
        GT_pcats, GT_categories = get_GT_categories(labels[tf_path])
        
        primary_rank_result = rank_categories_by_features(tf['textfield_tokens'], tf['local_context'], tf['global_context'], category_BoW=primary_category_BoW, prune=True, prune_count=3, weight_g=0.0, extend_local_ctx=True)
        pcat = primary_rank_result[0][0]
        # TODO: try top2/top3 ranked primary categories too? 
        
        secondary_rank_result = []
        scat = None
        
        if pcat in secondary_category_BoW:
            secondary_rank_result = rank_categories_by_features(tf['textfield_tokens'], tf['local_context'], tf['global_context'], category_BoW=secondary_category_BoW[pcat], prune=True, prune_count=3, weight_g=0.5, extend_local_ctx=True)
            scat = secondary_rank_result[0][0]
        
        ranking_result.append({
            'source': tf_feature_file,
            'app_name': tf['app_name'],
            'tf_path': tf_path,
            'pred_top1': [pcat, scat],
            'pcat_rank': get_cat_rank(GT_pcats, primary_rank_result),
            'scat_rank': get_cat_rank(GT_categories[pcat], secondary_rank_result) if pcat in GT_pcats else -1,
            'primary_categories_ranked': primary_rank_result,
            'secondary_categories_ranked': secondary_rank_result
        })
        
rank_df = pd.DataFrame(ranking_result)

In [27]:
rank_df

Unnamed: 0,source,app_name,tf_path,pred_top1,pcat_rank,scat_rank,primary_categories_ranked,secondary_categories_ranked
0,data/textfield_features_OSS.json,Instagram,data/OSS/textfield_contexts/Instagram_1/email_...,"[profile, email]",1,1.0,"[(profile, 0.37276495123902953), (auth_code, 0...","[(email, 0.20243436718980473), (phone, 0.50456..."
1,data/textfield_features_OSS.json,HERE_WeGo_Maps_Navigation_v4.4.200,data/OSS/textfield_contexts/HERE_WeGo_Maps_Nav...,"[search, tv_show]",1,2.0,"[(search, 0.6540699079632759), (location, 0.71...","[(tv_show, 0.5913960039615631), (location, 0.6..."
2,data/textfield_features_OSS.json,me.hackerchick.catima,data/OSS/textfield_contexts/me.hackerchick.cat...,"[search, storage]",1,1.0,"[(search, 0.4863956930736701), (auth_code, 0.6...","[(storage, 0.5714529690643151), (web, 0.660388..."
3,data/textfield_features_OSS.json,me.hackerchick.catima,data/OSS/textfield_contexts/me.hackerchick.cat...,"[code, card_id]",1,1.0,"[(code, 0.23469410339991253), (profile, 0.6545...","[(card_id, 0.1142329846819242), (coupon_code, ..."
4,data/textfield_features_OSS.json,me.hackerchick.catima,data/OSS/textfield_contexts/me.hackerchick.cat...,"[profile, name]",2,,"[(profile, 0.40422751009464264), (label, 0.431...","[(name, 0.41733503341674805), (username, 0.485..."
...,...,...,...,...,...,...,...,...
100,data/textfield_features_samsung.json,Clock,data/samsung_internal/Clock/state_2022-11-01_1...,"[datetime, second]",1,1.0,"[(datetime, 0.6410585244496663), (profile, 0.8...","[(second, 0.5160047913280627), (time, 0.673015..."
101,data/textfield_features_samsung.json,Clock,data/samsung_internal/Clock/state_2022-11-01_1...,"[profile, name]",2,,"[(profile, 0.4420238062739372), (label, 0.4804...","[(name, 0.4581761391212543), (username, 0.5324..."
102,data/textfield_features_samsung.json,Calendar,data/samsung_internal/Calendar/state_2022-11-0...,"[label, None]",1,100000.0,"[(label, 0.20613006750742593), (profile, 0.600...",[]
103,data/textfield_features_samsung.json,Calendar,data/samsung_internal/Calendar/state_2022-11-0...,"[location, location]",1,1.0,"[(location, 0.20663351317246756), (profile, 0....","[(location, 0.11815407623847325), (city, 0.632..."
