In [1]:
import pathlib
import os
os.chdir(pathlib.Path(os.getcwd()).parent)
from dnn.model_pipeline import UNetPipeline, get_masked_daily_product
from dnn.model_pipeline import ndvi, get_cdl, isin, crops_list
import descarteslabs as dl
import numpy as np
from descarteslabs.client.services import Places
import matplotlib as mpl
import matplotlib.patches as patches
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import shapely
import shapely.ops
import shapely.prepared
import rasterio.features
import ipyleaflet
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

  self, resource_name
  self, resource_name


### Generate the Tilated Image Collection and Get Rid of Invalid Tiles

In [2]:
# Create the Central Valley AOI and Tilation
sac = shapely.geometry.shape(
    dl.places.shape("north-america_united-states_california_sacramento-valley").geometry
)
sj = shapely.geometry.shape(
    dl.places.shape("north-america_united-states_california_san-joaquin-valley").geometry)
central_valley_aoi = sac.union(sj)

tiles = dl.scenes.DLTile.from_shape(
    central_valley_aoi, resolution=20, tilesize=64, pad=0)
print(f'Number of tiles: {len(tiles)}')

Number of tiles: 62635


In [3]:
# Create Relevant Image Collections from our AOI / Tiles
start_datetime = "2017-01-01"
end_datetime = "2020-01-01"
# Create Landsat Image Collection
l8_daily = get_masked_daily_product(
    "landsat:LC08:01:T1:TOAR", start_datetime, end_datetime
).pick_bands("red green blue nir swir1")
l8_daily = l8_daily.concat_bands(ndvi(l8_daily))
# Create CDL Image Collection
cdl = get_cdl(start_date="2017-01-01", end_date="2020-01-01")
is_crops = isin(cdl, crops_list)
is_crops_19 = is_crops[-1]
four_year_combo = is_crops.sum(axis="images") + is_crops_19  # double-weight 2019
four_year_binary = four_year_combo >= 2
cdl_mask = ~four_year_binary

In [4]:
# Find the the valid tiles
central_valley_ctx = dl.scenes.AOI(central_valley_aoi, shape=(2048, 2048), crs="EPSG:4326")
all_cdl = four_year_binary.compute(central_valley_ctx)
all_cdl.geocontext["gdal_geotrans"]
shapes = list(
    geom for geom, value in
    rasterio.features.shapes(
        all_cdl.ndarray.astype("uint8"), 
        transform=rasterio.transform.Affine.from_gdal(*all_cdl.geocontext["gdal_geotrans"])
    )
    if value == 1
)
print(f"length of shapes: {len(shapes)}")
all_valid = shapely.ops.unary_union([shapely.geometry.shape(s) for s in shapes]).simplify(0.3)
print(f'Type of all_valid: {type(all_valid)}')
all_valid_prepped = shapely.prepared.prep(all_valid)
valid_tiles = [t for t in tqdm(tiles) if all_valid_prepped.intersects(t.geometry)]
print(f'No. Valid Tiles: {len(valid_tiles)}')
print(f'Percentage of valid tiles: {100*(len(valid_tiles) / len(tiles))}')


Job ID: 063013cfb98bfbeed5e6ce21f35719a741551977e37933c6
[######] | Steps: 10/10 | Stage: SUCCEEDED                                    
Job ID: 063013cfb98bfbeed5e6ce21f35719a741551977e37933c6
length of shapes: 4957
Type of all_valid: <class 'shapely.geometry.multipolygon.MultiPolygon'>


HBox(children=(FloatProgress(value=0.0, max=62635.0), HTML(value='')))


No. Valid Tiles: 19904
Percentage of valid tiles: 31.777760038317233


## Train the U-Net Model

In [5]:
class_weight = [1] + [5] * 48
unet_params = {
    'img_height': 64,
    'img_width': 64,
    'bands': 6,
    'time_steps': 12,
    'nclasses': len(crops_list) + 1,
    'class_weights': class_weight,
    'learning_rate': 5e-3
}
random_seed = 2020
L = len(valid_tiles)
test_split = np.random.RandomState(random_seed)
tr_ix = test_split.choice(np.arange(L), int(0.85 * L), replace=False)
tst_ix = np.array([k for k in np.arange(L) if k not in tr_ix])
u_net = UNetPipeline(model_params=unet_params, 
                     tiles=np.asanyarray(valid_tiles),
                     img_prod_id="landsat:LC08:01:T1:TOAR",
                     train_ix=tr_ix,
                     test_ix=tst_ix,
                     random_seed=random_seed
                     )

Model: "Unet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
image_input (InputLayer)        [(None, 12, 64, 64,  0                                            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 12, 64, 64, 6 24          image_input[0][0]                
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 12, 64, 64, 1 2608        batch_normalization[0][0]        
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 12, 64, 64, 1 64          conv3d[0][0]                     
_______________________________________________________________________________________________

In [None]:
u_net.train_model(8, 2, 15)

Starting Training ...
Year 2017
Epoch 1 / 2: 
Batch 1 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 16 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 31 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 46 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 61 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 76 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 91 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 106 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 121 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 136 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 151 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 166 out of 2115
################################
Year 2017
Epoch 1 / 2: 
Batch 181 out 

In [None]:
#u_net.save_model(str(pathlib.Path(os.getcwd()) / 'saved_models' / 'model1.h5'))
#u_net.load_model(str(pathlib.Path(os.getcwd()) / 'saved_models' / 'model0.h5'))

#### Load some test data for validation

In [7]:
#test_data = u_net.data_loader(np.arange(5), '2019')

In [8]:
#preds = u_net.model.predict(test_data)

In [15]:
#import pandas as pd
#pd.Series(preds[0].argmax(1)).value_counts()

11    4038
32      56
12       2
dtype: int64