In [1]:
from fastai.vision.all import *

In [2]:
url = "https://drive.google.com/uc?id=18xM3jU2dSp1DiDqEM6PVXattNMZvsX4z"

In [3]:
!gdown {url}

In [5]:
from zipfile import ZipFile

with ZipFile("../Portrait.zip", "r") as zip_ref:
    zip_ref.extractall("../data")

In [2]:
path = Path("../data")

In [4]:
for walk in path.ls():
    print(repr(walk), walk.is_file())

In [5]:
(path/"GT_png").ls()[0]

In [16]:
mask = Image.open((path/"GT_png").ls()[0])

In [26]:
mask

In [17]:
mask = np.asarray(mask); mask

In [3]:
def get_codes(fnames) -> Dict[int,int]: 
    "Returns a dictionary of `original_code:new_code` for pixel values in segmentation masks"
    unique_codes = set()
    for fname in fnames:
        mask = Image.open(fname)
        mask = np.asarray(mask)
        for color in np.unique(mask):
            unique_codes.add(color)
    return {
        i : color
        for i, color in 
        enumerate(unique_codes)
    }

In [4]:
unique_codes = get_codes((path/"GT_png").ls()[:20])
unique_codes

In [24]:
mask = mask.copy()
np.place(mask, mask==255, 1)
np.unique(mask)

In [5]:
codes = ["Background", "Face"]
blocks = (ImageBlock, MaskBlock(codes=codes))

In [31]:
unique_codes

In [10]:
def get_y(filename:Path, unique_codes:dict):
    "Grabs a mask from `filename` and adjusts the pixel values based on `unique_codes`"
    filename = path/"GT_png"/f'{filename.stem}_mask.png'
    mask = np.asarray(Image.open(filename)).copy()
    for new_value, old_value in unique_codes.items():
        np.place(mask, mask==old_value, new_value)
    return PILMask.create(mask)

In [11]:
new_mask = get_y((path/"images_data_crop").ls()[0], unique_codes)
new_mask.show(cmap="Blues");

In [19]:
block = DataBlock(
    blocks=blocks,
    splitter=RandomSplitter(),
    get_y=partial(get_y, unique_codes=unique_codes),
    item_tfms=Resize(224),
    batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)]
)

In [20]:
dls = block.dataloaders(
    get_image_files(path/'images_data_crop'), 
    bs=8
)

In [21]:
dls.show_batch(cmap="Blues", vmin=0, vmax=1)

In [26]:
splitter = RandomSplitter()
dsets = Datasets(
    get_image_files(path/'images_data_crop'),
    tfms=[
        [PILImage.create], 
        [partial(get_y, unique_codes=unique_codes)]
    ],
    splits = splitter(get_image_files(path/'images_data_crop'))
)

In [35]:
dls = dsets.dataloaders(
    after_item = [
        Resize(224), 
        ToTensor(), 
        AddMaskCodes(codes=codes)
    ],
    after_batch = [
        *aug_transforms(), 
        IntToFloatTensor(), 
        Normalize.from_stats(*imagenet_stats)
    ]
)

In [37]:
dls.show_batch(cmap="Blues", vmin=0, vmax=1)