In [65]:
def jaccard(pred_labels, true_labels, weights=None):
    # compute weighted/non-weighted jaccard distance between true and predicted labels
    # smaller weights indicate higher importance
    
    pred_labels = set(pred_labels)
    true_labels = set(true_labels)
    
    intersection = pred_labels.intersection(true_labels)
    union = pred_labels.union(true_labels)
    
    if weights:
        intersection_size = sum([
            (1 / weights[label]) for label in intersection
        ])
        
        union_size = sum([
            (1 / weights[label]) for label in union
        ])
    else:
        intersection_size = len(intersection)
        union_size = len(union)
    
    return intersection_size / union_size

In [67]:
def precision(pred_labels, true_labels, weights=None):
    # compute weighted/non-weighted precision between true and predicted labels
    # smaller weights indicate higher importance
    
    pred_labels = set(pred_labels)
    true_labels = set(true_labels)
    
    intersection = pred_labels.intersection(true_labels)
    
    if weights:
        intersection_size = sum([
            (1 / weights[label]) for label in intersection
        ])
        
        pred_size = sum([
            (1 / weights[label]) for label in pred_labels
        ])
    else:
        intersection_size = len(intersection)
        pred_size = len(pred_labels)
    
    return intersection_size / pred_size

def recall(pred_labels, true_labels, weights=None):
    # compute weighted/non-weighted recall between true and predicted labels
    # smaller weights indicate higher importance
    
    pred_labels = set(pred_labels)
    true_labels = set(true_labels)
    
    intersection = pred_labels.intersection(true_labels)
    false_negatives = true_labels.difference(pred_labels)
    
    if weights:
        intersection_size = sum([
            (1 / weights[label]) for label in intersection
        ])
        
        false_negative_size = sum([
            (1 / weights[label]) for label in false_negatives
        ])
    else:
        intersection_size = len(intersection)
        pred_size = len(false_negatives)
    
    return intersection_size / (intersection_size + false_negative_size)

def f1_score(pred_labels, true_labels, weights=None):
    precision_score = precision(pred_labels, true_labels, weights)
    recall_score = recall(pred_labels, true_labels, weights)
    
    return (2 * precision_score * recall_score) / (precision_score + recall_score)

In [68]:
f1_score(
    [1,2,3,4,5,6,7,11], [1,2,3,4,5,6,7,8],
    {
        0: 1,
        1: 6,
        2: 1,
        3: 1,
        4: 0.001,
        5: 1,
        6: 1,
        7: 8,
        8: 1,
        11: 8
    }
)

0.9994402172786266