In [3]:
import argparse
import glob
import json
import os
from itertools import chain

import pandas as pd
from joblib import Parallel, delayed
from omegaconf import OmegaConf
from sklearn.model_selection import StratifiedKFold


In [4]:
# ---
def print_line():
    prefix, unit, suffix = "#", "--", "#"
    print(prefix + unit*50 + suffix)


In [17]:
def _process_json(fp):
    """process JSON files with annotations

    :param fp: file path
    :type fp: str
    :return: parsed annotation
    :rtype: dict
    """

    # read annotations ---
    with open(fp, "r") as f:
        anno = json.load(f)

    # store necessary data for labels ---
    chart_id = fp.split("/")[-1].split(".")[0]
    chart_source = anno["source"]
    chart_type = anno['chart-type']

    labels = []

    labels.append(
        {
            "id": chart_id,
            "source": chart_source,
            "chart_type": chart_type,
        }
    )

    # labels => [{'id': '8941f843bb04', 'source': 'generated', 'chart_type': 'dot'}]
    
    return labels


In [34]:
def process_annotations(cfg, num_jobs=8):
    data_dir = cfg.competition_dataset.data_dir.rstrip("/")
    anno_paths = glob.glob(f"{data_dir}/train/annotations/*.json")
    annotations = Parallel(n_jobs=num_jobs, verbose=1)(delayed(_process_json)(file_path) for file_path in anno_paths)
    
    # annotations[0:5] => [[{'id': '8941f843bb04', 'source': 'generated', 'chart_type': 'dot'}], [{'id': '7dbdb91fa4a7', 'source': 'generated', 'chart_type': 'scatter'}], [{'id': '956c946a123d', 'source': 'generated', 'chart_type': 'line'}], [{'id': '5fe5636e61d3', 'source': 'generated', 'chart_type': 'dot'}], [{'id': 'e3f11968d040', 'source': 'generated', 'chart_type': 'line'}]]
    
    # *annotations[0:5] => [{'id': '8941f843bb04', 'source': 'generated', 'chart_type': 'dot'}] [{'id': '7dbdb91fa4a7', 'source': 'generated', 'chart_type': 'scatter'}] [{'id': '956c946a123d', 'source': 'generated', 'chart_type': 'line'}] [{'id': '5fe5636e61d3', 'source': 'generated', 'chart_type': 'dot'}] [{'id': 'e3f11968d040', 'source': 'generated', 'chart_type': 'line'}]        
    
    # list(chain(*annotations))[0:5] =>
    # [{'id': '8941f843bb04', 'source': 'generated', 'chart_type': 'dot'}, {'id': '7dbdb91fa4a7', 'source': 'generated', 'chart_type': 'scatter'}, {'id': '956c946a123d', 'source': 'generated', 'chart_type': 'line'}, {'id': '5fe5636e61d3', 'source': 'generated', 'chart_type': 'dot'}, {'id': 'e3f11968d040', 'source': 'generated', 'chart_type': 'line'}]
    labels_df = pd.DataFrame(list(chain(*annotations)))
    
    # labels_df.head(3)
    #                  id     source chart_type
    # 0  8941f843bb04  generated        dot
    # 1  7dbdb91fa4a7  generated    scatter
    # 2  956c946a123d  generated       line
    
    return labels_df


In [37]:
def create_cv_folds(cfg):
    """Create Folds for the MGA task

    :param args: config file
    :type args: dict
    """
    print_line()
    print("creating folds ...")
    
    fold_df = process_annotations(cfg)
    fold_df = fold_df[["id", "source", "chart_type"]].copy()
    fold_df = fold_df.drop_duplicates()
    fold_df = fold_df.reset_index(drop=True)
    
    print(cfg.fold_metadata.n_folds)
    # ------
    skf = StratifiedKFold(
        n_splits=cfg.fold_metadata.n_folds,
        shuffle=True,
        random_state=cfg.fold_metadata.seed
    )

    for f, (t_, v_) in enumerate(skf.split(fold_df, fold_df["chart_type"].values)):
        fold_df.loc[v_, "kfold"] = f
        
    fold_df["kfold"] = fold_df["kfold"].astype(int)

    # allocate fold 99 to synthetic data
    fold_df["kfold"] = fold_df[["kfold", "source"]].apply(
        lambda x: x[0] if x[1] == "extracted" else 99, axis=1,
    )

    # fold_df.kfold.unique() => array([99,  0,  1])
    
    print(fold_df["kfold"].value_counts())

    # save fold split ---
    save_dir = cfg.fold_metadata.fold_dir
    os.makedirs(save_dir, exist_ok=True)

    save_path = os.path.join(save_dir, f"cv_map_{cfg.fold_metadata.n_folds}_folds.parquet")
    fold_df = fold_df[["id", "kfold"]].copy()
    fold_df = fold_df.reset_index(drop=True)
    fold_df.to_parquet(save_path)
    print("done!")
    print_line()
    # ---
    # fold_df.head(3) =>
    #               id  kfold
    # 0  8941f843bb04     99
    # 1  7dbdb91fa4a7     99
    # 2  956c946a123d     99


In [38]:
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config_path", default='./conf/tools/conf_folds.yaml', type=str, )# required=True
    args, unknown = ap.parse_known_args()

    cfg = OmegaConf.load(args.config_path)
    create_cv_folds(cfg)


#----------------------------------------------------------------------------------------------------#
creating folds ...


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  56 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done 37706 tasks      | elapsed:    0.9s
[Parallel(n_jobs=8)]: Done 60578 out of 60578 | elapsed:    1.2s finished


2
kfold
99    59460
0       564
1       554
Name: count, dtype: int64
done!
#----------------------------------------------------------------------------------------------------#
