In [11]:
import os
import warnings
from pprint import pprint
import descarteslabs as dl

## Wrapper function to generate tiles over Africa, search for imagery, and run classification over each tile. 

In [14]:
def run_classification():
    import os
    import warnings
    from pprint import pprint
    import descarteslabs as dl
    import matplotlib.pyplot as plt
    from matplotlib import colors
    from osgeo import gdal
    import numpy as np
    from sklearn import metrics
    from sklearn.ensemble import RandomForestClassifier
    from descarteslabs.client.services import Catalog
    from descarteslabs.client.services import Storage 
    
    # *************** TRAIN MODEL ***************
     def create_mask_from_vector(vector_data_path, cols, rows, geo_transform,
                            projection, target_value=1):
            """Rasterize the given vector (wrapper for gdal.RasterizeLayer)."""
            data_source = gdal.OpenEx(vector_data_path, gdal.OF_VECTOR)
            layer = data_source.GetLayer(0)
            driver = gdal.GetDriverByName('MEM')  # In memory dataset
            target_ds = driver.Create('', cols, rows, 1, gdal.GDT_UInt16)
            target_ds.SetGeoTransform(geo_transform)
            target_ds.SetProjection(projection)
            gdal.RasterizeLayer(target_ds, [1], layer, burn_values=[target_value])
            return target_ds


    def vectors_to_raster(file_paths, rows, cols, geo_transform, projection):
        """Rasterize the vectors in the given directory in a single image."""
        labeled_pixels = np.zeros((rows, cols))
        print
        for i, path in enumerate(file_paths):
            label = i+1
            ds = create_mask_from_vector(path, cols, rows, geo_transform,
                                         projection, target_value=label)
            band = ds.GetRasterBand(1)
            labeled_pixels += band.ReadAsArray()
            ds = None
        return labeled_pixels

    def write_geotiff(fname, data, geo_transform, projection):
        """Create a GeoTIFF file with the given data."""
        driver = gdal.GetDriverByName('GTiff')
        rows, cols = data.shape
        dataset = driver.Create(fname, cols, rows, 1, gdal.GDT_Byte)
        dataset.SetGeoTransform(geo_transform)
        dataset.SetProjection(projection)
        band = dataset.GetRasterBand(1)
        band.WriteArray(data)
        dataset = None  # Close the file



    shapefiles = Storage().get('urban_training_shapefiles') 

    labeled_pixels = vectors_to_raster(shapefiles, rows, cols, geo_transform,
                               proj)
    is_train = np.nonzero(labeled_pixels)
    training_labels = labeled_pixels[is_train]
    training_samples = clipped_arr[is_train]


    classifier = RandomForestClassifier(n_jobs=-1)
    classifier.fit(training_samples, training_labels)
    
    # ************* GENERATE TILES OVER AFRICA ***************   
    africa_bounding_box = {'coordinates':[[[-25.360422, -34.821954],
                        [-25.360422, 37.345201],
                        [51.417038, 37.345201],
                        [51.417038, -34.821954]]],
       'type': 'Polygon'}
    
    tiles = dl.raster.dltiles_from_shape(
        resolution= 351.263936238588940, 
        tilesize=2048, 
        pad=16, 
        shape=africa_bounding_box)
    
    print("Total number of tiles for Africa: " + str(len(tiles['features'])))
    
    # ************* ITERATE OVER EACH TILE, FIND IMAGERY AND CLASSIFY *************** 
   
    # ************* SEARCH CODE ***************
    date = ['2016-06-01','2016-06-30']
    africa_urban_classification = [] 
    geo_transform = clipped_meta['geoTransform']
    proj = 'PROJCS["WGS 84 / UTM zone 36N",GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]],PROJECTION["Transverse_Mercator"],PARAMETER["latitude_of_origin",0],PARAMETER["central_meridian",33],PARAMETER["scale_factor",0.9996],PARAMETER["false_easting",500000],PARAMETER["false_northing",0],UNIT["metre",1,AUTHORITY["EPSG","9001"]],AXIS["Easting",EAST],AXIS["Northing",NORTH],AUTHORITY["EPSG","32636"]]'
    
    for tile in tiles['features']:
        images = dl.metadata.search(
                                products=["landsat:LC08:PRE:TOAR"],
                                start_time=date[0],
                                end_time=date[1],
                                geom=json.dumps(tile['geometry']),
                                cloud_fraction=0.2,
                                limit = 1000
                                )

        ids = []
        for image in images['features']:
            ids.append(image['id'])

        arr, meta = dl.raster.ndarray(
            ids,
            bands=['nir', 'swir1', 'red', 'alpha'],
            scales=[[0,6000], [0, 6000], [0, 6000], None],
            data_type='Byte',
            srs = tile['properties']['cs_code'],
            resolution = 351.263936238588940,
            bounds = tile['properties']['outputBounds'],
            cutline = africa_bounding_box)

        
    # ************* CLASSIFICATION CODE ***************
        rows, cols, n_bands = clipped_arr.shape
        
        n_samples = rows*cols
        flat_pixels = arr.reshape((n_samples, n_bands))
        result = classifier.predict(flat_pixels)
        classification = result.reshape((rows, cols))
        
        
        classification = classification[16:-16, 16:-16]
        africa_urban_classification.append(classification)

        
    # ************* SAVE TO TOTAL CLASSIFICATION TO PLATFORM *************** 
    output_fname = "urban_classification_africa.tiff" 
    write_geotiff(output_fname, africa_urban_classification, geo_transform, proj)   
    Catalog().add_product('Africa_Urban', 
                          title='Africa_Urban', 
                          description='Africa areas identified using the random forest classification.'
                         )
    Catalog().add_band(product_id='7294028cc01114d89a473cf055d29dc5cd5ffe88:Africa_Urban', name='urban', srcband=1, nbits=64,dtype='Float64',type='class',data_range=[1.000,2.000],colormap_name='magma')

    Catalog().upload_image(urban_classification_africa.tiff,
                           '7294028cc01114d89a473cf055d29dc5cd5ffe88:Burundi_Urban',  
                           acquired='2018-04-07')

## Serialize and run task on Descartes Labs Cloud

In [13]:
from descarteslabs.client.services.tasks import AsyncTasks, as_completed

In [16]:
at = AsyncTasks()
async_func = at.create_function(
    run_classification,
    name='Generate_Tiles_Zimbabwe',
    image="us.gcr.io/dl-ci-cd/images/tasks/public/geospatial/geospatial-public:latest"
)

In [10]:
task = async_func()