In [None]:
import sys
import numpy as np
import pandas as pd
import cv2
from matplotlib import pyplot as plt
import PIL
from IPython.display import Image
from IPython.display import display as idisplay
import os
import random
import sys
import time
import albumentations as A
import importlib
from yagm.transforms import albumentations_custom as AC
import polars as pl
importlib.reload(AC)


def longest_resize(img, max_h = None, max_w = None, upscale = False, interpolation = cv2.INTER_LINEAR):
    if max_h is None and max_w is None:
        return img
    img_h, img_w = img.shape[:2]
    r = min(max_h / img_h, max_w / img_w)
    if not upscale:
        r = min(1.0, r)
    if r == 1.0:
        return img
    new_h, new_w = int(r * img_h), int(r * img_w)
    img = cv2.resize(img, (new_w, new_h), interpolation)
    return img


def pad(img, new_h, new_w, pad_mode = 'top_left', pad_value = 'noise'):
    ori_h, ori_w = img.shape[:2]
    assert new_h >= ori_h and new_w >= ori_w
    if pad_value == 'noise':
        padded = np.random.rand(new_h, new_w, img.shape[2])
        if img.dtype == np.uint8:
            padded = (padded * 255).astype(np.uint8)
        else:
            padded = padded.astype(img.dtype)
    if pad_mode == 'center':
        pad_top = (new_h - ori_h) // 2
        pad_left = (new_w - ori_w) // 2
    elif pad_mode == 'top_left':
        pad_top, pad_left = 0, 0
    else:
        raise ValueError
    padded[pad_top:pad_top + ori_h, pad_left: pad_left + ori_w] = img
    return padded


def concat_imgs(imgs, max_h = None, max_w = None, axis = 1, border_width = 5, border_color = [255, 0, 255]):
    hws = [img.shape[:2] for img in imgs]
    src_max_h = max([hw[0] for hw in hws])
    src_max_w = max([hw[1] for hw in hws])
    max_h = min(max_h, src_max_h) if max_h is not None else src_max_h
    max_w = min(max_w, src_max_w) if max_w is not None else src_max_w
    r = min(max_h / src_max_h, max_w / src_max_w)
    dst_max_h, dst_max_w = int(r * src_max_h), int(r * src_max_w)
    
    new_imgs = []
    for i, img in enumerate(imgs):
        h, w = img.shape[:2]
        new_h, new_w = int(h * r), int(w * r)
        new_img = cv2.resize(img, (new_w, new_h))
        new_img = pad(new_img, dst_max_h, dst_max_w, pad_mode = 'top_left', pad_value = 'noise')
        new_imgs.append(new_img)
        if i != len(imgs) - 1:
            _shape = list(new_img.shape)
            _shape[axis] = border_width
            border = np.zeros(_shape, dtype = new_img.dtype)
            border_value = np.array(border_color, dtype = np.uint8)[None, None]
            if border.dtype != np.uint8:
                border_value = border_value / 255
            border[..., :] = border_value
                
            new_imgs.append(border)
    ret = np.concatenate(new_imgs, axis = axis)
    return ret


def float_to_uint8(imgs):
    if isinstance(imgs, np.ndarray):
        imgs = [imgs]
    return [(255 * img).astype(np.uint8) for img in imgs]


def display(img, max_h = None, max_w = None):
    img = longest_resize(img, max_h, max_w)
    if img.dtype != np.uint8:
        img = (img * 255).astype(np.uint8)
    idisplay(PIL.Image.fromarray(img))
    

def viz(imgs, transforms, max_h = 320, max_w = 320, assert_no_change=False):
    if not isinstance(transforms, A.Compose):
        transform = A.Compose(transforms, p = 1.0)
    else:
        transform = transforms
    for img in imgs:
        start = time.time()
        ret = transform(image = img)
        end = time.time()
        print('Take:', round((end - start) * 1000), 'ms')
        img2 = ret['image']
        print(img2.dtype, img2.sum(), img2.min(), img2.max())

        if img2.dtype != np.uint8:
            print('Convert from float to uint8..')
            _min = img2.min()
            _max = img2.max()
            if _min < 0.0 or _max > 1.0:
                print(f'WARN: min={_min} max={_max}')
                img2 = (img2 - _min) / (_max - _min)

        eq = A.Equalize(mode='cv', by_channels=True, mask=None, mask_params=(), p=1.0)
        img3 = eq(image = img2)['image']
        
        print(img.shape, img2.shape, img3.shape)
        if assert_no_change:
            assert img.shape == img2.shape
            diff = np.sum(np.abs(img2 - img))
            print('Diff:', diff)
            assert diff == 0
        if len(img.shape) == 2:
            img = img[..., None].repeat(3, axis = -1)
            img2 = img2[..., None].repeat(3, axis = -1)
            img3 = img3[..., None].repeat(3, axis = -1)
        elif len(img.shape) == 3:
            pass
        else:
            raise AssertionError
            
        viz = concat_imgs([img, img2, img3], max_h = max_h, max_w = max_w)
        if viz.dtype != np.uint8:
            viz = (255 * viz).astype(np.uint8)
        display(viz)
        print('\n----------------\n')

In [None]:
df = pl.read_csv('/home/dangnh36/datasets/.comp/byu/processed/gt_v3.csv')
df

In [None]:
SELECT_TOMO_IDS = ['tomo_1cc887', 'tomo_1ab322', 'tomo_ed1c97',
                  'tomo_0a8f05', 'tomo_0363f2', 'tomo_10c564',
                   'tomo_319f79', 'tomo_4555b6', 'tomo_868255',
                   'tomo_b2ebbc', 'tomo_bede89', 'tomo_918e2b', 
                   'tomo_9c0253', 'tomo_b54396', 'tomo_dfc627'
                  ]

imgs = []
for tomo_id in SELECT_TOMO_IDS:
    row = df.filter(pl.col('tomo_id') == tomo_id)[0]
    ori_spacing = float(row['voxel_spacing'][0])
    img = cv2.imread(f'/home/dangnh36/datasets/.comp/byu/processed/viz/annotations/{tomo_id}.jpg')[..., 1]
    ori_shape = img.shape
    img = cv2.resize(img, None, fx = ori_spacing/16.0, fy = ori_spacing / 16.0)
    print(tomo_id, ori_shape, '-->', img.shape)
    imgs.append(img)

In [None]:
# # TEMPLATE, COPY THIS

# transforms = [
#     A.
# ]
# viz(imgs, transforms, max_h = 640, max_w = 640)

In [None]:
# print(A.Compose.main_compose)

# a = A.Compose([A.HorizontalFlip(p=1.0)])
# b = A.Compose([A.VerticalFlip(p=1.0)])
# print(id(a.main_compose), id(b.main_compose))
# print(A.Compose.main_compose, a.main_compose, b.main_compose)
# a.main_compose = not a.main_compose
# A.Compose.main_compose = not A.Compose.main_compose
# print(A.Compose.main_compose, a.main_compose, a.__class__.main_compose, b.main_compose)

## MAIN

In [None]:
def _read_func(p):
    img = cv2.imread(p)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#     img = img[..., 1]
    return img


random.seed(42)
transforms = [
   A.Perspective(
        scale=(0.08, 0.08),
        keep_size=False,
        pad_mode=cv2.BORDER_CONSTANT,
        pad_val=128,
        fit_output=True,
        interpolation=cv2.INTER_LINEAR,
        p=1.0,
    ),
]

viz(imgs[:3], transforms, max_h = 640, max_w = 640, assert_no_change=False)

In [None]:
# uint8 only, tile_grid_size=4,8,16,32
A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p = 1.0)

# jitter on float32
# Implement AdaptiveDownscale based on current resolution
A.Downscale(scale_range=(0.4, 0.95),
                interpolation_pair={'upscale': cv2.INTER_LANCZOS4, 'downscale': cv2.INTER_AREA}, p=1.0)

# wrong on float32
# good, but wrong result (bug?) when applied to float32 img
A.Emboss(alpha=(0.3, 0.8), strength=(0.2, 0.8), p=1.0)

# uint8 only
# implement mask by spine curve
# histogram equalization as fixed transformation?
A.Equalize(mode='cv', by_channels=True, mask=None, mask_params=(), p=1.0)

A.GaussNoise(var_limit=(10.0, 80.0), mean=0, per_channel=True, noise_scale_factor=1.0, p=1.0)


# [OPTIONAL] use with caution
# bug with input/ref is uint8 -> blank image returned
# because of recent change in scikit-image update
# ref: https://github.com/albumentations-team/albumentations/issues/1869
# More: implement histogram matching with mask
A.HistogramMatching(float_imgs, blend_ratio=(0.5, 1.0), read_fn=lambda x: x, p=1.0)


A.ImageCompression(compression_type='jpeg', quality_range=(30, 95), p=1.0)


A.MultiplicativeNoise(multiplier=(0.8, 1.2), per_channel=True, elementwise=True, p=1.0)


# [OPTIONAL]
A.PixelDistributionAdaptation(uint8_imgs,
                              blend_ratio=(1.0, 1.0),
                              read_fn=lambda x: x, transform_type='pca', p=1.0)

A.Posterize(num_bits=4, p=1.0)

# A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.3), contrast_limit=(-0.4, 0.6), brightness_by_max=True, p=1.0)
A.OneOf([
    A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.0), contrast_limit=(-0.2, 0.0), brightness_by_max=True, p=0.3 * 0.4),
    A.RandomBrightnessContrast(brightness_limit=(-0.15, 0.0), contrast_limit=(0.0, 0.6), brightness_by_max=True, p=0.3 * 0.4),
    A.RandomBrightnessContrast(brightness_limit=(0.0, 0.3), contrast_limit=(-0.4, 0.0), brightness_by_max=True, p=0.3 * 0.15),
    A.RandomBrightnessContrast(brightness_limit=(0.0, 0.2), contrast_limit=(0.0, 0.5), brightness_by_max=True, p=0.3 * 0.05),
    A.RandomBrightnessContrast(brightness_limit=(-0.3, 0.0), contrast_limit=(-0.2, 0.0), brightness_by_max=False, p=0.7 * 0.4),
    A.RandomBrightnessContrast(brightness_limit=(-0.4, 0.0), contrast_limit=(0.0, 0.6), brightness_by_max=False, p=0.7 * 0.4),
    A.RandomBrightnessContrast(brightness_limit=(0.0, 0.4), contrast_limit=(-0.4, 0.0), brightness_by_max=False, p=0.7 * 0.15),
    A.RandomBrightnessContrast(brightness_limit=(0.0, 0.3), contrast_limit=(0.0, 0.5), brightness_by_max=False, p=0.7 * 0.05)
], p=1.0)

A.RandomGamma(gamma_limit=(60, 150), p=1.0)

A.RandomToneCurve(scale=0.4, per_channel=True, p=1.0)

A.Sharpen(alpha=(0.1, 0.4), lightness=(0.0, 0.4), always_apply=None, p=1.0)


A.RandomSizedBBoxSafeCrop(height, width, erosion_rate=0.0, interpolation=1, always_apply=None, p=1.0)



A.CoarseDropout(max_holes=None, max_height=None, max_width=None, min_holes=None, min_height=None, min_width=None, fill_value=0, mask_fill_value=None, num_holes_range=(1, 1), hole_height_range=(8, 8), hole_width_range=(8, 8), always_apply=None, p=0.5)

A.GridDropout(ratio=0.8, random_offset=False, fill_value=0, mask_fill_value=None, unit_size_range=(50, 50), holes_number_xy=(5, 5), shift_xy=(0, 0),p=1.0)

# [IGNORE] JUST TRY TO ADD IT :), add to see if it downgrade performance as expected
HorizontalFlip(p=1.0)
VerticalFlip(p=1.0)

LongestMaxSize(max_size=512, interpolation=cv2.INTER_LINEAR, p=1.0)
A.PadIfNeeded(min_height=512, min_width=512, pad_height_divisor=None, pad_width_divisor=None, position='center',
                  border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=None, p=1.0)

# very slow
# could reduce localization acc
# how to deal with out-of-image keypoints
# implement: if dist > keypoints_threshold, then roll back to NoOp
A.OneOf([
    A.PiecewiseAffine(scale=(0.025, 0.025), nb_rows=4, nb_cols=4, interpolation=cv2.INTER_LINEAR, mask_interpolation=0, cval=0, cval_mask=0, mode='constant', absolute_scale=False, keypoints_threshold=0.01, p=1.0),
    A.PiecewiseAffine(scale=(0.016, 0.016), nb_rows=6, nb_cols=6, interpolation=cv2.INTER_LINEAR, mask_interpolation=0, cval=0, cval_mask=0, mode='constant', absolute_scale=False, keypoints_threshold=0.01, p=1.0)
    A.PiecewiseAffine(scale=(0.0125, 0.0125), nb_rows=8, nb_cols=8, interpolation=cv2.INTER_LINEAR, mask_interpolation=0, cval=0, cval_mask=0, mode='constant', absolute_scale=False, keypoints_threshold=0.01, p=1.0),  
])

AA.RandomResizedCropNoResize

# Dangerous since length is provided in pixels
# Modify so that length is fraction of width/height
A.XYMasking(num_masks_x=(3,3), num_masks_y=(4,4), mask_x_length=(5, 5), mask_y_length=(10, 10), fill_value=0, mask_fill_value=0, p=1.0)

##################################
# PARAMETERIZED

A.OneOf([
    # rotate only
    A.Affine(scale={'x': (0.6, 1.2), 'y': (0.6, 1.2)},
             translate_percent={'x': (-0.25, 0.25), 'y': (-0.4, 0.0)},
             rotate=(-30, 30), 
             shear={'x': (-10, 10), 'y': (-5, 5)},
             interpolation=cv2.INTER_LINEAR, cval=0, mode=cv2.BORDER_CONSTANT, fit_output=False, keep_ratio=False, balanced_scale=True, p=1.0),
    # shear only
    A.Affine(scale={'x': (0.6, 1.2), 'y': (0.6, 1.2)},
             translate_percent={'x': (-0.25, 0.25), 'y': (-0.4, 0.0)},
             rotate=(-10, 10),
             shear={'x': (-30, 30), 'y': (-30, 30)},
             interpolation=cv2.INTER_LINEAR, cval=0, mode=cv2.BORDER_CONSTANT, fit_output=False, keep_ratio=False, balanced_scale=True, p=1.0)
], p=1.0)

# identical to Affine with less functionals: scale for x/y separately, shear
A.ShiftScaleRotate(scale_limit=(-0.4, 0.2), rotate_limit=(-30, 30), interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, value=0, shift_limit_x=(-0.25, 0.25), shift_limit_y=(-0.4, 0.0), p=1.0)

A.Perspective(scale=(0.05, 0.1), keep_size=False, pad_mode=cv2.BORDER_CONSTANT, pad_val=0, fit_output=True, interpolation=cv2.INTER_LINEAR, p=1.0)



In [None]:
341 - 183, 577 - 260

## NOTES

In [None]:
# [OPTIONAL] use with caution
# bug with input/ref is uint8 -> blank image returned
# because of recent change in scikit-image update
# ref: https://github.com/albumentations-team/albumentations/issues/1869
# More: implement histogram matching with mask
A.HistogramMatching(float_imgs, blend_ratio=(0.5, 1.0), read_fn=lambda x: x, p=1.0)

# [OPTIONAL]
A.PixelDistributionAdaptation(uint8_imgs,
                              blend_ratio=(1.0, 1.0),
                              read_fn=lambda x: x, transform_type='pca', p=1.0)





In [None]:
from functools import partial
from shapely.geometry import Polygon

def get_vb_union_bbox(params, data, filter_vb_ids = [0], x_include_all = True, y_lowest_offset = True):
#     print('params:', params)
#     print('data:', data)
    assert len(data['keypoints']) == len(data['keypoint_classes']) == len(data['keypoint_weights'])
    img_h, img_w = params['shape'][:2]
    
    xs = []
    ys = []
    vb_ids = []
    min_x = img_w
    max_x = 0
    min_y = img_h
    max_y = 0
    for kpt, kpt_cls, kpt_weight in zip(data['keypoints'], data['keypoint_classes'], data['keypoint_weights']):
        if kpt_weight == -1 or kpt_cls == -1:
            continue
        min_x = min(min_x, kpt[0])
        max_x = max(max_x, kpt[0])
        min_y = min(min_y, kpt[1])
        max_y = max(max_y, kpt[1])
        vb_id = kpt_cls // 4
        if vb_id not in filter_vb_ids:
            continue
        # xy
        xs.append(kpt[0])
        ys.append(kpt[1])
        vb_ids.append(vb_id)
    if vb_ids:
        if y_lowest_offset:
            lowest_vb_id = min(vb_ids)
            lowest_vb_poly = [(xs[j], ys[j]) for j, vb_id in enumerate(vb_ids) if vb_id == lowest_vb_id]
            try:
                lowest_vb_area = Polygon(lowest_vb_poly).area
                delta_y = (lowest_vb_area ** 0.5) / 2
            except:
                delta_y = 0
        else:
            delta_y = 0
        assert delta_y >=0
        if x_include_all:
            ret = min_x, min(ys), max_x, max(ys) + delta_y
        else:
            ret = min(xs), min(ys), max(xs), max(ys) + delta_y
    else:
        ret = min_x, min_y, max_x, max_y
    print('keep bbox:', ret)
    return ret


AC.CustomRandomSizedBBoxSafeCrop(scale=(0.25, 1.0), ratio=(0.25, 1.5),
                                 get_bbox_func = partial(get_vb_union_bbox, filter_vb_ids = [0], x_include_all = True, y_lowest_offset = True),
                                 retry=30, p=1.0)


AC.CustomCoarseDropout(fill_value=0, num_holes_range=(4, 8), hole_height_range=(0.1, 0.2), 
                hole_width_range=(0.1, 0.2), p=1.0)
AC.CustomGridDropout(ratio=0.5, random_offset=True, fill_value=0,
              holes_number_xy=((3, 8), (3, 8)), p=1.0)

AC.CustomXYMasking(num_masks_x=(2,5), num_masks_y=(2,5),
                mask_x_length=(0.03, 0.05), mask_y_length=(0.03, 0.05),
                fill_value=0, p=1.0)

In [None]:
augment = A.Compose([
    # crop
    AA.RandomResizedBBoxSafeCropNoResize(p=0.4),

    # flip
    A.HorizontalFlip(p=0.0),
    A.VerticalFlip(p=0.0),

    # noise
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 80.0), mean=0, per_channel=True, noise_scale_factor=1.0, p=1.0),
        A.MultiplicativeNoise(multiplier=(0.8, 1.2), per_channel=True, elementwise=True, p=1.0),
    ], p = 1.0),

    # reduce quality
    A.OneOf([
        # jitter on float32, implement AdaptiveDownscale based on current resolution
        A.Downscale(scale_range=(0.4, 0.95), interpolation_pair={'upscale': cv2.INTER_LANCZOS4, 'downscale': cv2.INTER_AREA}, p=1.0),
        A.ImageCompression(compression_type='jpeg', quality_range=(30, 95), p=1.0),
        A.Posterize(num_bits=4, p=1.0), 
    ], p=1.0),

    # texture, contrast
    A.OneOf([
        # wrong on float32 img
        A.Emboss(alpha=(0.3, 0.8), strength=(0.2, 0.8), p=1.0),
        A.Sharpen(alpha=(0.1, 0.4), lightness=(0.0, 0.4), p=1.0),
        A.OneOf([
            A.CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p = 1.0),
            A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p = 1.0),
            A.CLAHE(clip_limit=4.0, tile_grid_size=(16, 16), p = 1.0),
        ], p=1.0)
    ], p=1.0),
    
    # color
    A.OneOf([
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=(-0.1, 0.0), contrast_limit=(-0.2, 0.0), brightness_by_max=True, p=0.3 * 0.4),
            A.RandomBrightnessContrast(brightness_limit=(-0.15, 0.0), contrast_limit=(0.0, 0.6), brightness_by_max=True, p=0.3 * 0.4),
            A.RandomBrightnessContrast(brightness_limit=(0.0, 0.3), contrast_limit=(-0.4, 0.0), brightness_by_max=True, p=0.3 * 0.15),
            A.RandomBrightnessContrast(brightness_limit=(0.0, 0.2), contrast_limit=(0.0, 0.5), brightness_by_max=True, p=0.3 * 0.05),
            A.RandomBrightnessContrast(brightness_limit=(-0.3, 0.0), contrast_limit=(-0.2, 0.0), brightness_by_max=False, p=0.7 * 0.4),
            A.RandomBrightnessContrast(brightness_limit=(-0.4, 0.0), contrast_limit=(0.0, 0.6), brightness_by_max=False, p=0.7 * 0.4),
            A.RandomBrightnessContrast(brightness_limit=(0.0, 0.4), contrast_limit=(-0.4, 0.0), brightness_by_max=False, p=0.7 * 0.15),
            A.RandomBrightnessContrast(brightness_limit=(0.0, 0.3), contrast_limit=(0.0, 0.5), brightness_by_max=False, p=0.7 * 0.05)
        ], p=1.0),
        A.RandomToneCurve(scale=0.4, per_channel=True, p=1.0),
        A.RandomGamma(gamma_limit=(60, 150), p=1.0)
    ]),

    # geometric
    A.OneOf([
        A.OneOf([
            # rotate only
            A.Affine(scale={'x': (0.6, 1.2), 'y': (0.6, 1.2)},
                     translate_percent={'x': (-0.25, 0.25), 'y': (-0.4, 0.0)},
                     rotate=(-30, 30), shear={'x': (-10, 10), 'y': (-5, 5)},
                     interpolation=cv2.INTER_LINEAR, cval=0, mode=cv2.BORDER_CONSTANT, fit_output=False, keep_ratio=False, balanced_scale=True, p=1.0),
            # shear only
            A.Affine(scale={'x': (0.6, 1.2), 'y': (0.6, 1.2)},
                     translate_percent={'x': (-0.25, 0.25), 'y': (-0.4, 0.0)},
                     rotate=(-10, 10), shear={'x': (-30, 30), 'y': (-30, 30)},
                     interpolation=cv2.INTER_LINEAR, cval=0, mode=cv2.BORDER_CONSTANT, fit_output=False, keep_ratio=False, balanced_scale=True, p=1.0)
        ], p=1.0),
        # identical to Affine with less functionals: scale for x/y separately, shear
        A.ShiftScaleRotate(scale_limit=(-0.4, 0.2), rotate_limit=(-30, 30), interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_CONSTANT, value=0, shift_limit_x=(-0.25, 0.25), shift_limit_y=(-0.4, 0.0), p=0.0),
        A.Perspective(scale=(0.05, 0.1), keep_size=False, pad_mode=cv2.BORDER_CONSTANT, pad_val=0, fit_output=True, interpolation=cv2.INTER_LINEAR, p=1.0),
        # very slow + in-accurate keypoints
        # how to deal with out-of-image keypoints ?
        # implement: if dist > keypoints_threshold, then roll back to NoOp
        A.OneOf([
            A.PiecewiseAffine(scale=(0.025, 0.025), nb_rows=4, nb_cols=4, interpolation=cv2.INTER_LINEAR, mask_interpolation=0, cval=0, cval_mask=0, mode='constant', absolute_scale=False, keypoints_threshold=0.01, p=1.0),
            A.PiecewiseAffine(scale=(0.016, 0.016), nb_rows=6, nb_cols=6, interpolation=cv2.INTER_LINEAR, mask_interpolation=0, cval=0, cval_mask=0, mode='constant', absolute_scale=False, keypoints_threshold=0.01, p=1.0)
            A.PiecewiseAffine(scale=(0.0125, 0.0125), nb_rows=8, nb_cols=8, interpolation=cv2.INTER_LINEAR, mask_interpolation=0, cval=0, cval_mask=0, mode='constant', absolute_scale=False, keypoints_threshold=0.01, p=1.0)
        ], p=1.0)
    ], p=1.0),

    # # dropout
    # A.OneOf([
        
    # ], p=1.0)
], p = 1.0)


transform = A.Compose([
    # histogram equalization as fixed transformation?
    A.Equalize(mode='cv', by_channels=True, mask=None, mask_params=(), p=1.0),
    A.LongestMaxSize(max_size=512, interpolation=cv2.INTER_LINEAR, p=1.0),
    A.PadIfNeeded(min_height=512, min_width=512, pad_height_divisor=None, pad_width_divisor=None, position='top_left',
                  border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=None, p=1.0)
])

In [None]:
def _read_func(p):
    img = cv2.imread(p)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
#     img = img[..., 1]
    return img

from functools import partial
from shapely.geometry import Polygon

def get_vb_union_bbox(params, data, filter_vb_ids = [0], x_include_all = True, y_lowest_offset = True):
    # print('params:', params)
#     print('data:', data)
    assert len(data['keypoints']) == len(data['keypoint_classes']) == len(data['keypoint_weights'])
    img_h, img_w = params['shape'][:2]
    
    xs = []
    ys = []
    vb_ids = []
    min_x = img_w
    max_x = 0
    min_y = img_h
    max_y = 0
    for kpt, kpt_cls, kpt_weight in zip(data['keypoints'], data['keypoint_classes'], data['keypoint_weights']):
        if kpt_weight == -1 or kpt_cls == -1:
            continue
        min_x = min(min_x, kpt[0])
        max_x = max(max_x, kpt[0])
        min_y = min(min_y, kpt[1])
        max_y = max(max_y, kpt[1])
        vb_id = kpt_cls // 4
        if vb_id not in filter_vb_ids:
            continue
        # xy
        xs.append(kpt[0])
        ys.append(kpt[1])
        vb_ids.append(vb_id)
    if vb_ids:
        if y_lowest_offset:
            lowest_vb_id = min(vb_ids)
            lowest_vb_poly = [(xs[j], ys[j]) for j, vb_id in enumerate(vb_ids) if vb_id == lowest_vb_id]
            try:
                lowest_vb_area = Polygon(lowest_vb_poly).area
                delta_y = (lowest_vb_area ** 0.5) / 2
            except:
                delta_y = 0
        else:
            delta_y = 0
        assert delta_y >=0
        if x_include_all:
            ret = min_x, min(ys), max_x, max(ys) + delta_y
        else:
            ret = min(xs), min(ys), max(xs), max(ys) + delta_y
    else:
        ret = min_x, min_y, max_x, max_y
    print('keep bbox:', ret)
    return ret

random.seed(611)
transforms = [
    AC.CustomRandomSizedBBoxSafeCrop(scale=(0.25, 1.0), ratio=(0.25, 1.5),
                                     get_bbox_func = partial(get_vb_union_bbox, filter_vb_ids = [0], x_include_all = True, y_lowest_offset = True),
                                     retry=30, p=1.0)
]
viz(
    # float_imgs,
    uint8_imgs,
    # rgb_uint8_imgs,
    all_keypoints,
    transforms, max_h = 640, max_w = 640, assert_no_change=False)

In [None]:
0.4 * 0.6 + 0.1

In [None]:
0.6 * 0.6