In [1]:
import os
import pandas as pd
import numpy as np
import svgpathtools
import cv2
import glob

from collections import defaultdict
from tqdm import tqdm
from pathlib import Path
from cairosvg import svg2png

In [2]:
base_dir = './../'
csv_dir = os.path.join(base_dir, "results", "csv")
svg_dir = os.path.join(base_dir, "results", "svg")
png_dir = os.path.join(base_dir, "results", "png")
roi_dir = os.path.join(base_dir, "results", "roi")
df = pd.read_csv(os.path.join(csv_dir, "causaldraw_annotation_preprocessed_final_svg_data.csv"), index_col=0)
df

Unnamed: 0,sketchID,strokeIndex,condition,orig_gameID,strokeLabel,strokeType,strokeRoiNum,svg,arcLength
0,gears_1.0219-e77f751a-a934-4602-97a0-f2c0bd8bd638,0,explanatory,0219-e77f751a-a934-4602-97a0-f2c0bd8bd638,gear,causal,4.0,"M315,266c0,3.66667 0,7.33333 0,11c0,3 0,6 0,9c...",937.999262
1,gears_1.0219-e77f751a-a934-4602-97a0-f2c0bd8bd638,1,explanatory,0219-e77f751a-a934-4602-97a0-f2c0bd8bd638,gear,causal,3.0,"M253,190c16.21414,0 17.30774,1.30774 28,12c2.5...",1450.279985
2,gears_1.0366-5aa209bd-fda5-4afd-b896-9371d73ab1df,0,depictive,0366-5aa209bd-fda5-4afd-b896-9371d73ab1df,background,background,5.0,"M176,245c0,-11.78835 -0.1136,-106.8864 0,-107c...",671.450565
3,gears_1.0366-5aa209bd-fda5-4afd-b896-9371d73ab1df,1,depictive,0366-5aa209bd-fda5-4afd-b896-9371d73ab1df,gear,functional,1.0,"M209,167c-11.32421,0 -26.73795,-4.78615 -31,8c...",252.747348
4,gears_1.0366-5aa209bd-fda5-4afd-b896-9371d73ab1df,2,depictive,0366-5aa209bd-fda5-4afd-b896-9371d73ab1df,gear,functional,1.0,"M201,198c0,0.23509 -3.72683,25.72683 8,14c2.31...",55.661950
...,...,...,...,...,...,...,...,...,...
4492,pulleys_2.9810-988306f1-19e0-4158-a007-261fc2d...,8,depictive,9810-988306f1-19e0-4158-a007-261fc2d3ae0b,wheel,functional,3.0,"M386,81c-0.66667,0.33333 -1.4274,0.52283 -2,1c...",120.559612
4493,pulleys_2.9810-988306f1-19e0-4158-a007-261fc2d...,9,depictive,9810-988306f1-19e0-4158-a007-261fc2d3ae0b,string,functional,6.0,"M386,111c0,-5 0,10 0,15c0,10 0,20 0,30c0,19 0,...",152.076427
4494,pulleys_2.9810-988306f1-19e0-4158-a007-261fc2d...,10,depictive,9810-988306f1-19e0-4158-a007-261fc2d3ae0b,wheel,causal,1.0,"M258,22c-5.3108,-5.3108 -13.60541,31.19729 -8,...",153.588471
4495,pulleys_2.9810-988306f1-19e0-4158-a007-261fc2d...,11,depictive,9810-988306f1-19e0-4158-a007-261fc2d3ae0b,wheel,causal,2.0,"M263,38c4.67856,0 1.75173,9.38223 1,14c-1.7544...",376.025906


In [3]:
Path(svg_dir).mkdir(parents=True, exist_ok=False)
Path(png_dir).mkdir(parents=True, exist_ok=False)
Path(roi_dir).mkdir(parents=True, exist_ok=False)

In [4]:
for i, row in tqdm(df.iterrows(), total=len(df)):
    
    sketch_id = row["sketchID"].replace("_", "-")
    stroke_idx = row["strokeIndex"]
    
    stroke_label = row["strokeLabel"]
    stroke_type = row["strokeType"]
     
    stroke_roi = row["strokeRoiNum"]
    if pd.isna(stroke_roi):
        stroke_roi = 0
    else:
        stroke_roi = int(stroke_roi)
        
    stroke_path = row["svg"]
    stroke_path = svgpathtools.parse_path(stroke_path)
    
    stroke_name = os.path.join(svg_dir, "%s_%02d_%s_%s_%d.svg" % (sketch_id, 
                                                                  stroke_idx, 
                                                                  stroke_label,
                                                                  stroke_type,
                                                                  stroke_roi))
    if os.path.exists(stroke_name):
        print(i, stroke_name)
    svgpathtools.wsvg(paths=stroke_path, 
                      attributes=[{'stroke-width': 5,
                                   'stroke-linecap': "round",
                                   'stroke': "black",
                                   'fill': "none"}],
                      viewbox=(0, 0, 500, 500),
                      filename=stroke_name)

100%|██████████| 4445/4445 [00:02<00:00, 1785.45it/s]


In [5]:
svg_files = glob.glob(os.path.join(svg_dir, "*.svg"))
for svg_file in tqdm(svg_files):
    svg2png(url=svg_file, 
            write_to=os.path.join(png_dir, 
                                  os.path.basename(svg_file)[:-3] + "png"))

100%|██████████| 4445/4445 [00:48<00:00, 90.75it/s]


In [6]:
png_files = glob.glob(os.path.join(png_dir, "*.png"))
roi_dict = defaultdict(list)

for png_file in tqdm(png_files):
    sketch_id, stroke_idx, stroke_label, stroke_type, stroke_roi = os.path.basename(png_file)[:-4].split("_")
    roi_dict[(sketch_id, stroke_roi)].append(png_file)

100%|██████████| 4445/4445 [00:00<00:00, 471180.78it/s]


In [7]:
for key, sub_files in tqdm(roi_dict.items()):
    img = np.zeros((500, 500))
    for sub_file in sub_files:
        sub_img = cv2.imread(sub_file, cv2.IMREAD_UNCHANGED)
        sub_img = sub_img[..., 3] / 255
        img = img + sub_img
    img = np.clip(img, 0, 1.0) * 255
    img = np.uint8(img)
    filename = key[0] + "_" + key[1] + ".png"
    filename = os.path.join(roi_dir, filename)
    cv2.imwrite(filename, img)

100%|██████████| 1453/1453 [00:11<00:00, 121.96it/s]
