# Data Generation

## Imports & Variables

In [1]:
import random
import os
from PIL import Image
import json
from tqdm import tqdm
import time
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split

import utils_gen
import utils_helper

In [2]:
seed = 2022
random.seed(seed)

In [3]:
debug = False
save_gen = True

In [4]:
path_scenegraph = "./data/scenegraphs"
path_L2L3 = "./data/L2L3_captions.json"

## Scenegraph Parsing & L1 Generation

In [5]:
data_L1 = []
n_success = 0
n_error = 0

files_sg = list(filter(lambda k: 'json' in k, os.listdir(path_scenegraph)))
files_sg.sort()
assert(len(files_sg) == 8822)

for file in tqdm(files_sg):
    filepath = os.path.join(path_scenegraph, file)
    with open(filepath) as f:
        read_sg = json.load(f)
#         read_sg = json.loads(f.read().decode("utf-8"))
        try:
            parsed_sg = utils_gen.parse_all_sg(read_sg)
            parsed_dt = utils_gen.parse_all_dt(read_sg)
            parsed_metadata = utils_gen.parse_all_sg(read_sg,src=False)
            parsed_caption, parsed_metadatadict = utils_gen.generate_caption(*parsed_metadata)
            data_L1.append((file.replace(".json",""), parsed_sg, parsed_dt, parsed_caption, parsed_metadata))
            if debug:
                print(f"Parsed Scenegraph: {parsed_sg}\n")
                print(f"Parsed Datatable: {parsed_dt}\n")
                print(f"Parsed L1 Metadata: {parsed_metadata}")
                display(Image.open(os.path.join("./data/images", file.replace(".json",".png"))))
                print(f"Parsed Caption: {parsed_caption}")
#                 print(parsed_metadatadict)
                break
            n_success += 1
        except Exception as e:
            n_error += 1
            print(e)
print(f"Successes: {n_success}")
print(f"Errors: {n_error}")

100%|██████████████████████████████████████████████████████████████████████████████| 8822/8822 [00:58<00:00, 151.53it/s]

Successes: 8822
Errors: 0





In [6]:
if save_gen:
    with open("./data/L1gen/L1gen.pkl", 'wb') as f:
        pickle.dump(data_L1, f)

## Data Preprocessing

In [20]:
df_L1 = pd.DataFrame(data_L1, columns=["img_id", "scenegraph", "datatable", "caption_L1", "L1_properties"])
df_L1["img_id"] = pd.to_numeric(df_L1["img_id"])
df_L1.head()

Unnamed: 0,img_id,scenegraph,datatable,caption_L1,L1_properties
0,1,title National Football League average ticket ...,National Football League average ticket price ...,National Football League average ticket price ...,"[line, National Football League average ticket..."
1,10,title Average ticket price Tampa Bay Lightning...,Average ticket price Tampa Bay Lightning (NHL)...,Here a bar plot is titled Average ticket price...,"[bar, Average ticket price Tampa Bay Lightning..."
2,100,title Annual number of hospital beds in the Un...,Annual number of hospital beds in the United K...,Annual number of hospital beds in the United K...,"[bar, Annual number of hospital beds in the Un..."
3,1000,title Average monthly salary in public sector ...,Average monthly salary in public sector in Swe...,Average monthly salary in public sector in Swe...,"[area, Average monthly salary in public sector..."
4,1001,title Household penetration rates for dog-owne...,Household penetration rates for dog-ownership ...,Household penetration rates for dog-ownership ...,"[area, Household penetration rates for dog-own..."


In [21]:
with open(path_L2L3, 'rb') as f:
    df_L2L3 = pd.read_json(pickle.load(f), lines=True)
df_L2L3["caption_L2L3"] = df_L2L3["caption_L2L3"].map(utils_helper.cleanL2L3)
df_L2L3.head()

Unnamed: 0,img_id,caption_L2L3
0,7219,liverpool's broadcasting revenue has almost qu...
1,3451,people spent most of their money on shelter an...
2,2553,"as time has gone on, there have been less peop..."
3,5364,unemployment skyrocketed between ~2008-2012. t...
4,5879,population was already incredibly high. signif...


In [22]:
df_joint = df_L1.set_index("img_id").join(df_L2L3.set_index("img_id"), on="img_id", how="right")
df_joint = df_joint.sort_values(by="img_id", ascending=True)
df_joint = df_joint.reset_index()
df_joint.head()

Unnamed: 0,img_id,scenegraph,datatable,caption_L1,L1_properties,caption_L2L3
0,1,title National Football League average ticket ...,National Football League average ticket price ...,National Football League average ticket price ...,"[line, National Football League average ticket...",The ticket price shown for 2006 is just over 6...
1,2,title Average viewers of Minecraft on Twitch w...,Average viewers of Minecraft on Twitch worldwi...,This is a area plot titled Average viewers of ...,"[area, Average viewers of Minecraft on Twitch ...",Viewers of Minecraft on twitch has gradually i...
2,2,title Average viewers of Minecraft on Twitch w...,Average viewers of Minecraft on Twitch worldwi...,This is a area plot titled Average viewers of ...,"[area, Average viewers of Minecraft on Twitch ...",Minecraft viewing on Twitch sharply increased ...
3,3,title Worldwide number of Michelin 's full-tim...,Worldwide number of Michelin 's full-time empl...,This is a area plot labeled Worldwide number o...,"[area, Worldwide number of Michelin 's full-ti...",employees have remained pretty stable througho...
4,3,title Worldwide number of Michelin 's full-tim...,Worldwide number of Michelin 's full-time empl...,This is a area plot labeled Worldwide number o...,"[area, Worldwide number of Michelin 's full-ti...",The worldwide number of full time employees is...


In [23]:
df_joint["img_id"] = df_joint["img_id"].map(str)
captionnumber = df_joint["img_id"] + "_" + df_joint.groupby("img_id").cumcount().add(1).astype(str).str.zfill(2)
df_joint["caption_id"] = captionnumber
# df_joint["img_filename"] = df_joint["img_id"]+".png"
df_joint.head()

Unnamed: 0,img_id,scenegraph,datatable,caption_L1,L1_properties,caption_L2L3,caption_id
0,1,title National Football League average ticket ...,National Football League average ticket price ...,National Football League average ticket price ...,"[line, National Football League average ticket...",The ticket price shown for 2006 is just over 6...,1_01
1,2,title Average viewers of Minecraft on Twitch w...,Average viewers of Minecraft on Twitch worldwi...,This is a area plot titled Average viewers of ...,"[area, Average viewers of Minecraft on Twitch ...",Viewers of Minecraft on twitch has gradually i...,2_01
2,2,title Average viewers of Minecraft on Twitch w...,Average viewers of Minecraft on Twitch worldwi...,This is a area plot titled Average viewers of ...,"[area, Average viewers of Minecraft on Twitch ...",Minecraft viewing on Twitch sharply increased ...,2_02
3,3,title Worldwide number of Michelin 's full-tim...,Worldwide number of Michelin 's full-time empl...,This is a area plot labeled Worldwide number o...,"[area, Worldwide number of Michelin 's full-ti...",employees have remained pretty stable througho...,3_01
4,3,title Worldwide number of Michelin 's full-tim...,Worldwide number of Michelin 's full-time empl...,This is a area plot labeled Worldwide number o...,"[area, Worldwide number of Michelin 's full-ti...",The worldwide number of full time employees is...,3_02


## Train/Test/Val Split

In [24]:
L2L3_split = {'train': [],
              'test': [],
              'val': []}

In [25]:
ids_L2L3 = list(df_joint["img_id"].unique())
assert(len(ids_L2L3) == 8822)
L2L3_split['train'], L2L3_rest = train_test_split(ids_L2L3, train_size=0.8, shuffle=True, random_state=seed)
L2L3_split['test'], L2L3_split['val'] = train_test_split(L2L3_rest, test_size=0.5, shuffle=True, random_state=seed)

In [26]:
with open('./data/splits.json', 'w') as f:
    json.dump(L2L3_split, f)

In [32]:
df_joint["split"] = df_joint["img_id"].map(lambda x: "train" if x in L2L3_split['train'] else ("test" if x in L2L3_split['test'] else "validation"))
df_joint = df_joint[["caption_id", "img_id", "split", "scenegraph", "datatable", "caption_L1", "caption_L2L3", "L1_properties"]]
df_joint.head()

Unnamed: 0,caption_id,img_id,split,scenegraph,datatable,caption_L1,caption_L2L3,L1_properties
0,1_01,1,train,title National Football League average ticket ...,National Football League average ticket price ...,National Football League average ticket price ...,The ticket price shown for 2006 is just over 6...,"[line, National Football League average ticket..."
1,2_01,2,train,title Average viewers of Minecraft on Twitch w...,Average viewers of Minecraft on Twitch worldwi...,This is a area plot titled Average viewers of ...,Viewers of Minecraft on twitch has gradually i...,"[area, Average viewers of Minecraft on Twitch ..."
2,2_02,2,train,title Average viewers of Minecraft on Twitch w...,Average viewers of Minecraft on Twitch worldwi...,This is a area plot titled Average viewers of ...,Minecraft viewing on Twitch sharply increased ...,"[area, Average viewers of Minecraft on Twitch ..."
3,3_01,3,validation,title Worldwide number of Michelin 's full-tim...,Worldwide number of Michelin 's full-time empl...,This is a area plot labeled Worldwide number o...,employees have remained pretty stable througho...,"[area, Worldwide number of Michelin 's full-ti..."
4,3_02,3,validation,title Worldwide number of Michelin 's full-tim...,Worldwide number of Michelin 's full-time empl...,This is a area plot labeled Worldwide number o...,The worldwide number of full time employees is...,"[area, Worldwide number of Michelin 's full-ti..."


## Export

In [30]:
df_train = df_joint[df_joint["split"] == "train"]
df_test = df_joint[df_joint["split"] == "test"]
df_val = df_joint[df_joint["split"] == "validation"]
assert(len(df_train)+len(df_test)+len(df_val) == 12441)

In [31]:
data_train = df_train.to_json(orient="records", lines=True)
data_test = df_test.to_json(orient="records", lines=True)
data_val = df_val.to_json(orient="records", lines=True)

with open("./data/data_train.json", "w") as f:
    f.write(data_train)
with open("./data/data_test.json", "w") as f:
    f.write(data_test)
with open("./data/data_validation.json", "w") as f:
    f.write(data_val)