In [1]:
from fastai.vision.all import *

In [2]:
url = "https://drive.google.com/uc?id=18xM3jU2dSp1DiDqEM6PVXattNMZvsX4z"

In [None]:
!gdown {url}

In [6]:
from zipfile import ZipFile

with ZipFile("Portrait.zip", "r") as zip_ref:
    zip_ref.extractall("data")

In [7]:
path = Path("data")

In [None]:
for walk in path.ls():
    print(repr(walk), walk.is_file())

In [None]:
(path/"GT_png").ls()[0]

In [10]:
mask = Image.open((path/"GT_png").ls()[0])

In [None]:
mask

In [None]:
mask = np.asarray(mask); mask

In [13]:
def get_codes(fnames) -> Dict[int,int]: 
    "Returns a dictionary of `original_code:new_code` for pixel values in segmentation masks"
    unique_codes = set()
    for fname in fnames:
        mask = Image.open(fname)
        mask = np.asarray(mask)
        for color in np.unique(mask):
            unique_codes.add(color)
    return {
        i : color
        for i, color in 
        enumerate(unique_codes)
    }

In [None]:
unique_codes = get_codes((path/"GT_png").ls()[:20])
unique_codes

In [None]:
mask = mask.copy()
np.place(mask, mask==255, 1)
np.unique(mask)

In [16]:
codes = ["Background", "Face"]
blocks = (ImageBlock, MaskBlock(codes=codes))

In [None]:
unique_codes

In [18]:
def get_y(filename:Path, unique_codes:dict):
    "Grabs a mask from `filename` and adjusts the pixel values based on `unique_codes`"
    filename = path/"GT_png"/f'{filename.stem}_mask.png'
    mask = np.asarray(Image.open(filename)).copy()
    for new_value, old_value in unique_codes.items():
        np.place(mask, mask==old_value, new_value)
    return PILMask.create(mask)

In [None]:
new_mask = get_y((path/"images_data_crop").ls()[0], unique_codes)
new_mask.show(cmap="Blues");

In [20]:
block = DataBlock(
    blocks=blocks,
    splitter=RandomSplitter(),
    get_y=partial(get_y, unique_codes=unique_codes),
    item_tfms=Resize(224),
    batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)]
)

In [21]:
dls = block.dataloaders(
    get_image_files(path/'images_data_crop'), 
    bs=8
)

In [None]:
dls.show_batch(cmap="Blues", vmin=0, vmax=1)

In [23]:
splitter = RandomSplitter()
dsets = Datasets(
    get_image_files(path/'images_data_crop'),
    tfms=[
        [PILImage.create], 
        [partial(get_y, unique_codes=unique_codes)]
    ],
    splits = splitter(get_image_files(path/'images_data_crop'))
)

In [24]:
dls = dsets.dataloaders(
    after_item = [
        Resize(224), 
        ToTensor(), 
        AddMaskCodes(codes=codes)
    ],
    after_batch = [
        *aug_transforms(), 
        IntToFloatTensor(), 
        Normalize.from_stats(*imagenet_stats)
    ],
    bs=8
)

In [None]:
dls.show_batch(cmap="Blues", vmin=0, vmax=1)

In [27]:
learn = unet_learner(
    dls, 
    resnet34, 
    metrics=partial(accuracy, axis=1), 
    self_attention=True, 
    act_cls=Mish,
    loss_func = CrossEntropyLossFlat(axis=1)
)

In [None]:
learn.summary()

In [None]:
learn.fit_one_cycle(10, 1e-3)

In [None]:
learn.save("stage_1")
#learn.load("stage_1")

In [None]:
learn.show_results(max_n=4, figsize=(12,6))

In [None]:
learn.unfreeze()
learn.fit_one_cycle(4, slice(1e-3/400, 1e-3/4))

In [None]:
learn.show_results(max_n=4, figsize=(12,6))

In [None]:
dl = learn.dls.test_dl(
    (path/'images_data_crop').ls()[:5]
)
dl.show_batch()

In [None]:
preds = learn.get_preds(dl=dl)

In [None]:
preds[0].shape

In [45]:
pred = preds[0][0].argmax(dim=0)

In [None]:
pred.shape

In [None]:
plt.imshow(pred);

In [48]:
pred = pred.numpy()
rescaled = (255.0 / pred.max() * (pred - pred.min())).astype(np.uint8)
im = Image.fromarray(rescaled)
im.save("mask.png")

In [None]:
im

In [72]:
fnames = (path/'images_data_crop').ls()[:5]

item_tfms = Pipeline([
    PILImage.create, 
    RandomResizedCrop(224), 
    ToTensor()
], split_idx=1)

batch_tfms = Pipeline([
    IntToFloatTensor(), 
    Normalize.from_stats(*imagenet_stats)
])

batch = []
for fname in fnames:
    batch.append(item_tfms(fname))
batch = torch.stack(batch, dim=0)
batch = batch_tfms(batch.cuda())

model = learn.model
model.eval()

with torch.no_grad():
    preds = model(batch)

for i,pred in enumerate(preds):
    pred = pred.argmax(0)
    pred = pred.cpu().numpy()
    rescaled = (255.0 / pred.max() * (pred - pred.min())).astype(np.uint8)
    im = Image.fromarray(rescaled)
    im.save(f'pred_{i}.png')