In [2]:
%matplotlib inline
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from typing import Union
from math import ceil
plt.rcParams["figure.figsize"] = [10,10]
from fastai.vision import *
from itertools import product

In [3]:
out_dir = Path("dataset_segmentation/custom_bg")
imgs_dir = Path("dataset_segmentation/images")
masks_dir = Path("dataset_segmentation/labels")
bg_dir = Path("backgrounds")

In [4]:
def repeated_bg(bg: np.ndarray, size: Union[int,int])-> np.ndarray:
    """return an image of the given size with the background repeaten to fill everything"""
    x,y = size
    bg_x, bg_y, bg_z = bg.shape
    nx = ceil(x / bg_x) 
    ny = ceil(y/ bg_y)
    # 1 as last dimensiont to not repeat z axis (and keep colors compoments)
    return np.tile(bg, (nx,ny,1))[:x,:y,:]

In [5]:
def repeat_bg_trfs(bg: Image, size: TensorImageSize, tfms: TfmList)-> Image:
    _, y, x= size #should be in normal or fastai style???
    bg_z, bg_y, bg_x = bg.shape
    nx = ceil(x / bg_x) 
    ny = ceil(y/ bg_y)
    f_bg = torch.empty((bg_z, ny*bg_y, nx*bg_x), dtype=bg.data.dtype)
    for ix,iy in product(range(nx),range(ny)):
        f_bg[:, bg_y*iy:bg_y*(iy+1), bg_x*ix:bg_x*(ix+1)] = bg.apply_tfms(tfms).data
    return Image(f_bg[:,:y,:x])

In [6]:
def change_bg(img: Image, mask: Image, new_bg: Image)-> Image:
    """replace the background with the provided on img mask and new_bg must be a fastai Image object and have the same size"""
    #convert mask to bool
    mask = mask.data.type(torch.ByteTensor)    
    return Image(torch.where(mask,img.data, new_bg.data))

In [97]:
#create a new dataset with the custom background
def get_mask(path): return masks_dir / path.name
def gen_custom_bg_dataset(bg_path, bg_id, trfms):
    #open bg and set up output folder
    bg  = open_image(bg_path)
    out_path = out_dir / bg_id
    out_path.mkdir(parents=True, exist_ok=True)
    for imgp in get_files(imgs_dir):
        img = open_image(imgp)
        mask = open_mask(get_mask(imgp), div=True)
        new_bg = nbg = repeat_bg_trfs(bg, img.shape, trfms)
        new_img = change_bg(img, mask, new_bg)
        new_img.save(out_path / imgp.name)
        print(f"saved {imgp.name}")

In [7]:
bg_id = "0000_random_rotation"
bgp = bg_dir/"0000.tif"
trfms = get_transforms(max_rotate=90, flip_vert=True)[0]

In [9]:
trfms

[RandTransform(tfm=TfmCrop (crop_pad), kwargs={'row_pct': (0, 1), 'col_pct': (0, 1), 'padding_mode': 'reflection'}, p=1.0, resolved={}, do_run=True, is_random=True, use_on_y=True),
 RandTransform(tfm=TfmAffine (dihedral_affine), kwargs={}, p=1.0, resolved={}, do_run=True, is_random=True, use_on_y=True),
 RandTransform(tfm=TfmCoord (symmetric_warp), kwargs={'magnitude': (-0.2, 0.2)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
 RandTransform(tfm=TfmAffine (rotate), kwargs={'degrees': (-90, 90)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
 RandTransform(tfm=TfmAffine (zoom), kwargs={'scale': (1.0, 1.1), 'row_pct': (0, 1), 'col_pct': (0, 1)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
 RandTransform(tfm=TfmLighting (brightness), kwargs={'change': (0.4, 0.6)}, p=0.75, resolved={}, do_run=True, is_random=True, use_on_y=True),
 RandTransform(tfm=TfmLighting (contrast), kwargs={'scale': (0.8, 1.25)}, p=0.75, resolved={}, do_r

In [99]:
gen_custom_bg_dataset(bgp, bg_id, trfms)

saved lime8.png
saved pere9.png
saved mele9.png
saved mele17.png
saved patate2.png
saved pesche1.png
saved mele2.png
saved pomodoro3.png
saved carotebuccia4.png
saved patate1.png
saved banane5.png
saved lime2.png
saved cipolla14.png
saved lime3.png
saved lime6.png
saved pesche5.png
saved mele4.png
saved prugnerosse3.png
saved kiwi3.png
saved carotebuccia1.png
saved pesche6.png
saved kiwi6.png
saved pesche4.png
saved cipolla11.png
saved mele16.png
saved carotebuccia3.png
saved pomodoro5.png
saved prugnerosse1.png
saved patate5.png
saved zucchine2.png
saved pere11.png
saved carota3.png
saved kiwi8.png
saved pomodoro2.png
saved peperoncino1.png
saved zucchine1.png
saved banane4.png
saved prugnerosse7.png
saved pesche2.png
saved albicocche4.png
saved prugnerosse5.png
saved albicocche1.png
saved lime7.png
saved cipolla5.png
saved mele14.png
saved kiwi1.png
saved albicocche6.png
saved mele8.png
saved pomodoro4.png
saved lime9.png
saved pere3.png
saved albicocche5.png
saved patate8.png
saved 

testing and experiments

In [27]:
imgp = imgs_dir / "albicocche1.png"
maskp = masks_dir / "albicocche1.png"
bgp = bg_dir/"0000.tif"

In [48]:
img = open_image(imgp)
mask = open_mask(maskp, div=True)
bg = open_image(bgp)

In [None]:
nbg = repeat_bg_trfs(bg, img.shape, get_transforms(max_rotate=90, flip_vert=True)[0])

In [None]:
new_img = change_bg(img, mask, nbg)
new_img.show(figsize=(10,10))

In [None]:
repeat_bg_trfs(bg, img.shape, [rotate(degrees=(-180,180), p=1)])

In [64]:
%timeit change_bg(img, mask, nbg) #using torch.where

34.8 ms ± 1.91 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
repeat_bg_trfs(bg, img.shape, get_transforms(max_rotate=90, flip_vert=True)[0])

In [74]:
%timeit nbg = repeat_bg_trfs(bg, img.shape, get_transforms(max_rotate=90, flip_vert=True)[0])

1.3 s ± 19.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
