In [1]:
"""Script to compute label-label cooccurances"""
import argparse

try:
    from .utils import isnotebook
except ImportError:
    from utils import isnotebook
from pathlib import Path
import json
import sys
sys.path.append('../')
from multilabel_learning.dataset_readers.common import JSONTransform
from multilabel_learning.dataset_readers.utils import smart_read
from allennlp.common.params import Params
import itertools
import numpy as np
import pandas as pd
from fractions import Fraction

In [2]:
def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Compute cooccurance score for a dataset"
    )
    parser.add_argument("-i", "--input-file", type=Path)
    parser.add_argument(
        "-o", "--output-file", type=Path, default=Path("cooccurrences.csv")
    )
    parser.add_argument(
        "-l",
        "--label-field-name",
        default="labels",
        help="name of the field in (transformed) json that contains labels as list",
    )
    parser.add_argument(
        "-t",
        "--json-transform",
        type=(lambda x: JSONTransform.from_params(Params({'type': x}))),
        default=JSONTransform.from_params(Params(dict(type="identity"))),
        help='Registered child of "dataset_readers.common.JSONTransform"',
    )
    if isnotebook():
        import shlex  # noqa

        args_str = (
            "-i ../.data/blurb_genre_collection/sample_train.json -o "
            "../.data/blurb_genre_collection/sample_train_cooccurrences.csv "
            "-t from-blurb-genre"
        )
        args = parser.parse_args(shlex.split(args_str))
    else:
        args = parser.parse_args()
    return args

In [3]:
if __name__=="__main__":
    args = get_args()
    label_sets = [args.json_transform(ex)[args.label_field_name] for ex in smart_read(args.input_file)]
    num_examples = len(label_sets)
    all_labels = set([l for s in label_sets for l in s])
    all_pairs = list(itertools.product(all_labels,repeat=2))
    label_df = pd.DataFrame(Fraction(0.0),columns=all_labels, index=all_labels)
    # count co-occurances
    for label_set in label_sets:
        for a,b in itertools.product(label_set,repeat=2):
            label_df[a][b]+=1
    # get pair-wise conditional probabilities
    for a,b in all_pairs:
        if a!=b:
            label_df[a][b]/=label_df[b][b]
    for l in all_labels:
        label_df[l][l]/=num_examples

    label_df.to_csv(args.output_file, index_label='labels')