In [1]:
# auto reload
%load_ext autoreload
%autoreload 2

import os
import imageio
import utm
import numpy as np
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor
from datetime import datetime, timedelta

from eolearn.core import FeatureType, LinearWorkflow, EOExecutor, EOTask, SaveTask, LoadTask, OverwritePermission
from eolearn.mask import AddValidDataMaskTask
from eolearn.io import SentinelHubInputTask
from eolearn.coregistration import ThunderRegistration

from sentinelhub import DataSource, BBox, CRS

  from collections import Iterable
  args = inspect.getargspec(func)


# Task definitions

In [2]:
class ValidData:
    def __init__(self, is_data_mask, clm_mask):
        self.is_data_mask = is_data_mask
        self.clm_mask = clm_mask
        
    def __call__(self, eopatch):
        return np.logical_and(eopatch.mask[self.is_data_mask].astype(np.bool),
                              np.logical_not(eopatch.mask[self.clm_mask].astype(np.bool)))
    

class ValidCoverageTask(EOTask):
    def __init__(self, feature_in, feature_out):
        self.feature_in = feature_in
        self.feature_out = feature_out
        
    def execute(self, eopatch):
        mask = eopatch[self.feature_in]
        coverage = np.count_nonzero(mask == 1, axis=(1,2))/np.prod(mask.shape[1:])
        eopatch[self.feature_out] = coverage
        return eopatch

In [3]:
# download RBG L2A bands for plotting
download_task = SentinelHubInputTask(
    bands_feature=(FeatureType.DATA, 'RGB'),
    bands = ['B04', 'B03', 'B02'],
    resolution=10,
    maxcc=1.0,
    time_difference=timedelta(hours=2),
    data_source=DataSource.SENTINEL2_L2A,
    max_threads=10,
    additional_data=[
        (FeatureType.MASK, 'dataMask'),
        (FeatureType.MASK, 'CLM'),
        (FeatureType.DATA, 'CLP'),
        (FeatureType.DATA, 'sunZenithAngles'),
        (FeatureType.DATA, 'sunAzimuthAngles')
    ])

# task for creating a valid data mask
valid_data_task = AddValidDataMaskTask(ValidData(is_data_mask='dataMask', clm_mask='CLM'), 'VALID_DATA')
        

# task for calculating the valid coverage
valid_cov_task = ValidCoverageTask(feature_in=(FeatureType.MASK, 'VALID_DATA'), 
                                   feature_out=(FeatureType.SCALAR, 'VALID_COVERAGE'))


# task for coregistrating the time frames
coreg_task = ThunderRegistration((FeatureType.DATA, 'RGB'), valid_mask_feature = (FeatureType.MASK, 'VALID_DATA'), 
                                 channel=0)


# tasks for loading and saving
load_task = LoadTask('./eopatches', lazy_loading=True)
save_task = SaveTask('./eopatches', overwrite_permission=OverwritePermission.OVERWRITE_PATCH)

# Download

In [7]:
os.system('rm -rf eopatches && mkdir eopatches')

workflow = LinearWorkflow(
    download_task,
    valid_data_task,
    valid_cov_task,
    coreg_task,
    save_task
)

#(lat, lon), width in meters
location_data = [
    [[30.962476, 34.730068], 1000], # ashalim
    [[25.197020, 55.274212], 1250], # burj khalifa
    [[29.979221, 31.134213], 1e3], # pyramids
    [[35.710054, 139.810714],1250] # tokyo sky tree
]

time_interval = [datetime(2019,1,1), datetime(2019,12,31)]
bbox_list = []
for idx in range(len(location_data)):
    x,y,zone,letter = utm.from_latlon(*location_data[idx][0])
    x,y = np.round([x,y], -1)
    d = location_data[idx][1]
    bbox_list.append(BBox((x-d, y-d, x+d, y+d), eval(f'CRS.UTM_{zone}{"N" if letter >= "N" else "S"}')))

execution_args = []
for idx, bbox in enumerate(bbox_list):
    execution_args.append({
        download_task: {'bbox': bbox, 'time_interval': time_interval},
        save_task: {'eopatch_folder': f'eopatch_{idx}'}
    })
    
executor = EOExecutor(workflow, execution_args, save_logs=True)
executor.run(workers=8, multiprocess=False)

executor.make_report()

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




# Load

In [14]:
# load the previously created eopatches

workflow = LinearWorkflow(
    load_task
)

execution_args = []
for idx, bbox in enumerate(bbox_list):
    execution_args.append({
        load_task: {'eopatch_folder': f'eopatch_{idx}'}
    })
    
eopatches = []
for args in execution_args:
    eopatches.append(workflow.execute(args).eopatch())

# Create animation

In [25]:
# create RGB animations 
os.system('rm -rf graphs && mkdir graphs')
factors = [2.0, 2.5, 2.5, 2.75]

def plot_image(idx, f):
    fig = plt.figure(figsize=(10,10))
    plt.imshow(np.clip(eop.data['RGB'][idx]*f,0,1))
    plt.xticks([])
    plt.yticks([])
    plt.axis('off')
    plt.savefig(f'graphs/true_color_{name}_{idx}.png', dpi=50, bbox_inches='tight')
    plt.close()
    
th = 0.9

for idx in tqdm(range(len(eopatches)), total=len(eopatches)):
    eop = eopatches[idx]
    f = factors[idx]
    name = f'loc{idx}'
    
    def plot(idx):
        plot_image(idx,f)

    with ProcessPoolExecutor(max_workers=8) as executor:
        _ = list(tqdm(executor.map(plot, range(len(eop.timestamp))), total=len(eop.timestamp), leave=False))

    n_valid = np.count_nonzero(eop.scalar['VALID_COVERAGE']  > th)
    with imageio.get_writer(f'figs/true_color_{name}.gif', mode='I', duration=2/n_valid) as writer:
        for i in range(len(eop.timestamp)):
            if eop.scalar['VALID_COVERAGE'][i] > th:
                image = imageio.imread(f'graphs/true_color_{name}_{i}.png')
                writer.append_data(image)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=73.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))




# Create graph animation

In [26]:
# define function for plotting time frame

def plot_graph(idx, f, name):
    x = np.mean(eop.data['sunAzimuthAngles'][...,0], axis=(1,2))
    y = np.mean(eop.data['sunZenithAngles'][...,0], axis=(1,2))
    valid_mask = eop.scalar['VALID_COVERAGE'].squeeze() > th
    ids = np.array(range(len(valid_mask)))
    dates = np.array([ts.date().isoformat()[:-3] for ts in eop.timestamp])
    transition_dates = [datetime(2019,3,20), datetime(2019,6,21),datetime(2019,9,23), datetime(2019,12,22)]
    transition_doys = [(x-datetime(2019,1,1)).days for x in transition_dates]    
    try:
        last_valid = ids[:idx+1][valid_mask[:idx+1]][-1]
    except:
        last_valid = ids[valid_mask][-1]
    
    fig, axs = plt.subplots(2,2,figsize=(10,10))
    plot_ids = np.array(range(len(valid_mask[:idx])))[valid_mask[:idx]].astype(int)

    ax1 = axs[0,0]
    ax1.plot(x,y)
    ax1.plot(x[:idx+1],y[:idx+1],'r')
    ax1.plot(x[plot_ids],y[plot_ids],'kx')
    ax1.plot(x[idx], y[idx], 'r', marker='o')
    ax1.set_xlabel('Zenith [°]')
    ax1.set_ylabel('Azimuth [°]')

    def idx_to_days(x):
        return x*365/len(dates)

    def days_to_idx(x):
        return x*len(dates)/365
    
    ax2 = axs[0,1]
    ax2.plot(y)
    ax2.plot(range(len(y[:idx+1])), y[:idx+1],'r')
    ax2.plot(np.array(range(len(y)))[plot_ids], y[plot_ids],'kx')
    ax2.plot(idx, y[idx], 'r', marker='o')
    ax2.set_xlabel('Time')
    ax2.set_xticks([days_to_idx(doy) for doy in transition_doys])
    ax2.set_xticklabels([dt.strftime(format='%Y-%m') for dt in transition_dates], ha='right')
    ax2.set_yticks([])
    
    ax3 = axs[1,0]
    ax3.plot(x, range(len(x)))
    ax3.plot(x[:idx+1], range(len(x[:idx+1])), 'r')
    ax3.plot(x[plot_ids], np.array(range(len(x)))[plot_ids], 'kx')
    ax3.plot(x[idx], idx, 'r', marker='o')
    ax3.set_xticks([])
    ax3.set_yticks([])
    
    sax3 = ax3.secondary_yaxis('right', functions=(idx_to_days, days_to_idx))
    sax3.set_ylabel('Time')
    sax3.set_yticks(transition_doys)
    sax3.set_yticklabels([dt.strftime(format='%Y-%m') for dt in transition_dates], rotation=90, va='top')

    ax1.axvline(x=x[idx],ymin=-1.2,ymax=1,c="gray",linewidth=1, linestyle='dashed',zorder=0, clip_on=False)
    ax3.axvline(x=x[idx],ymin=0,ymax=1,c="gray",linewidth=1, linestyle='dashed',zorder=0, clip_on=False)
    ax1.axhline(y=y[idx],xmin=0,xmax=1.2,c="gray",linewidth=1, linestyle='dashed',zorder=0, clip_on=False)
    ax2.axhline(y=y[idx],xmin=0,xmax=1,c="gray",linewidth=1, linestyle='dashed',zorder=0, clip_on=False)
    
    ax4 = axs[1,1]
    
    if last_valid is not None:
        ax4.imshow(np.clip(eop.data['RGB'][last_valid]*f,0,1))
    else:
        pass
    
    ax4.set_xticks([])
    ax4.set_yticks([])
    ax4.axis('off')  
        
    # comment out below to plot
    plt.savefig(f'graphs/graph_{name}_{idx}.png', dpi=100, bbox_inches='tight')
    plt.close()

In [27]:
# create RGB and solar angle animation
for idx in tqdm(range(len(eopatches)), total=len(eopatches)):
    eop = eopatches[idx]
    name = f'loc{idx}'
    f = factors[idx]

    def plot(idx):
        plot_graph(idx, f, name)

    os.system('rm -rf graphs && mkdir graphs')
    with ProcessPoolExecutor(max_workers=8) as executor:
        _ = list(tqdm(executor.map(plot, range(len(eop.timestamp))), total=len(eop.timestamp), leave=False))

    with imageio.get_writer(f'figs/graph_{name}.gif', mode='I', duration=4.0/len(eop.timestamp)) as writer:
        for i in range(len(eop.timestamp)):
            image = imageio.imread(f'graphs/graph_{name}_{i}.png')
            writer.append_data(image)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=71.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=73.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=72.0), HTML(value='')))


