<a href="https://colab.research.google.com/github/google-research/skai/blob/skai-colab/src/SKAI_2022_Colab_Github.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **The SKAI isn’t the limit 🚀**
***Assessing Post-Disaster Damage 🏚️ from Satellite Imagery 🛰️ using Semi-Supervised Learning Techniques 📔***

*by Amine Baha, WFP Innovation Accelerator, 04th June 2022*

## Intro 🏹

WFP partnered with Google Research to set up **SKAI**, a humanitarian response mapping solution powered by artificial intelligence — an approach that combines statistical methods, data and modern computing techniques to automate specific tasks. SKAI assesses damage to buildings by applying computer vision — computer algorithms that can interpret information extracted from visual materials such as, in this case, **satellite images of areas impacted by conflict, climate events, or other disasters**.

![Skai Logo](https://storage.googleapis.com/skai-public/skai_logo.png)

The type of machine learning used in SKAI, learns from a small number of labeled and a large number of unlabeled images of affected buildings. SKAI uses a ***semi-supervised learning technique*** that reduces the required number of labeled examples by an order of magnitude. As such, SKAI models typically *only need a couple hundred labeled examples* to achieve high accuracy, significantly improving the speed at which accurate results can be obtained.

Google Research presented this novel application of semi-supervised learning (SSL) to train models for damage assessment with a minimal amount of labeled data and large amount of unlabeled data in [June 2020](https://ai.googleblog.com/2020/06/machine-learning-based-damage.html). Using the state-of-the-art methods including [MixMatch](https://arxiv.org/abs/1905.02249) and [FixMatch](https://arxiv.org/abs/2001.07685), they compare the performance with supervised baseline for the 2010 Haiti earthquake, 2017 Santa Rosa wildfire, and 2016 armed conflict in Syria.

![SSL Approach](https://storage.googleapis.com/skai-public/ssl_diagram.png)

The [paper](https://arxiv.org/abs/2011.14004) published by *Jihyeon Lee, Joseph Z. Xu, Kihyuk Sohn, Wenhan Lu, David Berthelot, Izzeddin Gur, Pranav Khaitan, Ke-Wei, Huang, Kyriacos Koupparis, Bernhard Kowatsch* shows how models trained with SSL methods can reach fully supervised performance despite using only a fraction of labeled data.


In [None]:
#@title Please run this cell first!

import base64
import collections
import datetime
import io
import json
import os
import re
import subprocess
import time

import ee
import folium
import ipyplot
import IPython.display
import pandas as pd
import pexpect
import pprint
import pyproj
import pytz
import requests
import smtplib
import ssl
import tensorflow as tf

from os import path
from google.appengine.api import mail
from google.cloud import monitoring_v3
from IPython.display import display, HTML, Javascript
from PIL import Image, ImageDraw, ImageFont

def launch_pexpect_process(venv_dir, skai_dir, script, arguments, use_pexpect):
  flags_str = ' '.join(f"--{f}='{v}'" for f, v in arguments.items())
  commands = '; '.join([
      f'set -e',
      f'source {venv_dir}/bin/activate',
      f'export GOOGLE_APPLICATION_CREDENTIALS=/root/service-account-private-key.json',
      f'python {skai_dir}/src/{script} {flags_str}'])
  sh_command = f'bash -c "{commands}" | tee /tmp/output.txt'
  if use_pexpect:
    return pexpect.spawn(sh_command)
  else:
    with open('/tmp/shell_command.sh', 'w') as f:
      f.write(commands)
    !bash "/tmp/shell_command.sh" | tee /tmp/output.txt

def make_gcp_http_request(url):
  variable=!(gcloud auth print-access-token)
  token=variable[0]
  response = requests.get(url=url, headers = {"Authorization": "Bearer {token}".format(token=token)})
  if not response.ok:
    response.raise_for_status()
  return response.json()

def bucket_exists(project, bucket_name):
  url = f'https://storage.googleapis.com/storage/v1/b?project={project}'
  data = make_gcp_http_request(url)
  buckets = [item['name'] for item in data['items'] if item['kind'] == 'storage#bucket']
  return (bucket_name in buckets)

def create_bucket(project, location, bucket_name):
  %shell gsutil mb -p $project -l $location -b on gs://{bucket_name}

def progress(value, max=100):
  css = """
        <style>
          progress {
            border-radius: 7px;
            box-shadow: 1px 1px 2px rgba(0, 0, 0, 0.5) inset;
            width: 80%;
            height: 30px;
            display: block;
          }
          progress::-webkit-progress-bar {
            background-color: rgba(237, 237, 237, 0);
            border-radius: 7px;
          }
          progress::-webkit-progress-value {
            background-color: green;
            border-radius: 7px;
            box-shadow: 1px 1px 1px rgba(0, 0, 0, 0.1) inset;
          }
        </style>
        """
  html = """
          <progress
              value='{value}'
              max='{max}'
          >
            {value}%
          </progress>
        """.format(value=value, max=max)
  return HTML(css + html)


## Notebook Setup 📓

Specify the variables to set your damage assessment project and press play:

In [None]:
#############################################
### INITIAL SETTING - PROJECT DESCRIPTION ###
#############################################
#@markdown ---
#@markdown Please enter here the parameters for your **project desciption**

#@markdown ---
Disaster = 'Cyclone' #@param ["Cyclone", "Earthquake", "Tsunami", "Flood", "Eruption", "Tornado", "Wind", "Wildfire", "Landslide", "Conflict"]
Year =  #@param {type:"integer"}
Month =   #@param {type:"integer"}
Name = '' #@param {type:"string"}
Country = '' #@param {type:"string"}
Organisation = '' #@param {type:"string"}
Author = '' #@param {type:"string"}
Run = '' #@param {type:"string"}

Project_description= f"{Organisation}-{Disaster}-{Name}-{Country}-{Year}{Month}_{Run}".lower()
print(f"\nYour project description: {Project_description}")

currentDateTime = datetime.datetime.now()
date = currentDateTime.date()
timestamp=currentDateTime.strftime('%Y%m%d%H%M%S')
year = date.strftime("%Y")

Tool="Skai"
Env="Test"
Version="102"

root_filesys='/content'
GD_DIRECTORY = f"{Project_description}".lower()

#############################################
### CLOUD SETTING - PROJECT CONFIGURATION ###
#############################################
#markdown ---
#markdown Please enter here the project and location names you want to use in your **google cloud platform account**

#markdown ---
GCP_PROJECT = "skai-2022" #param {type:"string"}
GCP_LOCATION = "europe-west1" #param {type:"string"}
GCP_LOCATION_LABELING=GCP_LOCATION
if "europe-" in GCP_LOCATION :
  GCP_LOCATION_LABELING= "europe-west4"
  if GCP_LOCATION!= "europe-west1" :
    GCP_LOCATION= "europe-west1"
    print(f"\nLocation region has been changed to {GCP_LOCATION} (Vertex AI features availability) ")
if "us-" in GCP_LOCATION :
  GCP_LOCATION_LABELING= "us-central1"
  if GCP_LOCATION!= "us-central1" :
    GCP_LOCATION= "us-central1"
    print(f"\nLocation region has been changed to {GCP_LOCATION} (Vertex AI features availability) ")

url='https://cloudresourcemanager.googleapis.com/v1/projects/{}'.format(GCP_PROJECT)
data = make_gcp_http_request(url)
GCP_PROJECT_ID=int(data['projectNumber'])

GCP_BUCKET = f"{Tool}{year}-Bucket-{Env}{Version}_{Author}".lower()
if not bucket_exists(GCP_PROJECT, GCP_BUCKET):
  create_bucket(GCP_PROJECT, GCP_BUCKET)

print(f"\nYour project bucket in Google Cloud: {GCP_BUCKET} \nhttps://console.cloud.google.com/storage/browser/{GCP_BUCKET}")

pathgcp_outputdir=os.path.join(GCP_BUCKET,GD_DIRECTORY)

emailgcp_serviceaccount = 'skai-colab@skai-2022.iam.gserviceaccount.com'

os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/root/service-account-private-key.json'

#############################################
### CODE SETTING - ENVIRONMENT ACTIVATION ###
#############################################
pathsys_venv=os.path.join(root_filesys,'skai-env')
pathsys_actenv=os.path.join(pathsys_venv, 'bin/activate')

pathsys_skai=os.path.join(root_filesys, 'skai-src')

#########################################
### IMAGE SETTING - FILE & DIRECTORY ###
#########################################
#@markdown ---
#@markdown Please enter the path to the files of pre and post disaster satellite images and area of interest:

#@markdown ---
FILE_IMAGE_BEFORE = '' #@param {type:"string"}
FILE_IMAGE_AFTER = '' #@param {type:"string"}
FILE_IMAGE_AOI = '' #@param {type:"string"}

#@markdown ---
#@markdown Choose where to get building footprints from:
BUILDING_DETECTION_METHOD = "open_street_map" #@param ["open_street_map", "file"]
#@markdown If you chose "file", please enter path to CSV file here:
BUILDINGS_CSV = '' #@param {type:"string"}

pathgcp_imagebefore=FILE_IMAGE_BEFORE.replace('gs://','')
pathgcp_imageafter=FILE_IMAGE_AFTER.replace('gs://','')
pathgcp_aoi=FILE_IMAGE_AOI.replace('gs://','')

#########################################
### EXAMPLE SETTING - CLOUD DIRECTORY ###
#########################################
pathgcp_examples=os.path.join(pathgcp_outputdir,'examples')
pathgcp_importfile=os.path.join(pathgcp_examples,'labeling_images/import_file.csv')

###########################################
### LABELING SETTING - EMAIL PARAMETERS ###
###########################################
#@markdown ---
#@markdown Provide email addresses for all individuals that will help with labeling images, separated by commas.
#@markdown Emails of the labelers need to be linked to a google account.

#@markdown ---
EMAIL_MANAGER = '' #@param {type:"string"}
EMAIL_ANNOTATORS = '' #@param {type:"string"}

if EMAIL_MANAGER.strip() in EMAIL_ANNOTATORS:
  EMAIL_ANNOTATORS.replace(EMAIL_MANAGER.strip(), '')
GCP_LABELER_EMAIL = [EMAIL_MANAGER.strip()] + [email.strip() for email in EMAIL_ANNOTATORS.split(',')]
GCP_LABELER_EMAIL = ','.join(GCP_LABELER_EMAIL)

################################################
### DATASET SETTING - FILE & CLOUD DIRECTORY ###
################################################
pathgcp_temp=os.path.join(pathgcp_outputdir,'temp')
pathgcp_unlabeled=os.path.join(pathgcp_examples,'unlabeled/*.tfrecord')

pathgcp_trainset=os.path.join(pathgcp_examples,'labeled_train_examples.tfrecord')
pathgcp_testset=os.path.join(pathgcp_examples,'labeled_test_examples.tfrecord')

#######################################
### MODEL SETTING - FILE & DIRECTORY ##
#######################################
pathsys_runjobs=os.path.join(root_filesys,'run_jobs')
if not os.path.exists(pathsys_runjobs):
  os.mkdir(pathsys_runjobs)

pathgcp_models=os.path.join(pathgcp_outputdir,'models')


In [None]:
#@title Optimize images for cloud

pathsys_images=os.path.join(root_filesys,'images')
pathsys_imagesfolder=os.path.join(pathsys_images,f"{Author}_{GD_DIRECTORY}")
if path.exists(pathsys_images) == False:
  os.mkdir(pathsys_images)
  os.mkdir(pathsys_imagesfolder)
else :
  if path.exists(pathsys_imagesfolder) == False:
    os.mkdir(pathsys_imagesfolder)

temp_dir_image = os.path.join(pathsys_imagesfolder, 'tmp')
validator_file = os.path.join(temp_dir_image, 'validate_cog.py')
temp_out_file = os.path.join(temp_dir_image, 'temp_out_file.txt')

pre_image = os.path.join(pathsys_imagesfolder, FILE_IMAGE_BEFORE.split('/')[-1])
pre_image_copy = os.path.join(pathsys_imagesfolder, FILE_IMAGE_BEFORE.split('/')[-1].split('.tif')[0] + '_copy.tif')
post_image = os.path.join(pathsys_imagesfolder, FILE_IMAGE_AFTER.split('/')[-1])
post_image_copy = os.path.join(pathsys_imagesfolder, FILE_IMAGE_AFTER.split('/')[-1].split('.tif')[0] + '_copy.tif')

pathgcp_images=os.path.join(pathgcp_outputdir,'images')

!gsutil -m cp gs://{pathgcp_images}/* {pathsys_imagesfolder}/ 

def write_optimize_images_launch_script(**args):
  submission_ending='''
mkdir -p {temp_dir_image}
curl -s 'https://raw.githubusercontent.com/OSGeo/gdal/master/swig/python/gdal-utils/osgeo_utils/samples/validate_cloud_optimized_geotiff.py' > {validator_file}

python3.6 {validator_file} {pre_image} | tee {temp_out_file}
if grep -q 'NOT a valid' {temp_out_file}; then
  cp {pre_image} {pre_image_copy} 
  echo 'Converting pre_disaster image to COG...'
  gdaladdo -r average {pre_image_copy} 2 4 8 16
  gdal_translate {pre_image_copy} {pre_image} -co COMPRESS=LZW -co TILED=YES 
fi

python3.6 {validator_file} {post_image} | tee {temp_out_file}
if grep -q 'NOT a valid' {temp_out_file}; then
  cp {post_image} {post_image_copy}
  echo 'Converting post_disaster image to COG...'
  gdaladdo -r average {post_image_copy} 2 4 8 16
  gdal_translate {post_image_copy} {post_image} -co COMPRESS=LZW -co TILED=YES
fi

rm -rf {temp_dir_image}'''.format(**args)  
  
  with open(args['path_run'], 'w+') as file:
    file.write(submission_ending)

timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
file_runjob=f'run_optimize_as_child_process_{Author}_{timestamp}_{Project_description}.sh'
pathsys_runfile=os.path.join(pathsys_runjobs,file_runjob)

generate_script_args={
    'validator_file':validator_file,
    'temp_out_file':temp_out_file,
    'pre_image':pre_image,
    'pre_image_copy':pre_image_copy,
    'post_image':post_image,
    'post_image_copy':post_image_copy,
    'temp_dir_image':temp_dir_image,
    'path_run': pathsys_runfile,
}

write_optimize_images_launch_script(**generate_script_args)

!bash {pathsys_runfile}

!gsutil -m cp {pre_image} {FILE_IMAGE_BEFORE}
!gsutil -m cp {post_image} {FILE_IMAGE_AFTER}


In [None]:
#@title Visualize before and after images

# Add custom basemaps to folium.
basemaps = {
    'Google Maps': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Maps',
        overlay = True,
        control = True
    ),
    'Google Satellite': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Satellite',
        overlay = True,
        control = True
    ),
    'Google Terrain': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=p&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Terrain',
        overlay = True,
        control = True
    ),
    'Google Satellite Hybrid': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Satellite',
        overlay = True,
        control = True
    ),
    'Esri Satellite': folium.TileLayer(
        tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        attr = 'Esri',
        name = 'Esri Satellite',
        overlay = True,
        control = True
    )
}

def create_folium_map_with_images():
  # Load before image and get latitude/longitude of map center.
  before_image_path = 'gs://'+pathgcp_imagebefore
  before_map = ee.Image.loadGeoTIFF(before_image_path)
  before_map_id_dict = before_map.getMapId()
  x = before_map.getInfo()['bands'][0]['crs_transform'][2]
  y = before_map.getInfo()['bands'][0]['crs_transform'][-1]
  dim_x, dim_y = before_map.getInfo()['bands'][0]['dimensions']
  crs = before_map.getInfo()['bands'][0]['crs'].split(':')[-1]
  proj = pyproj.Transformer.from_crs(int(crs), 4326, always_xy=True)
  lon, lat = proj.transform(x + int(dim_x / 4), y - int(dim_y / 4))

  # Create a folium map object. Location is latitude, longitude.
  my_map = folium.Map(location=[lat, lon], zoom_start=12, max_zoom=25)

  # Add before and after disaster imagery.
  folium.raster_layers.TileLayer(
      tiles=before_map_id_dict['tile_fetcher'].url_format,
      attr='COG',
      name = 'Pre-Disaster Imagery',
      overlay = True,
      control = True,
      max_zoom = 25,
    ).add_to(my_map)

  after_image_path = 'gs://'+pathgcp_imageafter
  after_map_id_dict = ee.Image.loadGeoTIFF(after_image_path).getMapId()
  folium.raster_layers.TileLayer(
      tiles=after_map_id_dict['tile_fetcher'].url_format,
      attr='COG',
      name = 'Post-Disaster Imagery',
      overlay = True,
      control = True,
      max_zoom = 25,
    ).add_to(my_map)

  my_map.add_child(folium.LayerControl())
  IPython.display.display(my_map)

display(Javascript("google.colab.output.resizeIframeToContent()"))


# Prepare credentials for map visualization.
service_account = 'skai-colab@skai-2022.iam.gserviceaccount.com'
credentials = ee.ServiceAccountCredentials(
    service_account, '/root/service-account-private-key.json')
ee.Initialize(credentials)

create_folium_map_with_images()


## Data labeling 👷

Create examples of buildings images before and after the disaster and classify them as either undamaged, possibly damaged, damaged/destroyed, or bad example (e.g., cloud cover etc.)

First, generate the building images, this task should take about 30 minutes.

In [None]:
#@title Generate Examples

## CLASS DEFINITION

class DataflowMetricFetcher:
  def __init__(self, project_id: str, job_name: str, metric_name: str):
    self._client = monitoring_v3.MetricServiceClient()
    self._project_id = project_id
    self._job_name = job_name
    self._metric_name = metric_name
    self._filter = self.make_filter()

  def make_filter(self):
    conditions = [
        'resource.type = "dataflow_job"',
        f'resource.labels.project_id = "{self._project_id}"',
        f'resource.labels.job_name = "{self._job_name}"',
        'metric.name = "dataflow.googleapis.com/job/user_counter"',
        f'metric.labels.metric_name = "{self._metric_name}"',
    ]
    return '({})'.format(' AND '.join(conditions))

  def get_latest_value(self):
    end_seconds = int(time.time())
    start_seconds = 1
    interval = monitoring_v3.TimeInterval({
        'start_time': { 'seconds': start_seconds},
        'end_time': { 'seconds': end_seconds }
    })
    request = {
        'name': f'projects/{self._project_id}',
        'filter': self._filter,
        'interval': interval,
        'view': monitoring_v3.ListTimeSeriesRequest.TimeSeriesView.FULL
    }
    results = self._client.list_time_series(request)
    if (len(results._response.time_series) == 0 or
        len(results._response.time_series[0].points) == 0):
      return None, None

    latest_point = results._response.time_series[0].points[0]
    return latest_point.interval.end_time, latest_point.value.double_value

class ProgressBar:
  def __init__(self, max):
    self._display = display(self.get_html(0, max), display_id=True)

  def get_html(self, value, max):
    return HTML(f'Num generated examples: {value}/{max}<progress value="{value}" max="{max}" style="width: 100%">{value}</progress>')

  def update(self, num_examples, max):
    self._display.update(self.get_html(num_examples, max))

def parse_dataflow_job_creation_params(param_str: str):
  params_dict = {}
  lines = [line.strip() for line in param_str.split('\r\n')]
  for line in lines:
    if not line:
      continue
    pieces = line.strip().split(':')
    key = pieces[0].strip()
    value = pieces[1].strip().strip("'")
    params_dict[key] = value
  return params_dict

def run_example_generation(generate_examples_args, pretty_output=True):
  if not pretty_output:
    launch_pexpect_process(
        pathsys_venv, 
        pathsys_skai,
        'generate_examples_main.py',
        generate_examples_args,
        use_pexpect=False)
    return

  progress_bar = ProgressBar(1)

  child = launch_pexpect_process(
      pathsys_venv, 
      pathsys_skai,
      'generate_examples_main.py',
      generate_examples_args,
      use_pexpect=True)

  JOB_CREATION_PATTERN = 'Create job: <Job(.*clientRequestId:.*)>'
  BUILDINGS_MATCHED_PATTERN = 'Found ([0-9]+) buildings in area of interest.'

  num_buildings = 1
  while child.isalive():
    i = child.expect([BUILDINGS_MATCHED_PATTERN, JOB_CREATION_PATTERN, pexpect.EOF], timeout=600)
    if i == 0:
      num_buildings = int(child.match.group(1))
      print(f'Found {num_buildings} buildings in area of interest.')
      progress_bar.update(0, num_buildings)
    elif i == 1:
      job_params = parse_dataflow_job_creation_params(child.match.group(1).decode())
      job_name = job_params['name']
      job_id = job_params['id']
      job_location = job_params['location']
      job_project = job_params['projectId']
      job_status_pattern = f'Job {job_id} is in state JOB_STATE_([A-Z]+)'
      print(f'Detailed monitoring page: https://console.cloud.google.com/dataflow/jobs/{job_location}/{job_id}?project={job_project}')
      break
    else:
      print(child.before.decode())
      child.close()
      raise Exception('Job terminated unexpectedly.')

  generated_examples_metric = DataflowMetricFetcher(job_project, job_name, 'generated_examples_count')
  rejected_examples_metric = DataflowMetricFetcher(job_project, job_name, 'rejected_examples_count')

  job_state = None
  while child.isalive():
    i = child.expect([job_status_pattern, pexpect.TIMEOUT, pexpect.EOF], timeout=15)
    if i == 0:
      job_state = child.match.group(1).decode()
      print(f'Dataflow job state: {job_state}')
    elif i == 1 or i == 2:
      if job_state == 'RUNNING':
        examples_processed = 0
        t, v = generated_examples_metric.get_latest_value()
        if t:
          examples_processed += int(v)
        t, v = rejected_examples_metric.get_latest_value()
        if t:
          examples_processed += int(v)
        progress_bar.update(examples_processed, num_buildings)
      if i == 2:
        child.close()
        break

## COMMAND RUN
generate_examples_args = {
    'cloud_project': GCP_PROJECT,
    'cloud_region': GCP_LOCATION,
    'before_image_path': f'gs://{pathgcp_imagebefore}',
    'after_image_path': f'gs://{pathgcp_imageafter}',
    'aoi_path': f'gs://{pathgcp_aoi}',
    'output_dir': f'gs://{pathgcp_outputdir}',
    'buildings_method': BUILDING_DETECTION_METHOD,
    'buildings_file': BUILDINGS_CSV,
    'worker_service_account': 'skai-colab@skai-2022.iam.gserviceaccount.com',
    'use_dataflow': 'true',
    'num_labeling_examples': 1000
}

run_example_generation(generate_examples_args)

def count_tfrecord(path):
  pre_images = []
  post_images = []
  labels = []
  labels_split=[]
  total_example_num=len(list(tf.data.TFRecordDataset(path)))
  return total_example_num

total_example_counter=0
for k in range(20):
  file_directory='unlabeled/unlabeled-000{:02d}-of-00020.tfrecord'.format(k)
  tfrecord_path = os.path.join('gs://',pathgcp_examples,file_directory)
  total_example_counter+=count_tfrecord(tfrecord_path)

print('{} building examples were extracted in total from the Area Of Interest'.format(total_example_counter))

Second, we create the labeling tasks for the labelers. This task should take about 15 minutes.
At the end of this step you and each labelers will receive an email with the instruction on how to perform the labeling task.

In [None]:
#@title Create Labeling Task

class LabelingJob:
  def __init__(self, endpoint, project, location, labeling_job):
    self._endpoint = endpoint
    self._project = project
    self._location = location
    self._labeling_job = labeling_job
    self._access_token = self.get_access_token()

    job_info = self.get_info()
    self._dataset = job_info['datasets'][0]

    assert len(job_info['specialistPools']) == 1
    # Has the format projects/{project_id}/locations/{location}/specialistPools/{pool_id}
    parts = job_info['specialistPools'][0].split('/')
    assert len(parts) == 6
    assert parts[4] == 'specialistPools'
    self._pool_id = parts[5]
    
  def get_access_token(self):
    return subprocess.check_output('gcloud auth print-access-token'.split()).decode().rstrip('.\r\n')
  
  def get_header(self):
    return {
      'Authorization': f'Bearer {self._access_token}',
      'Content-Type': 'application/json',
    }
  
  def get_info(self):
    '''Return the percentage of data items labeled.

    Warning: There is a long lag between when items are labeled and when this
    value is updated.
    '''
    parent = f'projects/{self._project}/locations/{self._location}/dataLabelingJobs/{self._labeling_job}'
    url = f'https://{self._endpoint}/v1/{parent}'
    header = self.get_header()
    r = requests.get(url, headers=header)
    if not r.ok:
      r.raise_for_status()
    return r.json()

  def get_completion_percentage(self):
    '''Return the percentage of data items labeled.

    Warning: There is a long lag between when items are labeled and when this
    value is updated.
    '''
    info = self.get_info()
    return info.get('labelingProgress', 0)
    
  def get_data_items(self):
    parent = f'projects/{self._project}/locations/{self._location}/datasets/{self._dataset}/dataItems'
    url = f'https://{self._endpoint}/v1/{parent}'
    items = []
    page_token = None
    header = self.get_header()
    while True:
      if page_token:
        r = requests.get(url, headers=header, params={'pageToken': page_token})
      else:
        r = requests.get(url, headers=header)
      if not r.ok:
        r.raise_for_status()

      result_json = r.json()
      items.extend(result_json['dataItems'])
      if 'nextPageToken' in result_json:
        page_token = result_json['nextPageToken']
      else:
        break
    return items

  def get_labels(self, data_item_name):
    url = f'https://{self._endpoint}/v1/{data_item_name}/annotations'
    header = self.get_header()
    r = requests.get(url, headers=header)
    if not r.ok:
      r.raise_for_status()
    json = r.json()
    labels = []
    if 'annotations' in json:
      for a in json['annotations']:
        labels.append(a['payload']['displayName'])
    return labels

  def get_worker_url(self):
    '''Returns the URL workers can use to access the labeling interface.

    The syntax of the URL was determined by reverse engineering, so there's no
    guarantee that it won't change in the future.
    '''
    location = self._location.replace('-', '_')
    return f'https://datacompute.google.com/w/cloudml_data_specialists_{location}_{self._pool_id}'

  def get_manager_url(self):
    '''Returns the URL managers can use to access the task management interface.

    The syntax of the URL was determined by reverse engineering, so there's no
    guarantee that it won't change in the future.
    '''
    location = self._location.replace('-', '_')
    return f'https://datacompute.google.com/cm/cloudml_data_specialists_{location}_{self._pool_id}/tasks'

def run_labeling_task_creation(create_label_task_args, pretty_output=True):
  if not pretty_output:
    launch_pexpect_process(
        pathsys_venv, pathsys_skai, 'create_cloud_labeling_task.py',
        create_label_task_args, False)
    return None

  child = launch_pexpect_process(
      pathsys_venv, pathsys_skai, 'create_cloud_labeling_task.py',
      create_label_task_args, True)

  DATASET_CREATED_PATTERN = 'ImageDataset created. Resource name: projects/[^/]+/locations/[^/]+/datasets/([0-9]+)'
  LABELING_JOB_CREATED_PATTERN = 'Data labeling job created:'

  output = b''
  try:
    while child.isalive():
      i = child.expect(
          [DATASET_CREATED_PATTERN,
           LABELING_JOB_CREATED_PATTERN,
           pexpect.EOF,
           pexpect.TIMEOUT], timeout=1800)
      if isinstance(child.before, bytes):
        output += child.before
      if isinstance(child.after, bytes):
        output += child.after
      if i == 0:
        dataset_id = child.match.group(1).decode()
      elif i == 1:
        print('Data labeling job created.')
      elif i == 2:
        break
      else:
        raise Exception('Job timed out. Full output:\n' + output.decode())
  finally:
    child.close()

  return dataset_id

## COMMAND RUN

timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
GCP_DATASET_NAME = f"{Author}_label_{timestamp}_{Project_description}"

create_labeling_task_args = {
    'cloud_project':GCP_PROJECT,
    'cloud_location':GCP_LOCATION_LABELING,
    'dataset_name': GCP_DATASET_NAME,
    'import_file': f'gs://{pathgcp_importfile}',
    'cloud_labeler_emails': GCP_LABELER_EMAIL
    }

print('Creating data labeling job.')
GCP_DATASET_ID = run_labeling_task_creation(create_labeling_task_args)

url='https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/dataLabelingJobs'.format(GCP_LOCATION_LABELING,GCP_PROJECT,GCP_LOCATION_LABELING)
data = make_gcp_http_request(url)
data = list(filter(lambda d: GCP_DATASET_NAME in d['displayName'], data['dataLabelingJobs']))[0]

GCP_DATASET_ID = int(data['datasets'][0].split('/')[-1])
GCP_DATASET_NAME = data['displayName']
GCP_LABELING_JOB= int(data['name'].split('/')[-1])
GCP_LABELING_INSTRUCTION= data['instructionUri']

print(f'\nLabeling dataset {GCP_DATASET_NAME} created, with ID {GCP_DATASET_ID}')
print(f'\nData Labeling job {GCP_LABELING_JOB} created')

labeling_job = LabelingJob(f'{GCP_LOCATION_LABELING}-aiplatform.googleapis.com', 
                           GCP_PROJECT, GCP_LOCATION_LABELING, GCP_LABELING_JOB)
print('Instruction URL: {}'.format(GCP_LABELING_INSTRUCTION.replace('gs://','https://storage.cloud.google.com/')))
print(f'Worker URL: {labeling_job.get_worker_url()}')
print(f'Manager URL: {labeling_job.get_manager_url()}')
print(f'Detailed monitoring page: https://console.cloud.google.com/vertex-ai/locations/{GCP_LOCATION_LABELING}/labeling-tasks/{GCP_LABELING_JOB}?project={GCP_PROJECT}')

As a manager of the task, you can track the labeling progress by running this script below and see how many labels were created or view the detailed monitoring page. For good quality we recommend having about 200 building labels from the damaged/destroyed and undamaged categories.

In [None]:
#@title Monitor Labeling Task
labeling_job = LabelingJob(f'{GCP_LOCATION_LABELING}-aiplatform.googleapis.com', 
                           GCP_PROJECT, GCP_LOCATION_LABELING, GCP_LABELING_JOB)
print(f'\nJob completion percentage: {labeling_job.get_completion_percentage()}%')

## Create and inspect training and evaluation examples 🧩

Run this script to assign the labeled images to training and evaluation datasets.

In [None]:
#@title Create training and evaluation datasets

def create_labeled_dataset():
  child = launch_pexpect_process(pathsys_venv, pathsys_skai, 'create_labeled_dataset.py', {
      "cloud_project": GCP_PROJECT,
      "cloud_location": GCP_LOCATION_LABELING,
      "cloud_dataset_id": GCP_DATASET_ID,
      "cloud_temp_dir": 'gs://' + pathgcp_temp,
      "examples_pattern": 'gs://' + pathgcp_unlabeled,
      "train_output_path": 'gs://' + pathgcp_trainset,
      "test_output_path": 'gs://' + pathgcp_testset}, True)

  print('Creating labeled datasets...')
  child.expect(pexpect.EOF, timeout=None)
  child.close()
  if child.exitstatus != 0:
    print('An unexpected error occurred. Output of command was:')
    print(child.before.decode())
  else:
    print('Labeled dataset created.')

create_labeled_dataset()

(Optional) You can run the following script to inspect both datasets.


In [None]:
#@title Inspect the training dataset

## CLASS DEFINTION

#! pip install ipyplot

def concat_caption_pilimage(image_before, image_after):
  img_before=caption_pilformat(image_before, "before")
  img_after=caption_pilformat(image_after, "after")

  w, h=img_before.size

  img_concat = Image.new('RGB', (2*w, h), "white")
  img_concat.paste(img_before, (0, 0))
  img_concat.paste(img_after, (w, 0))

  return img_concat

def caption_pilformat(img_data, caption):
  base64_encoded = base64.b64encode(img_data)
  im_bytes = base64.b64decode(base64_encoded)
  byte_encoded=io.BytesIO(im_bytes)

  img=Image.open(byte_encoded)
  wd, hg =img.size

  img_ = Image.new('RGB', (wd+int(wd/10), hg+int(hg/5)), "white")
  img_.paste(img, (int(wd/20),int(hg/5)))

  wd, hg =img_.size
  img_cap = ImageDraw.Draw(img_)
  w, h = img_cap.textsize(caption)
  img_cap.text(((wd-w)/2,0), caption, fill=(0, 0, 0))

  return img_

def ipyplot_tfrecord(path, max_examples=None):
  pre_images = []
  post_images = []
  labels = []
  labels_split=[]
  total_example_num=len(list(tf.data.TFRecordDataset(path)))
  print('Number of examples: {}.'.format(total_example_num))

  if max_examples==None:
    max_examples=total_example_num

  for record in tf.data.TFRecordDataset(path):
    e = tf.train.Example()
    e.ParseFromString(record.numpy())
    labels_split.append(e.features.feature['label'].float_list.value[0])
    if len(pre_images) < max_examples:
      pre_images.append(e.features.feature['pre_image_png'].bytes_list.value[0])
      post_images.append(e.features.feature['post_image_png'].bytes_list.value[0])
      labels.append(e.features.feature['label'].float_list.value[0])

  labels_counter=dict(collections.Counter(labels_split))
  map_value = {0: 'Undamaged/bad examples {}/{}'.format(int(len(labels)-sum(labels)),labels_counter[0]),
               1: 'Damaged {}/{}'.format(int(sum(labels)),labels_counter[1])}
  labels=list((pd.Series(labels)).map(map_value))
  
  images=[concat_caption_pilimage(pre_images[idx], post_images[idx]) for idx in range(len(pre_images))]

  ipyplot.plot_class_tabs(images, labels,max_imgs_per_tab=max_examples, tabs_order=[map_value[1],map_value[0]], img_width=200)

  return total_example_num

## COMMAND RUN
COUNT_TRAIN_LABELED=ipyplot_tfrecord(os.path.join("gs://",pathgcp_trainset),max_examples=40)


In [None]:
#@title Inspect the evaluation dataset

COUNT_TEST_LABELED=ipyplot_tfrecord(os.path.join("gs://",pathgcp_testset),max_examples=40)

## Model training, performance evaluation 🤖

Please run the following script to train the machine learning model and test it using the evaluation dataset (leveraging the examples you labeled).

The script runs in the background and may take up to 6 hours. You will be able to see the progress on this page and we will also send you an email when this step is done.

In [None]:
#@title Train and evaluate model

#ML_METHOD = 'fixmatch'
#HP_EPOCH=128
#HP_UNLABELEDRATIO=1
#HP_BATCH=16

#HP_TRAINKIMG=int(HP_BATCH*COUNT_TRAIN_LABELED*HP_EPOCH/1024)

def write_train_and_eval_launch_script(**args):
  args['hyper_parameters_args']=''

  submission_ending = '''
export GOOGLE_APPLICATION_CREDENTIALS=/root/service-account-private-key.json

source {python_env} ; python {path_skai}/src/launch_vertex_job.py \\
  --location={cloud_region} \\
  --project={cloud_project} \\
  --job_type=train \\
  --display_name={display_name_train} \\
  --dataset_name={dataset_name} \\
  --train_worker_machine_type=n1-highmem-8 \\
  --train_docker_image_uri_path={train_docker_image_uri_path} \\
  --service_account={service_account} \\
  --train_dir={train_dir} \\
  --train_unlabel_examples={train_unlabel_examples} \\
  --train_label_examples={train_label_examples} \\
  --test_examples={test_examples} & \\
sleep 60 ; python {path_skai}/src/launch_vertex_job.py \\
  --location={cloud_region} \\
  --project={cloud_project} \\
  --job_type=eval \\
  --display_name={display_name_eval} \\
  --dataset_name={dataset_name} \\
  --eval_docker_image_uri_path={eval_docker_image_uri_path} \\
  --service_account={service_account} \\
  --train_dir={train_dir} \\
  --train_unlabel_examples={train_unlabel_examples} \\
  --train_label_examples={train_label_examples} \\
  --test_examples={test_examples}'''.format(**args)

  with open(args['path_run'], 'w+') as file:
    file.write(submission_ending)

def metrics(train_label_acc, train_label_auc, test_acc, test_auc):
  html = """
         <h2>Metrics (updated as training progresses):</h2>
         <h3>Labeled Training Set</h3> 
         <p>Accuracy: {train_label_acc}% | AUC: {train_label_auc}</p>
         <h3>Test Set</h3>
         <p>Accuracy: {test_acc}% | AUC: {test_auc}</p>
        """.format(train_label_acc=train_label_acc, train_label_auc=train_label_auc, test_acc=test_acc, test_auc=test_auc)
  return HTML(html)


def timestamp_to_datetime(timestamp):
  return pd.to_datetime(timestamp)

timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
GCP_EXPERIMENT_NAME=f"{Author}_experiment_{timestamp}_{Project_description}"
GCP_TRAINJOB_NAME=f"{Author}_train_{timestamp}_{Project_description}"
GCP_EVALJOB_NAME=f"{Author}_eval_{timestamp}_{Project_description}"

jobgcp_exper=GCP_EXPERIMENT_NAME+'_default'
jobgcp_train=GCP_TRAINJOB_NAME+'_default'
jobgcp_eval=GCP_EVALJOB_NAME+'_default'

pathgcp_exper=os.path.join(pathgcp_models, jobgcp_exper)

file_runjob=f'run_jobs_as_child_process_{Author}_{timestamp}_{Project_description}.sh'
pathsys_runfile=os.path.join(pathsys_runjobs,file_runjob)

generate_script_args={   
    'cloud_project':GCP_PROJECT,
    'cloud_region':GCP_LOCATION,
    'train_docker_image_uri_path':'gcr.io/disaster-assessment/ssl-train-uri',
    'eval_docker_image_uri_path':'gcr.io/disaster-assessment/ssl-eval-uri',
    'service_account':emailgcp_serviceaccount,
    'dataset_name':jobgcp_exper,
    'train_dir':f'gs://{pathgcp_exper}',
    'train_unlabel_examples':f'gs://{pathgcp_unlabeled}',
    'train_label_examples':f'gs://{pathgcp_trainset}',
    'test_examples':f'gs://{pathgcp_testset}',
    'display_name_train':jobgcp_train,
    #'method':ML_METHOD,
    #'unlabeled_ratio':HP_UNLABELEDRATIO,
    #'batch':HP_BATCH,
    #'train_kimg':HP_TRAINKIMG,
    'display_name_eval':jobgcp_eval,
    'python_env':pathsys_actenv,
    'path_skai':pathsys_skai,
    'path_run': pathsys_runfile,
}

write_train_and_eval_launch_script(**generate_script_args)
print(f"\nYour Custom Training job is :\n{jobgcp_train}")
print(f"\nYour Evaluation job is :\n{jobgcp_eval}\n")


# Create the progress bar and metrics displays.
progress_display = display(progress(0, 100), display_id=True)
metrics_display = None

# Store Job IDs of training and evaluation jobs.
# Keep track of the timestamp of most recent logs to process only fresher logs. 
train_job_id = None
eval_job_id = None
curr_epoch = None
total_num_epochs = None
train_most_recent_timestamp = pd.Timestamp.utcnow() 
eval_most_recent_timestamp = pd.Timestamp.utcnow() 


def update_job_id(job_id):
  global train_job_id
  global eval_job_id
  if train_job_id is None:
    train_job_id = job_id
    progress_display.update(progress(1, 100))
  elif eval_job_id is None:
    if job_id != train_job_id:
      eval_job_id = job_id 


# Run the child program.
child = pexpect.spawn(f'sh {pathsys_runfile}')
while not child.closed:
  # Expects 5 different patterns, or EOF (meaning the program terminated).
  # Each pattern is a regex and you can use regex match groups "()" to extract a
  # part of the matched text for later use.
  pattern_idx = child.expect([
    'I.*] CustomJob created\. Resource name: .*/([0-9]*)',
    'I.*] CustomJob .*/([0-9]*) current state:\r\nJobState.JOB_STATE_PENDING',
    'I.*] CustomJob .*/([0-9]*) current state:\r\nJobState.JOB_STATE_RUNNING',
    'I.*] CustomJob run completed.',
    pexpect.EOF], timeout=None)
  if pattern_idx == 0:  # A job was created, so store its ID.
    job_id = child.match.group(1).decode()
    update_job_id(job_id)
  elif pattern_idx == 1:  # Jobs are pending, so update progress bar. 
    job_id = child.match.group(1).decode()
    if job_id == train_job_id:
      progress_display.update(progress(5, 100))
  elif pattern_idx == 2:  # Jobs are running, so update progress bar or metrics.
    job_id = child.match.group(1).decode()
    get_logs_status = os.system(f"""gcloud logging read 'resource.labels.job_id={job_id} severity=ERROR "Epoch"' --format json > /tmp/{job_id}_log""")
    if get_logs_status == 0:
      with open(f'/tmp/{job_id}_log', 'r') as log_file:
        log_data = json.load(log_file)    
        if job_id == train_job_id:
          # If training job, then update the progress bar.
          curr_epoch = None
          for log in log_data:
            log_timestamp = timestamp_to_datetime(log["timestamp"])
            if log_timestamp < train_most_recent_timestamp:
              # If logs have not been refreshed, ignore them.
              break
            else:
              train_most_recent_timestamp = log_timestamp
            if log_timestamp == train_most_recent_timestamp:
              matches = re.search('Epoch ([0-9]*/[0-9]*):   [0-9]*%', log['jsonPayload']['message'])
              if matches:
                matches = matches.groups()
                log_epoch = int(matches[-1].split('/')[0])
                if total_num_epochs is None:
                  total_num_epochs = int(matches[-1].split('/')[1])
                if curr_epoch is None or log_epoch > curr_epoch:
                  # Logs can be received at different times, so check for the
                  # highest epoch number logged.
                  curr_epoch = log_epoch
                  progress_display.update(progress(5 + int(95. * curr_epoch / (total_num_epochs + 1)), 100))
                  break
        else:
          # If evaluation job, then update the metrics display.
          train_label_acc, train_label_auc, test_acc, test_auc = None, None, None, None
          for log in log_data:
            log_timestamp = timestamp_to_datetime(log["timestamp"])
            if log_timestamp < eval_most_recent_timestamp:
              # If logs have not been refreshed, ignore them.
              break
            if train_label_acc is None:
              train_label_matches = re.search('Train_Label AUC: ([0-9]*\.[0-9]*), Train_Label Accuracy: ([0-9]*\.[0-9]*)', log['jsonPayload']['message'])
              if train_label_matches:
                train_label_auc = train_label_matches.groups()[0]
                train_label_acc = train_label_matches.groups()[1]
            elif test_acc is None:
              test_matches = re.search('Test AUC: ([0-9]*\.[0-9]*), Test Accuracy: ([0-9]*\.[0-9]*)', log['jsonPayload']['message'])
              if test_matches:
                test_auc = test_matches.groups()[0]
                test_acc = test_matches.groups()[1]
            else:
              eval_most_recent_timestamp = log_timestamp
              break
          if train_label_acc is not None and test_acc is not None:
            if metrics_display is None:
              metrics_display = display(metrics(0, 0, 0, 0), display_id=True)
            else:
              metrics_display.update(metrics(train_label_acc, train_label_auc, test_acc, test_auc))          
  elif pattern_idx == 3:  # Job completed. Email user a notification.
    os.system(f"""printf 'Subject: Skai Training Complete\n\nTraining has completed! Please return to the Colab to visualize results.' | msmtp {EMAIL_MANAGER}""")
    progress_display.update(progress(100, 100))    
  else:
    child.close()

In [None]:
#@title View Results in Tensorboard

%load_ext tensorboard
%tensorboard --logdir gs://{pathgcp_exper}

## Inference prediction 🔮

Run the following script to use the model to create the damage assessment. When it is done you will be shown the summary statistics for the disaster along with a map based visualization of the damaged buildings.

In [None]:
#@title Run Inference

# Add custom basemaps to folium.
basemaps = {
    'Google Maps': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=m&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Maps',
        overlay = True,
        control = True
    ),
    'Google Satellite': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Satellite',
        overlay = True,
        control = True
    ),
    'Google Terrain': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=p&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Terrain',
        overlay = True,
        control = True
    ),
    'Google Satellite Hybrid': folium.TileLayer(
        tiles = 'https://mt1.google.com/vt/lyrs=y&x={x}&y={y}&z={z}',
        attr = 'Google',
        name = 'Google Satellite',
        overlay = True,
        control = True
    ),
    'Esri Satellite': folium.TileLayer(
        tiles = 'https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        attr = 'Esri',
        name = 'Esri Satellite',
        overlay = True,
        control = True
    )
}

def run_generate_inference_script(**args):

  submission_ending = '''
export GOOGLE_APPLICATION_CREDENTIALS=/root/service-account-private-key.json

source {python_env} ; python {path_skai}/src/launch_vertex_job.py \\
  --location={cloud_region} \\
  --project={cloud_project} \\
  --job_type=eval \\
  --display_name={display_name_infer} \\
  --dataset_name={dataset_name} \\
  --eval_docker_image_uri_path={eval_docker_image_uri_path} \\
  --service_account={service_account} \\
  --train_dir={train_dir} \\
  --test_examples={test_examples} \\
  --eval_ckpt={eval_model_ckpt} \\
  --inference_mode=True \\
  --save_predictions=True'''.format(**args)

  with open(args['path_run'], 'w+') as file:
    file.write(submission_ending)

def create_folium_map(geojson_path):
  with open(geojson_path, 'r') as f:
    predictions = json.load(f)

  damaged_preds = {
      'type': predictions['type'],
      'features': []
  }
  undamaged_preds = {
      'type': predictions['type'],
      'features': []
  }
  
  # Count number of buildings per class.
  num_damaged_buildings = 0
  num_undamaged_buildings = 0
  for feat in predictions['features']:
    if feat['properties']['class_1'] >= 0.5:
      num_damaged_buildings += 1
      damaged_preds['features'].append(feat)
    else:
      num_undamaged_buildings += 1
      undamaged_preds['features'].append(feat)

  lat = predictions['features'][0]['properties']['latitude']
  lon = predictions['features'][0]['properties']['longitude']
  
  # Create a folium map object. Location is latitude, longitude.
  my_map = folium.Map(location=[lat, lon], zoom_start=16, max_zoom=20)

  # Add custom basemaps.
  basemaps['Google Maps'].add_to(my_map)
  basemaps['Google Satellite Hybrid'].add_to(my_map)

  after_image_path = 'gs://'+pathgcp_imageafter
  after_map_id_dict = ee.Image.loadGeoTIFF(after_image_path).getMapId()
  folium.raster_layers.TileLayer(
      tiles=after_map_id_dict['tile_fetcher'].url_format,
      attr='COG',
      name = 'Post-Disaster Imagery',
      overlay = True,
      control = True,
      max_zoom = 20,
    ).add_to(my_map)

  # Add predictions.
  folium.features.GeoJson(damaged_preds, name='Damaged Predictions', 
                          style_function=style_function,
                          marker=folium.CircleMarker(),
                          ).add_to(my_map)
  folium.features.GeoJson(undamaged_preds, name='Undamaged Predictions', 
                          style_function=style_function,
                          marker=folium.CircleMarker(),
                          ).add_to(my_map)                          

  my_map.add_child(folium.LayerControl())

  print('Number of Damaged Buildings: ', num_damaged_buildings)
  print('Number of Undamaged Buildings: ', num_undamaged_buildings)
  print('Total: ', int(num_undamaged_buildings) + int(num_damaged_buildings))
  IPython.display.display(my_map)

display(Javascript("google.colab.output.resizeIframeToContent()"))

# Identify epoch number of last checkpoint.
most_recent_epoch_file = os.path.join(f'gs://{pathgcp_exper}', 'checkpoints', 'last_processed_epoch')
os.system(f'gsutil cp {most_recent_epoch_file} /tmp/last_processed_epoch')
with open('/tmp/last_processed_epoch', 'r') as epoch_f:
  epoch_num = epoch_f.read()
  epoch = epoch_num.zfill(8)

# Create inference script that will be run by child process.
timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
GCP_INFERENCE_NAME=f"{Author}_inference_{timestamp}_{Project_description}"

jobgcp_infer = GCP_INFERENCE_NAME + '_default'

file_runjob=f'run_inference_as_child_process_{Author}_{timestamp}_{Project_description}.sh'
pathsys_runfile=os.path.join(pathsys_runjobs,file_runjob)

generate_script_args={   
    'cloud_project':GCP_PROJECT,
    'cloud_region':"europe-west1",
    'eval_docker_image_uri_path':'gcr.io/disaster-assessment/ssl-eval-uri',
    'service_account':emailgcp_serviceaccount,
    'dataset_name':jobgcp_exper,
    'train_dir':'gs://'+pathgcp_exper,
    'test_examples':'gs://'+pathgcp_unlabeled,
    'display_name_infer':jobgcp_infer,
    'eval_model_ckpt': 'gs://'+pathgcp_exper+'/checkpoints/model.ckpt-'+epoch,
    'python_env':pathsys_actenv,
    'path_skai':pathsys_skai,
    'path_run': pathsys_runfile,
    }

run_generate_inference_script(**generate_script_args)


# Prepare credentials for map visualization.
service_account = 'skai-colab@skai-2022.iam.gserviceaccount.com'
credentials = ee.ServiceAccountCredentials(
    service_account, '/root/service-account-private-key.json')
ee.Initialize(credentials)
# Set style parameters.
style_function = lambda x: {
  'radius': 10,
  'weight': 1,
  'fill': True,
  'color': '#ff0000' if float(x['properties']['class_1']) >= 0.5 else '#00ff00',
  'fillColor': '#ff0000' if float(x['properties']['class_1']) >= 0.5 else '#00ff00',
  'fillOpacity': 0.3
}

# Initialize progress bar.
progress_display = display(progress(0, 100), display_id=True)
curr_idx = 0
map = None

# Run the child program.
child = pexpect.spawn(f'sh {pathsys_runfile}')
while not child.closed:
  # Expects 5 different patterns, or EOF (meaning the program terminated).
  # Each pattern is a regex and you can use regex match groups "()" to extract a
  # part of the matched text for later use.
  pattern_idx = child.expect([
    'CustomJob created\.',
    'JobState\.JOB_STATE_PENDING\r\n',
    'JobState\.JOB_STATE_RUNNING\r\n',
    'CustomJob run completed\.',
    pexpect.EOF], timeout=None)
  if pattern_idx == 0:  # A job was created.
    progress_display.update(progress(5, 100))
  elif pattern_idx == 1:  # Job Pending.
    progress_display.update(progress(10, 100))
  elif pattern_idx == 2:  # Job Running.
    starting_progress = 20
    max_progress = 90
    curr_progress = starting_progress + (curr_idx * 2)
    if curr_idx == 0:
      progress_display.update(progress(starting_progress, 100))
    elif curr_progress < max_progress:
      # Update while job is running only until progress hits 90.
      progress_display.update(progress(curr_progress, 100))
    curr_idx += 1
  elif pattern_idx == 3:  # Job Completed.
    progress_display.update(progress(100, 100))
  else:
    child.close()

preds_file = os.path.join(f'gs://{pathgcp_exper}', 'predictions', f'test_ckpt_{epoch_num}.geojson')
os.system(f'gsutil cp {preds_file} /tmp/predictions.geojson')
create_folium_map('/tmp/predictions.geojson')
