In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from glob import glob
from PIL import Image
import numpy as np
import copy

In [None]:
highres_image_paths = glob("../data/osdar23/train/images/rgb_highres_center/*")

In [None]:
def create_grid_mask(img, prob):
    if np.random.rand() > prob:
        return img
    h = img.shape[0]
    w = img.shape[1]
    d1 = 2
    d2 = min(h, w)
    hh = int(1.5 * h)
    ww = int(1.5 * w)
    d = np.random.randint(d1, d2)
    l = np.random.randint(1, d)
    mask = np.ones((hh, ww), np.float32)
    st_h = np.random.randint(d)
    st_w = np.random.randint(d)
    for i in range(hh // d):
        s = d * i + st_h
        t = min(s + l, hh)
        mask[s:t, :] *= 0
    for i in range(ww // d):
        s = d * i + st_w
        t = min(s + l, ww)
        mask[:, s:t] *= 0

    r = np.random.randint(1)
    mask = Image.fromarray(np.uint8(mask))
    mask = mask.rotate(r)
    mask = np.asarray(mask)
    mask = mask[
        (hh - h) // 2 : (hh - h) // 2 + h, (ww - w) // 2 : (ww - w) // 2 + w
    ]

    mask = mask.astype(np.float32)
    mask = mask[:, :, None]
    mask = 1 - mask

    return img * mask 

def get_concat_v(im1, im2):
    dst = Image.new('RGB', (im1.width, im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst

def get_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

In [None]:
resize_lim = [[0.38, 0.55], [0.48, 0.48]]
bot_pct_lim = [[0.6, 0.7], [0.65, 0.65]]
rot_lim = [-5.4, 5.4]
gridmask_prob= 0.1
gridmask_fixed_prob = True
final_dim = [256,740]
rand_flip = True

def transform(img, test=False):
    W, H = img.size

    if test:
        fH, fW = final_dim
        resize = np.mean(resize_lim[1])
        resize_dims = (int(W * resize), int(H * resize))
        newW, newH = resize_dims
        crop_h = int((1 - np.mean(bot_pct_lim[1])) * newH) - fH
        crop_w = int(max(0, newW - fW) / 2)
        crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
        img = img.resize(resize_dims)
        img = img.crop(crop)
        return img

    fH, fW = final_dim
    resize = np.random.uniform(*resize_lim[0])
    resize_dims = (int(W * resize), int(H * resize))

    newW, newH = resize_dims
    crop_h = int((1 - np.random.uniform(*bot_pct_lim[0])) * newH) - fH
    crop_w = int(np.random.uniform(0, max(0, newW - fW)))

    crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
    rotate = np.random.uniform(*rot_lim)

    img = img.resize(resize_dims)
    img = img.crop(crop)
    
    if rand_flip and np.random.choice([0, 1]):
        img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
    
    img = img.rotate(rotate)

    masked = create_grid_mask(np.uint8(img), gridmask_prob)
    img = Image.fromarray((masked).astype(np.uint8))
    
    return img

random_indices = np.random.randint(len(highres_image_paths), size=5)

for i in random_indices:
    img = highres_image_paths[i]
    img = Image.open(img).convert("RGB")
    bot_pct_lim = [[0.55, 0.65], [0.6, 0.6]]
    np.random.seed(i)
    img_train = transform(copy.deepcopy(img), False)
    img_test = transform(copy.deepcopy(img), True)
    img_1 = get_concat_v(img_train, img_test)
    bot_pct_lim = [[0.6, 0.7], [0.65, 0.65]]
    np.random.seed(i)
    img_train = transform(copy.deepcopy(img), False)
    img_test = transform(copy.deepcopy(img), True)
    img_2 = get_concat_v(img_train, img_test)
    img = get_concat_h(img_1, img_2)
    img.show()
