In [86]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [204]:
from collections import defaultdict
from fuzzywuzzy import fuzz
from tqdm import tqdm, tqdm_notebook
import random 

from knowledge_graph import KG
from task_utils import *

In [156]:
kg = KG("data/food/construction")

Loading data from ../data/food/construction...
Loaded 1001498 triples.


In [185]:
kg.load_tasks()

Task: entity_matching
entity_matching (train): 12920 samples
dict_keys(['train'])
entity_matching (valid): 1615 samples
dict_keys(['train', 'valid'])
entity_matching (test): 2465 samples
dict_keys(['train', 'valid', 'test'])

Task: arc_matching
arc_matching (train): 0 samples
dict_keys(['train'])
arc_matching (valid): 0 samples
dict_keys(['train', 'valid'])
arc_matching (test): 50 samples
dict_keys(['train', 'valid', 'test'])



## Entity Matching Task

Our baseline consists of string matching on names. 

In [146]:
def generate_candidates(triples, test_x):
    """ The entity matching task consists of mapping USDA entities to Open Food Facts (OFF) entities. 
    In order to make comparisons feasible, we first recommend performing a candidate generation step, where
    we generate a list of likely USDA candidates for every OFF entity we seek to match.
    
    Args:
        triples (list): list of all triples in the KG
        test_x: all unknown USDA entities in the test split
    """
    off2usda_arcs = {
        'name': ['name', 'long_name'],
        'product_name': ['name', 'long_name'],
        'type': ['type']
    }
    
    arcs2save = set()
    for k, v in off2usda_arcs.items():
        arcs2save.add(k)
        for vv in v:
            arcs2save.add(vv)
        
    entity2source = {} 
    entity2features = defaultdict(lambda: defaultdict(list))
    
    # collect features for each entity 
    print("Collecting Features")
    for head, arc, tail, tail_type, source in tqdm(triples):
        entity2source[head] = source
        if arc in arcs2save:
            entity2features[head][arc].append(tail)
    
    # generate candidates 
    candidate_scores = defaultdict(lambda: defaultdict(int))
    print("Scoring candidates")
    
    for off_entity in tqdm(test_x): 
        for candidate_entity in entity2features:
            if not candidate_entity in entity2features:
                continue
                
            if not off_entity in entity2features:
                continue

            # skip candidate if it belongs to USDA
            if entity2source[candidate_entity] == 'OFF':
                continue 
            
            candidate_features = entity2features[candidate_entity]
            # skip entity if it's not a food product
            if not 'food_product' in candidate_features['type']:
                continue
            
    
            '''print("Source entity:",off_entity, entity2source[off_entity])
            print("Source features:",entity2features[off_entity])
            print("Candidate:", candidate_entity, entity2source[candidate_entity])
            print("Candidate features:", candidate_features)'''
            score = 0
            for off_feature_name, usda_feature_name_list in off2usda_arcs.items():
                max_score = 0
                for usda_feature_name in usda_feature_name_list:
                    for candidate_val in candidate_features.get(usda_feature_name, []):
                        for off_val in entity2features[off_entity].get(off_feature_name, []):
                            #print(off_feature_name, off_val, usda_feature_name, candidate_val, fuzz.ratio(off_val, candidate_val))
                            max_score = max(fuzz.ratio(off_val, candidate_val), max_score)
                score += max_score 
            candidate_scores[off_entity][candidate_entity] = score
    return candidate_scores
    

In [147]:
# Get candidates
candidates = generate_candidates(kg.triples, kg.tasks['entity_matching']['test']['X'])

 22%|██▏       | 224072/1001498 [00:00<00:00, 1108640.31it/s]

Collecting Features


100%|██████████| 1001498/1001498 [00:00<00:00, 1192017.74it/s]
  0%|          | 1/2465 [00:00<05:58,  6.88it/s]

Scoring candidates


100%|██████████| 2465/2465 [05:24<00:00,  6.92it/s]


In [148]:
# predict top scoring candidates 
unfiltered_predictions = {} 
all_scores = []
for off_entity in kg.tasks['entity_matching']['test']['X']:
    max_score = 0
    max_candidate = ""
    for candidate, score in candidates[off_entity].items():
        if score > max_score:
            max_score = score
            max_candidate = candidate
    all_scores.append(max_score)
    unfiltered_predictions[off_entity] = (max_candidate, max_score)

# use median score as threshold 
threshold = sorted(all_scores)[int(len(all_scores) / 2)]

filtered_predictions = {}
for off_entity, candidate_tuple in unfiltered_predictions.items():
    candidate, score = candidate_tuple
    if score < threshold:
        filtered_predictions[off_entity] = "None"
    else:
        filtered_predictions[off_entity] = candidate
        
    

In [153]:
X = kg.tasks['entity_matching']['test']['X']
Y = kg.tasks['entity_matching']['test']['Y']
precision, recall, f1 = evaluate_entity_matching(filtered_predictions, X, Y)
print("Precision: %f. Recall: %f. F1: %f" % (precision, recall, f1))

Precision: 0.046661. Recall: 0.068155. F1: 0.055396


## Arc Matching 

We will align arcs for known paired entities

In [186]:
off_entities = kg.tasks['entity_matching']['train']['X']
usda_entities = kg.tasks['entity_matching']['train']['Y']

In [200]:
off_arcs = kg.tasks['arc_matching']['test']['X']
arc_scores = defaultdict(lambda: defaultdict(int))
pairs = list(zip(off_entities, usda_entities))
random.shuffle(pairs)
pairs = pairs[:1000]
for off_e, usda_e in tqdm_notebook(pairs):
    off_triples = kg.filter_triples(head_filter=[off_e])
    usda_triples = kg.filter_triples(head_filter=[usda_e])
    for _, usda_arc, usda_tail, usda_tail_type, _ in usda_triples: 
        if usda_tail_type == 'entity':
            if not usda_tail in entity2name:
                continue
            usda_tail = entity2name[usda_tail]
        for _, off_arc, off_tail, off_tail_type, _ in off_triples: 
            if not off_arc in off_arcs:
                continue 
            if off_tail_type == 'entity':
                if not off_tail in entity2name:
                    continue
                off_tail = entity2name[off_tail]
            arc_scores[off_arc][usda_arc] = fuzz.ratio(usda_tail, off_tail)
        

In [202]:
predictions = {}
for arc in off_arcs:
    best_score = 0
    best_candidate = ""
    for usda_arc, score in arc_scores[arc].items():
        if score > best_score:
            best_score = score 
            best_candidate = usda_arc
    if best_score > 90:
        predictions[arc] = best_candidate
    else:
        predictions[arc] = "None"

In [205]:
X = kg.tasks['arc_matching']['test']['X']
Y = kg.tasks['arc_matching']['test']['Y']
precision, recall, f1 = evaluate_arc_matching(predictions, X, Y)
print("Precision: %f. Recall: %f. F1: %f" % (precision, recall, f1))

Precision: 0.238095. Recall: 0.147059. F1: 0.181818
