In [1]:
import os
os.environ['no_proxy'] = "127.0.0.1,0.0.0.1,localhost"

import json

from uuid import uuid4
from PIL import Image, ImageDraw

from selenium.webdriver.common.by import By


type_to_id = {'button': 0, 'text': 1, 'input_field': 2, 'image': 3 }

## Set for Webshop Environment

In [None]:
from selenium import webdriver
driver = webdriver.Chrome()
driver.set_window_size(1300, 1300)

webshop_env_host_ip = "127.0.0.1"
webshop_env_host_port =3000

In [2]:
def _get_color(bbox, width, height):
    x = bbox[0]
    y = bbox[1]
    r_value = int((x/width) * 180 + 30)
    g_value = int((y/height) * 180 + 30)
    b_value = int(200 - (x/height) * 180)
    
    return (r_value, g_value, b_value)

def _compute_iou(box1, box2):
    """
    Computes the Intersection over Union (IoU) between two bounding boxes.
    Args:
        box1: A tuple of (x1, y1, x2, y2).
        box2: A tuple of (x1, y1, x2, y2).
    Returns:
        The IoU between the two bounding boxes.
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    if x2 <= x1 or y2 <= y1:
        return 0.0

    intersection = (x2 - x1) * (y2 - y1)
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    iou = intersection / float(area1 + area2 - intersection)
    return iou

def _compute_iou_w_h(box1, box2):
    """
    Computes the Intersection over Union (IoU) between two bounding boxes.
    Args:
        box1: A tuple of (x1, y1, width, height).
        box2: A tuple of (x1, y1, width, height).
    Returns:
        The IoU between the two bounding boxes.
    """
    converted_box1 = (box1[0], box1[1], box1[0]+box1[2], box1[1]+box1[3])
    converted_box2 = (box2[0], box2[1], box2[0]+box2[2], box2[1]+box2[3])

    return _compute_iou(converted_box1, converted_box2)

In [3]:
def set_episode(driver, index_id:str):
    driver.get(f"http://{webshop_env_host_ip}:{webshop_env_host_port}/{index_id}")

In [4]:
def save_full_screenshot(driver, save_image_full_path):
    elements = driver.find_elements(By.TAG_NAME, '*')
    fullscreen_element = None
    max_size = 0
    for e_element in elements:
        element_size = e_element.size['height']+e_element.size['width']
        if  element_size > max_size:
            max_size = element_size
            fullscreen_element = e_element
    fullscreen_element.screenshot(save_image_full_path)


def width_then_height_key(box):
    return (box[2], box[3])


def _non_maximum_subpression(box_list, iou_threshold):
    ### bounding box list shape [n, 5]
    # each bounding box has x, y, w, h, id
    subpression_candidat_list = list()
    for outer_bbox_idx, outer_bbox in enumerate(box_list):
        for inner_bbox in box_list[outer_bbox_idx+1:]:
            if inner_bbox[4] in subpression_candidat_list:
                continue
            iou = _compute_iou_w_h(outer_bbox, inner_bbox)
            if iou > iou_threshold:
                subpression_candidat_list.append(inner_bbox[4])
            
    nmsed_box_list = list()
    for e_bbox in box_list:
        if e_bbox[4] in subpression_candidat_list:
            continue
        nmsed_box_list.append(e_bbox)

    return nmsed_box_list


def fliter_in_window_element(box_list, image_height):
    ### bounding box list shape [n, 5]
    # each bounding box has x, y, w, h, id
    result_box_list = list()
    for e_box in box_list:
        if e_box[1] >= image_height:
            continue
        result_box_list.append(e_box)

    return result_box_list


def get_fit_coordinate_for_left_space(img_original, start_x, start_y, width, height):
    right_r, right_g, right_b = img_original.getpixel((start_x, start_y))

    image_width, image_height = img_original.size 

    avg_start_vertical_line_color = list()
    for i in range(height):
        if image_height <= start_y+i:
            continue
        right_r, right_g, right_b = img_original.getpixel((start_x, start_y+i))
        avg_start_vertical_line_color.append((right_r + right_g + right_b)/3)
    avg_start_vertical_color = sum(avg_start_vertical_line_color) / len(avg_start_vertical_line_color)
    
    for j in range(1, width):
        right_moving_x = start_x + j
        if image_width <= right_moving_x:
            continue
        avg_right_vertical_line_color = list()
        for i in range(height):
            if image_height <= start_y+i:
                continue
            right_r, right_g, right_b = img_original.getpixel((right_moving_x, start_y+i))
            avg_right_vertical_line_color.append((right_r + right_g + right_b)/3)
        avg_right_vertical_color = sum(avg_right_vertical_line_color) / len(avg_right_vertical_line_color)

        if abs(avg_start_vertical_color - avg_right_vertical_color) > 0:
            return right_moving_x - 1

    return start_x


def get_fit_coordinate_for_right_space(img_original, start_x, start_y, width, height):
    right_r, right_g, right_b = img_original.getpixel((start_x, start_y))

    image_width, image_height = img_original.size 

    avg_start_vertical_line_color = list()
    for i in range(height):
        if image_height <= start_y+i:
            continue
        right_r, right_g, right_b = img_original.getpixel((start_x, start_y+i))
        avg_start_vertical_line_color.append((right_r + right_g + right_b)/3)
    avg_start_vertical_color = sum(avg_start_vertical_line_color) / len(avg_start_vertical_line_color)
    
    for j in range(1, width):
        left_moving_x = start_x - j
        if image_width <= left_moving_x:
            continue
        avg_left_vertical_line_color = list()
        for i in range(height):
            if image_height <= start_y+i:
                continue
            right_r, right_g, right_b = img_original.getpixel((left_moving_x, start_y+i))
            avg_left_vertical_line_color.append((right_r + right_g + right_b)/3)
        avg_left_vertical_color = sum(avg_left_vertical_line_color) / len(avg_left_vertical_line_color)

        if abs(avg_start_vertical_color - avg_left_vertical_color) > 0:
            return left_moving_x + 1

    return start_x


def show_preview_annotation_data(driver):
    elements = driver.find_elements(By.TAG_NAME, '*')

    candidate_bounding_box_list = list()
    candidate_element_list = list()
    element_index = 0
    background_element = None
    max_element_size = 0
    for e_element in elements:
        x1 = e_element.location['x']
        y1 = e_element.location['y']
        width = e_element.rect['width']
        height = e_element.rect['height']
        if max_element_size < (width+height):
            max_element_size = width+height
            background_element = e_element
        
        if width <= 2 or height <= 2:
            continue
        if ('name="search_query"' not in e_element.get_attribute('outerHTML')) and len(e_element.text) == 0:
            if e_element.tag_name != 'img':
                continue
        candidate_bounding_box_list.append((x1,y1, width,height, element_index))
        candidate_element_list.append(e_element)
        element_index += 1

    candidate_bounding_box_list.sort(key=width_then_height_key)
    nmsed_box_list = _non_maximum_subpression(candidate_bounding_box_list, 0.05)

    background_element.screenshot('./sample_bg.png')

    background_img = Image.open('./sample_bg.png')
    draw = ImageDraw.Draw(background_img)
    image_width, image_height = background_img.size
    for e_nms_box in nmsed_box_list:
        draw.rectangle((e_nms_box[0],e_nms_box[1],e_nms_box[0]+e_nms_box[2],e_nms_box[1]+e_nms_box[3]), outline=_get_color(e_nms_box, image_width, image_height), width=3)
        draw.rectangle((e_nms_box[0],e_nms_box[1],e_nms_box[0]+10,e_nms_box[1]+10), fill=_get_color(e_nms_box, image_width, image_height))
        draw.text((e_nms_box[0]+2,e_nms_box[1]), str(e_nms_box[4]), fill=(255, 255, 255))
        
    display(background_img)

In [5]:
def save_dataset(driver, type_to_id, target_image_name, save_image_base, target_label_base, commentator_label_base, fixed_window_heigt = 1153):
    
    elements = driver.find_elements(By.TAG_NAME, '*')
    candidate_bounding_box_list = list()
    candidate_element_list = list()
    element_index = 0

    background_element = None
    max_element_size = 0

    for e_element in elements:
        x1 = e_element.location['x']
        y1 = e_element.location['y']
        width = e_element.rect['width']
        height = e_element.rect['height']

        if max_element_size < width+height:
            max_element_size = width+height
            background_element = e_element    
        
        if width <= 2 or height <= 2:
            continue
        if ('name="search_query"' not in e_element.get_attribute('outerHTML')) and len(e_element.text) == 0:
            if e_element.tag_name != 'img':
                continue
        candidate_bounding_box_list.append((x1, y1, width, height, element_index))
        candidate_element_list.append(e_element)
        element_index += 1

    candidate_bounding_box_list.sort(key=width_then_height_key)
    filtered_bbox_list = fliter_in_window_element(candidate_bounding_box_list, fixed_window_heigt-2)
    nmsed_box_list = _non_maximum_subpression(filtered_bbox_list, 0.05)

    target_image_path = os.path.join(save_image_base, target_image_name)
    background_element.screenshot(target_image_path)
    background_img = Image.open(target_image_path)
    
    image_width, image_height = background_img.size


    base_annotation_file_name = target_image_name[:-4] + '.txt'
    base_annotation_file_name_annotation = target_image_name[:-4] + '_annotation.txt'

    save_data_list_for_yolo = list()
    save_data_list_for_pix2struct = list()
    save_data_list_for_pix2struct_annotation = list()

    for e_nms_box in nmsed_box_list:
        element_id = e_nms_box[4]
        element_text = candidate_element_list[element_id].text
        element_outer_html = candidate_element_list[element_id].get_attribute('outerHTML')
        element_type = "text"
        if 'img' in element_outer_html:
            element_type ='image'
            element_text = ''
        elif '<button' in element_outer_html or '<label' in element_outer_html:
            element_type ='button'
        elif '<input id' in element_outer_html:
            element_type ='input_field'


        x1 = e_nms_box[0]
        y1 = e_nms_box[1]
        x2 = x1 + e_nms_box[2]
        y2 = y1 + e_nms_box[3]

        x2 = min(x2, image_width)
        y2 = min(y2, image_height)

        if element_type == 'text':        
            x1 = get_fit_coordinate_for_left_space(background_img, x1, y1, e_nms_box[2], e_nms_box[3])
            x2 = get_fit_coordinate_for_right_space(background_img, x2, y1, e_nms_box[2], e_nms_box[3])

        assert x1 < x2
        assert y1 < y2

        bbox_width = x2-x1
        bbox_height = y2-y1
        bbox_mx = x1 + bbox_width*0.5
        bbox_my = y1 + bbox_height*0.5

        # mx my w h normalized
        save_data_list_for_yolo.append(f"{type_to_id[element_type]} {bbox_mx/image_width} {bbox_my/image_height} {bbox_width/image_width} {bbox_height/image_height}")
        save_data_list_for_pix2struct.append({'type' : element_type, 'content' : element_text})
        save_data_list_for_pix2struct_annotation.append({'type' : type_to_id[element_type], 'coords' : [x1,y1,x2,y2], 'content' : element_text})

    ## save for yolo
    with open(os.path.join(target_label_base, base_annotation_file_name), 'w') as f:
        for save_data in save_data_list_for_yolo:
            f.write(save_data+"\n")

    ## save for pix2strucnt
    with open(os.path.join(commentator_label_base, base_annotation_file_name), 'w') as f:
        for save_data_pix2struct in save_data_list_for_pix2struct:
            f.write(json.dumps(save_data_pix2struct)+"\n")

    with open(os.path.join(commentator_label_base, base_annotation_file_name_annotation), 'w') as f:
        for save_data_pix2struct_annotation in save_data_list_for_pix2struct_annotation:
            f.write(json.dumps(save_data_pix2struct_annotation)+"\n")


In [6]:
set_episode(driver, 'abc')

## Create Training dataset

In [7]:
tmp_data_path = './tmp_data'
tmp_data_commentator_path = './tmp_data/commentator'
os.makedirs(tmp_data_path, exist_ok=True)
os.makedirs(tmp_data_commentator_path, exist_ok=True)

unique_episode_idx_list = list()
instruction_dict = {}
episode_key_list = list()

while len(episode_key_list) < 130:
    
    episode_key = f'{uuid4()}'
    
    set_episode(driver, episode_key)
    episode_key_list.append(episode_key)

    # grep Search keyword
    target_xpath = '//*[@id="instruction-text"]/h4'
    instruction_text = driver.find_element(By.XPATH,target_xpath).text

    if instruction_text in instruction_dict:
        instruction_dict[instruction_text] += 1
    else:
        instruction_dict[instruction_text] = 1    


save_image_base = '../../datasets/webshop_dataset/images/train2024'
target_label_base = '../../datasets/webshop_dataset/labels/train2024'
commentator_label_base = '../../datasets/webshop_dataset/commentator_labels/train2024'

os.makedirs(save_image_base, exist_ok=True)
os.makedirs(target_label_base, exist_ok=True)
os.makedirs(commentator_label_base, exist_ok=True)

fine_data_count = 0
for episode_key in episode_key_list:
    if fine_data_count == 100:
        break
    try:
        # Reset episode
        set_episode(driver, episode_key)

        # grep Search keyword
        target_xpath = '//*[@id="instruction-text"]/h4'
        instruction_text = driver.find_element(By.XPATH,target_xpath).text        
        if instruction_dict[instruction_text] > 1:
            continue


        target_image_name = f"data_{episode_key}__01_first.png"
        save_dataset(driver, type_to_id, target_image_name, tmp_data_path, tmp_data_path, tmp_data_commentator_path)


        # image
        tmp_image_path = os.path.join(tmp_data_path, target_image_name)

        base_annotation_file_name = target_image_name[:-4] + '.txt'
        base_annotation_file_name_annotation = target_image_name[:-4] + '_annotation.txt'

        # yolo
        tmp_label_path = os.path.join(tmp_data_path, base_annotation_file_name)
        # pix2struct
        tmp_commentator_label_path = os.path.join(tmp_data_commentator_path, base_annotation_file_name)
        tmp_commentator_annotation_label_path = os.path.join(tmp_data_commentator_path, base_annotation_file_name_annotation)

        # get search_keyword
        search_keyword = instruction_text.split('\n')[-1].split(', and price')[0]

        # type Search keywords
        target_xpath = '//*[@id="search_input"]'
        driver.find_element(By.XPATH,target_xpath).clear()
        search_input = driver.find_element(By.XPATH,target_xpath)
        search_input.send_keys(search_keyword)    

        # Click Search button
        target_xpath = '//*[@id="form-buscar"]/div/div/span/button'
        driver.find_element(By.XPATH,target_xpath).click()

        is_not_found = False
        elements = driver.find_elements(By.TAG_NAME, '*')
        if len(elements) == 6:
            for e_element in elements:
                if 'Not Found' in e_element.text:
                    is_not_found = True
                    break

        if is_not_found is True:
            os.remove(tmp_image_path)
            os.remove(tmp_label_path)
            os.remove(tmp_commentator_label_path)
            os.remove(tmp_commentator_annotation_label_path)
            continue
        
        # image 
        saved_image_path = os.path.join(save_image_base, target_image_name)
        saved_label_path = os.path.join(target_label_base, base_annotation_file_name)
        saved_commentator_label_path = os.path.join(commentator_label_base, base_annotation_file_name)
        saved_commentator_annotation_label_path = os.path.join(commentator_label_base, base_annotation_file_name_annotation)

        os.rename(tmp_image_path, saved_image_path)
        os.rename(tmp_label_path, saved_label_path)
        os.rename(tmp_commentator_label_path, saved_commentator_label_path)
        os.rename(tmp_commentator_annotation_label_path, saved_commentator_annotation_label_path)

        target_image_name = f"data_{episode_key}__02_first_search_page.png"
        save_dataset(driver, type_to_id, target_image_name, save_image_base, target_label_base, commentator_label_base)


        # Click first result
        target_xpath = "/html/body/div/div[4]/div[1]/div[2]/ul/div/div/h4[1]/a"
        driver.find_element(By.XPATH,target_xpath).click()

        target_image_name = f"data_{episode_key}__03_first_item.png"
        save_dataset(driver, type_to_id, target_image_name, save_image_base, target_label_base, commentator_label_base)

        fine_data_count += 1
    except:
        pass



## Create Validation dataset

In [8]:
for _ in range(200):
    uuid4()


save_image_base = '../../datasets/webshop_dataset/images/valid2024'
target_label_base = '../../datasets/webshop_dataset/labels/valid2024'
commentator_label_base = '../../datasets/webshop_dataset/commentator_labels/valid2024'

os.makedirs(save_image_base, exist_ok=True)
os.makedirs(target_label_base, exist_ok=True)
os.makedirs(commentator_label_base, exist_ok=True)

unique_episode_idx_list = list()
instruction_dict = {}
episode_key_list = list()

while len(episode_key_list) < 30:
    
    episode_key = f'{uuid4()}'    
    
    set_episode(driver, episode_key)
    episode_key_list.append(episode_key)

    # grep Search keyword
    target_xpath = '//*[@id="instruction-text"]/h4'
    instruction_text = driver.find_element(By.XPATH,target_xpath).text

    if instruction_text in instruction_dict:
        instruction_dict[instruction_text] += 1
    else:
        instruction_dict[instruction_text] = 1    

inst_cnt = sorted(instruction_dict.items(), key=lambda x:x[1], reverse=True)

for inst, cnt in inst_cnt:
    if cnt != 1:
        print(inst[:100], cnt)

fine_data_count = 0
for episode_key in episode_key_list:
    if fine_data_count == 10:
        break
    try:
        # Reset episode
        set_episode(driver, episode_key)

        # grep Search keyword
        target_xpath = '//*[@id="instruction-text"]/h4'
        instruction_text = driver.find_element(By.XPATH,target_xpath).text        
        if instruction_dict[instruction_text] > 1:
            continue


        target_image_name = f"data_{episode_key}__01_first.png"
        save_dataset(driver, type_to_id, target_image_name, tmp_data_path, tmp_data_path, tmp_data_commentator_path)


        # image
        tmp_image_path = os.path.join(tmp_data_path, target_image_name)

        base_annotation_file_name = target_image_name[:-4] + '.txt'
        base_annotation_file_name_annotation = target_image_name[:-4] + '_annotation.txt'

        # yolo
        tmp_label_path = os.path.join(tmp_data_path, base_annotation_file_name)
        # pix2struct
        tmp_commentator_label_path = os.path.join(tmp_data_commentator_path, base_annotation_file_name)
        tmp_commentator_annotation_label_path = os.path.join(tmp_data_commentator_path, base_annotation_file_name_annotation)

        # get search_keyword
        search_keyword = instruction_text.split('\n')[-1].split(', and price')[0]

        # type Search keywords
        target_xpath = '//*[@id="search_input"]'
        driver.find_element(By.XPATH,target_xpath).clear()
        search_input = driver.find_element(By.XPATH,target_xpath)
        search_input.send_keys(search_keyword)    

        # Click Search button
        target_xpath = '//*[@id="form-buscar"]/div/div/span/button'
        driver.find_element(By.XPATH,target_xpath).click()

        is_not_found = False
        elements = driver.find_elements(By.TAG_NAME, '*')
        if len(elements) == 6:
            for e_element in elements:
                if 'Not Found' in e_element.text:
                    is_not_found = True
                    break
        
        if is_not_found is True:
            os.remove(tmp_image_path)
            os.remove(tmp_label_path)
            os.remove(tmp_commentator_label_path)
            os.remove(tmp_commentator_annotation_label_path)
            continue
        
        # image 
        saved_image_path = os.path.join(save_image_base, target_image_name)
        saved_label_path = os.path.join(target_label_base, base_annotation_file_name)
        saved_commentator_label_path = os.path.join(commentator_label_base, base_annotation_file_name)
        saved_commentator_annotation_label_path = os.path.join(commentator_label_base, base_annotation_file_name_annotation)

        os.rename(tmp_image_path, saved_image_path)
        os.rename(tmp_label_path, saved_label_path)
        os.rename(tmp_commentator_label_path, saved_commentator_label_path)
        os.rename(tmp_commentator_annotation_label_path, saved_commentator_annotation_label_path)            

        target_image_name = f"data_{episode_key}__02_first_search_page.png"
        save_dataset(driver, type_to_id, target_image_name, save_image_base, target_label_base, commentator_label_base)


        # Click first result
        target_xpath = "/html/body/div/div[4]/div[1]/div[2]/ul/div/div/h4[1]/a"
        driver.find_element(By.XPATH,target_xpath).click()

        target_image_name = f"data_{episode_key}__03_first_item.png"
        save_dataset(driver, type_to_id, target_image_name, save_image_base, target_label_base, commentator_label_base)

        fine_data_count += 1
    except:
        pass