In [1]:
# This file is modified version of original:
# https://github.com/rbiswasfc/benetech-mga/blob/main/gen/run_gen_vbar.py

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import argparse
import glob
import json
import os
import random
import sys
import warnings
from copy import deepcopy

import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from src.generator_utils import generate_random_string
from src.vbar_plot_base import VerticalBarPlot
from src.vbar_xy_generation import generate_from_synthetic, generate_from_wiki
from tqdm.auto import tqdm

# warnings.filterwarnings("ignore")

In [4]:
STOPWORDS = [
    "ISBN",
    "exit",
    "edit",
]


In [5]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

In [6]:
def generate_plot_data(cfg, wiki_generator, synthetic_generator):
    generator = random.choices(
        [wiki_generator, synthetic_generator],
        weights=[0.25, 0.75],
        k=1,
    )[0]

    try:
        data = next(generator)
    except Exception as e:
        data = next(synthetic_generator)

    x_series = list(deepcopy(data['x_series']))
    y_series = list(deepcopy(data['y_series']))

    # process underlying data ---
    x_series = [str(x) for x in x_series]
    x_series = [x_val[:cfg.max_chars] for x_val in x_series]

   # max data points in a plot
    x_series = x_series[:cfg.max_points]
    y_series = y_series[:cfg.max_points]

    if (abs(min(y_series)) < 1e-3) & (abs(max(y_series)) < 1e-3):  # fixing constant value
        idx = random.randint(0, len(y_series)-1)
        y_series[idx] = random.uniform(0.01, 0.99)

    data['x_series'] = list(x_series)
    data['y_series'] = list(y_series)
    return data

In [7]:
def generate_annotation(data):
    data = deepcopy(data)
    x_mga = data['x_series']
    y_mga = data['y_series']

    chart_type = 'vertical_bar'
    x_type = 'categorical'
    y_type = 'numerical'

    data_series = []
    for xi, yi in zip(x_mga, y_mga):
        data_series.append(
            {
                'x': xi,
                'y': yi,
            }
        )

    annotation = dict()
    annotation['chart-type'] = chart_type

    annotation['axes'] = dict()
    annotation['axes']['x-axis'] = dict()
    annotation['axes']['x-axis']['values-type'] = x_type

    annotation['axes']['y-axis'] = dict()
    annotation['axes']['y-axis']['values-type'] = y_type

    annotation['data-series'] = data_series
    return annotation


In [8]:

def main(args, cfg):
    with open(args.wiki_path, 'r') as f:
        wiki_bank = json.load(f)
        
    # len(wiki_bank) => 175543
    # type(wiki_bank[0]) => <class 'list'>
    # type(wiki_bank[0][0]) => <class 'dict'>
    # wiki_bank[0:100] => [[{'plot-title': '', 'series-name': 'Party', 'data-type': 'categorical', 'data-series': ['Public Against Violence', 'Christian Democratic Movement', 'Slovak National Party', 'Communist Party of Slovakia', 'Coexistence – Hungarian Christian Democratic Movement', 'Democratic Party', 'Party of Greens',
    
    
    
    stem_df = pd.read_pickle(args.stem_path)
    # stem_df.head(3) => 
    #    title                                           keywords
    # 0  page_2  [SENIOR CONTRIBUTING AUTHORS, COMMUNITY COLLEG...
    # 1  page_3  [Rice University, Main Street, Textbook, Downl...
    # 2  page_4  [RICE UNIVERSITY, FOUNDATION, Advanced Placeme...    
    
    stem_bank = dict(zip(stem_df["title"], stem_df["keywords"]))    

    # get first item of dict => next(iter(stem_bank.items())) =>
    # ('page_2', array(['SENIOR CONTRIBUTING AUTHORS', 'COMMUNITY COLLEGE', 'Physiology',
    #        'CENTRAL OREGON COMMUNITY', 'UNIVERSITY', 'PETER DESAIX',
    #        'CHAPEL HILL', 'EDDIE JOHNSON', 'Anatomy', 'DEAN KRUSE',
    #        'YOUNGSTOWN STATE', 'BETTS', 'TYLER', 'GORDON', 'WISE', 'OKSANA',
    #        'POE', 'SPRINGFIELD', 'JAMES', 'MARK', 'CALIFORNIA', 'TECHNICAL'],
    #       dtype=object))
    # here "page_2" is the first key.
    
    # process stem bank
    processed_stem_bank = dict()
    for key, values in stem_bank.items():
        # values => ['SENIOR CONTRIBUTING AUTHORS' 'COMMUNITY COLLEGE' 'Physiology', ...]
        
        key = key.replace("_", " ")
        values = [v for v in values if not v.startswith("[")]# v => SENIOR CONTRIBUTING AUTHORS
        values = [v for v in values if not v in STOPWORDS]

        if len(values) >= 4:
            processed_stem_bank[key] = list(set(values))

    # processed_stem_bank['page 2'] =>
    # ['OKSANA', 'TYLER', 'UNIVERSITY', 'MARK', 'CENTRAL OREGON COMMUNITY', ...]    
        
    print(f"wiki bank size: {len(wiki_bank)}")
    print(f"stem bank size: {len(processed_stem_bank)}")
    
    
    wiki_generator = generate_from_wiki(wiki_bank)
    synthetic_generator = generate_from_synthetic(processed_stem_bank)

    # -- input/output ---
    os.makedirs(cfg.output.image_dir, exist_ok=True)
    os.makedirs(cfg.output.annotation_dir, exist_ok=True)
    texture_files = glob.glob(f"{args.texture_dir}/*.png")
    print(f"# texture files: {len(texture_files)}")

    p_bar = tqdm(range(cfg.num_images))
    for _ in range(cfg.num_images):
        base_image_id = f'syn_vbar_{generate_random_string()}'
        the_example = generate_plot_data(cfg, wiki_generator, synthetic_generator)

        # cast in the format of MGA
        mga_anno = generate_annotation(the_example)
        
        # cfg.output.annotation_dir => '../data/synthetic_vbar_b0/annotations'
        anno_path = os.path.join(cfg.output.annotation_dir, f"{base_image_id}.json")
        image_id = f"{base_image_id}"
        try:
            VerticalBarPlot(cfg, the_example, texture_files=texture_files).make_vertical_bar_plot(image_id)
            with open(anno_path, "w") as f:
                # use NpEncoder as custom JSONEncoder subclass.
                json.dump(mga_anno, f, cls=NpEncoder)

        except Exception as e:
            print(e)
            print("--"*40)
            print(the_example)
            print("--"*40)
        p_bar.update()
    p_bar.close()

In [9]:
if __name__ == '__main__':
    ap = argparse.ArgumentParser()
    ap.add_argument('--wiki_path', default='../datasets/processed/deps/sanitized_wiki.json', type=str, )#required=True
    ap.add_argument('--stem_path', default='../datasets/processed/deps/mga_stem_kws.pickle', type=str, )
    ap.add_argument('--conf_path', default='./conf/conf_vbar.yaml', type=str, )
    ap.add_argument('--texture_dir', default='../datasets/processed/deps/mga_textures_cc/mga_textures_cc', type=str, )

    args, unknown = ap.parse_known_args()
    cfg = OmegaConf.load(args.conf_path)

    processed_stem_bank = main(args, cfg)

wiki bank size: 175543
stem bank size: 96456
# texture files: 1109


  0%|          | 0/5000 [00:00<?, ?it/s]

  self.fig.savefig(save_path, format='jpg', bbox_inches='tight')
  self.fig.savefig(save_path, format='jpg', bbox_inches='tight')
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
  self.fig.savefig(save_path, format='jpg', bbox_inches='tight')
  self.fig.savefig(save_path, format='jpg', bbox_inches='tight')
  self.fig.savefig(save_path, format='jpg', bbox_inches='tight')
  self.fig.savefig(save_path, format='jpg', bbox_inches='tight')
