## **Data and code setup**

In [0]:
%%capture
!pip3 install gdown
!pip3 install gpustat

In [0]:
%%capture
!git clone https://github.com/gevero/enet_tensorflow.git

In [0]:
%%capture
!gdown https://drive.google.com/uc?id=1zQ6PCA7k-1d_s_zrZWftJ0OgS23wKIT_ -O EnetWeights.zip
!unzip -o EnetWeights.zip

## **Notebook Setup**

In [0]:
# update to tf 2.0
from __future__ import absolute_import, division, print_function, unicode_literals

# Install TensorFlow
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.1
except Exception:
  pass

# importing standard libraries
import tensorflow as tf
print(tf.__version__)
import matplotlib.pylab as plt
import numpy as np
import os, os.path
from functools import partial
from google.colab import files

# Importing utils and models
import time
import sys
sys.path.append('./enet_tensorflow')
from utils import preprocess_img_label, map_singlehead, map_doublehead, map_label, tf_dataset_generator, get_class_weights
from models import EnetModel

## **Selenium setup**

In [0]:
%%capture
!apt install chromium-chromedriver
!cp /usr/lib/chromium-browser/chromedriver /usr/bin
!pip install selenium

In [0]:
from selenium import webdriver

# download automatically without dialog box
prefs = {'profile.default_content_setting_values.automatic_downloads': 1}

# set necessary options for headless working
options = webdriver.ChromeOptions()
options.add_argument('--headless')
options.add_argument('--no-sandbox')
options.add_argument('--disable-dev-shm-usage')
options.add_experimental_option("prefs", prefs)

# create webdriver
wd = webdriver.Chrome('chromedriver',options=options)

## **Load Weights**

In [7]:
Enet = EnetModel(C=3,MultiObjective=True,l2=1e-3)
Enet.load_weights('./Enet512x512.tf')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f319b43a978>

## **Get/refresh image from "this person does not exist"**

In [0]:
# get or refresh the image
wd.get('https://www.thispersondoesnotexist.com')
time.sleep(2)

# buttons
save_button = wd.find_element_by_id('saveButton')
another_button = wd.find_element_by_xpath('//*[@title="Save this person"]')
time.sleep(2)

# click them to get image
another_button.click()
save_button.click()

In [0]:
   # decoding image
    img = tf.io.read_file('./person.jpg')
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img,(512,512))
    img = tf.reshape(img,[1,512,512,3])

In [0]:
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]

  return pred_mask[0]

img_enc_probs, img_dec_probs = Enet(img[0:1,:,:,:])
img_dec_out = create_mask(img_dec_probs)

# image
fig = plt.figure(figsize=(20,10))
plt.subplot(1,3,1)
plt.xticks([])
plt.yticks([])
plt.imshow(img.numpy()[0,:,:,:])

# mask
plt.subplot(1,3,2)
plt.xticks([])
plt.yticks([])
plt.imshow(img_dec_out[:,:,0],cmap='viridis')

# image + mask
plt.subplot(1,3,3)
plt.xticks([])
plt.yticks([])
plt.imshow(img.numpy()[0,:,:,:])
plt.imshow(img_dec_out[:,:,0], alpha=0.5,cmap='viridis')

plt.tight_layout()
fig.subplots_adjust(wspace=0.0, hspace=0.0)