In [1]:
import os
import random
import torch
import json
import sys

import torchvision.transforms as T
from tqdm.auto import tqdm

In [2]:
from facade_project import FACADE_LABELME_ORIGINAL_DIR, FACADE_IMAGES_DIR, LABEL_NAME_TO_VALUE, PATH_TO_DATA, IMG_MAX_SIZE
from facade_project.utils.load import load_tuple_from_json
from facade_project.geometry.heatmap import extract_heatmaps_info
from facade_project.geometry.masks import crop_pil, get_bbox
from facade_project.geometry.image import resize, rotate
from facade_project.geometry.heatmap import \
    rotate as rotate_info,\
    crop as crop_info,\
    resize as resize_info

In [3]:
img_paths = [os.path.join(FACADE_LABELME_ORIGINAL_DIR, fname) for fname in sorted(os.listdir(FACADE_LABELME_ORIGINAL_DIR))]

len(img_paths)

418

# Cut border given a ratio

In [5]:
def cut_borders(img, lbl):
    ratios = [1, 4/3, 3/4]
    
    bbox = get_bbox(lbl)
    
    best_ratio_dist = sys.maxsize
    closest_ratio = None
    best_bbox = None
    for ratio in ratios:
        bbox_extended, ratio_dist = extend_bbox_for_ratio(img.size, bbox, ratio)
        if ratio_dist < best_ratio_dist:
            best_ratio_dist = ratio_dist
            closest_ratio = ratio
            best_bbox = bbox_extended
    
    bbox = best_bbox
    width, height = img.size
    tl_x, tl_y, br_x, br_y = bbox
    assert tl_x >= 0, '{} >= {}'.format(tl_x, 0)
    assert tl_y >= 0, '{} >= {}'.format(tl_y, 0)
    assert br_x <= width, '{} <= {}'.format(br_x, width)
    assert br_y <= height, '{} <= {}'.format(br_y, height)
    
    return crop_pil(img, bbox), crop_pil(lbl, bbox), bbox, closest_ratio

In [4]:
def extend_bbox_for_ratio(dim, bbox, ratio):
    width, height = dim
    tl_x, tl_y, br_x, br_y = bbox
    matched_ratio = False

    bbox_width = br_x - tl_x
    bbox_height = br_y - tl_y
    bbox_ratio = bbox_width / bbox_height
    
    
    if bbox_ratio < ratio:
        missing_width = round(ratio * bbox_height - bbox_width)
        slack_left = tl_x
        slack_right = width - br_x
        width_slack = slack_left + slack_right
        
        if width_slack >= missing_width:
            matched_ratio = True
            # good new -> enough slack
            if slack_left <= slack_right:
                left_add = min(slack_left, missing_width//2)
                right_add = missing_width - left_add
            else:
                right_add = min(slack_right, missing_width//2)
                left_add = missing_width - right_add
        else:
            left_add = slack_left
            right_add = slack_right
        
        bbox = tl_x - left_add, tl_y, br_x + right_add, br_y   
        
    elif bbox_ratio >= ratio:
        missing_height = round(bbox_width / ratio - bbox_height)
        slack_top = tl_y
        slack_bottom = height - br_y
        height_slack = slack_top + slack_bottom
        
        if height_slack > missing_height:
            matched_ratio = True
            if slack_top <= slack_bottom:
                top_add = min(slack_top, missing_height//2)
                bottom_add = missing_height - top_add
            else:
                bottom_add = min(slack_bottom, missing_height//2)
                top_add = missing_height - bottom_add
        else:
            top_add = slack_top
            bottom_add = slack_bottom 
        
        bbox = tl_x, tl_y - top_add, br_x, br_y + bottom_add
        
    ratio_dist = 0
    if not matched_ratio:
        tl_x, tl_y, br_x, br_y = bbox
        bbox_width = br_x - tl_x
        bbox_height = br_y - tl_y
        bbox_ratio = bbox_width / bbox_height
        ratio_dist = abs(ratio - bbox_ratio) / ratio
    return bbox, ratio_dist

# Generate
Only run this to generate the data

In [7]:
# in case we want to add more rotations
rotation_offset = 0
num_rot = 5
max_size = IMG_MAX_SIZE
dir_name = '{}/images/tensor/rotated_rescaled'.format(PATH_TO_DATA)
heatmap_infos_path = '{}/heatmaps/json/heatmaps_infos_rotated_rescaled.json'.format(PATH_TO_DATA)
heatmap_infos = dict()

for idx, path in enumerate(tqdm(img_paths)):
    heatmap_info = extract_heatmaps_info(json.load(open(path, mode='r')))
    
    img_original, lbl_original = load_tuple_from_json(path)
    img_pil, lbl_pil = T.ToPILImage()(img_original), T.ToPILImage()(lbl_original)
    
    if idx not in heatmap_infos:
        heatmap_infos[idx] = dict()
    
    for jdx in range(num_rot):
        
        info = heatmap_info
        
        #saving the original as jdx = 0
        if jdx + rotation_offset > 0:
            angle = random.randint(1, 10)
            if random.randint(0,1) == 0:
                angle *= -1
            info = rotate_info(info, angle)
            img = rotate(img_pil, angle, itp_name='BI')
            lbl = rotate(lbl_pil, angle, itp_name='NN')
            # cut borders
            img, lbl, bbox, closest_ratio = cut_borders(img, lbl)
        else:
            img, lbl, bbox, closest_ratio = cut_borders(img_pil, lbl_pil)
            
        info = crop_info(info, bbox)
        
        # resize
        max_width = max_size
        max_height = max_size
        if closest_ratio > 1:
            max_width = max_size
            max_height = round(max_size / closest_ratio)
        elif closest_ratio < 1:
            max_height = max_size
            max_width = round(max_size * closest_ratio)
            
        resize_size = (max_height, max_width)
        img = resize(img, size=resize_size, itp_name='BI')
        lbl = resize(lbl, size=resize_size, itp_name='NN')
        info = resize_info(info, resize_size)
        
        #print(img.size)
        img = T.ToTensor()(img)
        lbl = (T.ToTensor()(lbl) * 255).int()
                
        heatmap_infos[idx][jdx] = info
        
        json.dump(heatmap_infos, open(heatmap_infos_path, mode='w'))
        torch.save(img, '{}/img_{:03d}_{:03d}.torch'.format(dir_name, idx, jdx + rotation_offset))
        torch.save(lbl, '{}/lbl_{:03d}_{:03d}.torch'.format(dir_name, idx, jdx + rotation_offset))


HBox(children=(IntProgress(value=0, max=418), HTML(value='')))


