https://sciforum.net/manuscripts/5158/slides.pdf

In [1]:
import xarray
import numpy
import matplotlib.pyplot as plt
import pandas
import geopandas
import math

In [2]:
tiles = ['11SLB', # UCSB
         '10SFH', # LNBL 1
         '10SGH', # LNBL 2         
         '11SQD', # DRI 1
         '11SQC', # DRI 2
         '11TNK', # BSU 1
         '11TNJ', # BSU 2
         '11TPK', # BSU 3
         '11TPJ'  # BSU 5
        ]

#tiles = ['10SFH']

In [3]:
prefix = '/data/sentinel2/'

In [4]:
def read_10m(tile):
    zarr_store = f'{prefix}/zarrs/{tile}_10m.zarr'
    print(zarr_store)
    ds_10m = xarray.open_zarr(zarr_store)
    ds_10m = ds_10m.drop_vars('spatial_ref') # Only keep it for either 10m or 20m
    ds_10m = ds_10m.sel(band=['B2', 'B3','B4', 'B8'])
    return ds_10m

def read_20m(tile):
    zarr_store = f'{prefix}/zarrs/{tile}_20m.zarr'
    ds_20m = xarray.open_zarr(zarr_store)
    ds_20m = ds_20m.sel(band=['B5', 'B6', 'B7', 'B8A', 'B11', 'B12'])
    return ds_20m

def read_60m(tile):
    zarr_store = f'{prefix}/zarrs/{tile}_60m.zarr'
    ds_60m = xarray.open_zarr(zarr_store)
    ds_60m = ds_60m.sel(band=['B1', 'B9'])
    return ds_60m

def upscale(ds_10m, ds_20m):
    """
    We might feel tempted to exclude the first and last values in x and y (because become NaNs)
    But that would mean that the shape and thus chunk size would change
    """
    ds_20m_upscale = ds_20m.interp(x=ds_10m.x,
                                   y=ds_10m.y,
                                   method='nearest',
                                   method_non_numeric='pad')
    return ds_20m_upscale

def sharpen1(ds_10m, ds_20m_upscale):
    # Method 1: Sum across
    panband = ds_10m['reflectance'].sum(dim=('band')) / 4
    intensity = ds_20m_upscale['reflectance'].sum(dim='band') / 6
    sharpend = (ds_20m_upscale['reflectance'] / intensity * panband )
    return sharpend

def sharpen2(ds_10m, ds_20m_upscale):
    # Method 2: Excluse the SWIR bands
    panband = ds_10m['reflectance'].sum(dim=('band')) / 4
    intensity = ds_20m_upscale['reflectance'].sel(band=['B5', 'B6', 'B7', 'B8A']).sum(dim='band') / 4
    sharpend = (ds_20m_upscale['reflectance'] / intensity * panband )
    return sharpend

def sharpen3(ds_10m, ds_20m_upscale):
    # Method 3: Scale on band 8/8A only
    panband = ds_10m['reflectance'].sel(band='B8')
    intensity = ds_20m_upscale['reflectance'].sel(band='B8A')
    intensity = intensity.where(intensity!=0)
    sharpend = (ds_20m_upscale['reflectance'] / intensity * panband )
    return sharpend

def stack_data(ds_10m, ds_20m_upscale, sharpend):
    # Creating the stacked (10m + upscaled 20m arrays)
    stacked_viewing_azimuth = xarray.concat([ds_20m_upscale['viewing_azimuth_grid'], ds_10m['viewing_azimuth_grid']], dim='band')
    stacked_viewing_zenith = xarray.concat([ds_20m_upscale['viewing_zenith_grid'], ds_10m['viewing_zenith_grid']], dim='band')
    stacked_reflectance = xarray.concat([sharpend, ds_10m['reflectance']], dim='band')

    # We need to drop the old reflectance + angle DA from the DS and re-set it with the stacked (now 10 band) DA. 
    # If we don't do this, the band coordinates do not get updated
    reflectance_attrs = ds_20m_upscale['reflectance'].attrs
    ds_20m_upscale = ds_20m_upscale.drop_vars('viewing_azimuth_grid')
    ds_20m_upscale = ds_20m_upscale.drop_vars('viewing_zenith_grid')
    ds_20m_upscale = ds_20m_upscale.drop_vars('reflectance').drop_vars('band')

    ds_20m_upscale['reflectance'] = stacked_reflectance
    ds_20m_upscale['reflectance'].attrs = reflectance_attrs
    ds_20m_upscale['viewing_zenith_grid'] = stacked_viewing_zenith
    ds_20m_upscale['viewing_azimuth_grid'] = stacked_viewing_azimuth
    ds_20m_upscale['band'] = ds_20m_upscale['band'].astype('str')
    ds_20m_upscale = ds_20m_upscale.chunk(band=-1, x=-1, y=1000, time=1)
    return ds_20m_upscale

def write_zarr(tile, ds_20m_upscale):    
    zarr_store = f'{prefix}/zarrs/{tile}_sharpend.zarr'
    ds_20m_upscale.to_zarr(zarr_store, mode='w', zarr_format=2)

def process_tile(tile):
    print(tile)
    ds_10m = read_10m(tile)
    ds_20m = read_20m(tile)
    print('finished reading data')
    ds_20m_upscale = upscale(ds_10m, ds_20m)
    print('finished scaling')
    sharpend = sharpen3(ds_10m, ds_20m_upscale)
    print('finished sharpening')
    ds_20m_upscale = stack_data(ds_10m, ds_20m_upscale, sharpend)
    print('finished stacking')
    write_zarr(tile, ds_20m_upscale)

In [5]:
tile = '11SLB'
ds_10m = read_10m(tile)
ds_20m = read_20m(tile)
ds_20m_upscale = upscale(ds_10m, ds_20m)
sharpend = sharpen3(ds_10m, ds_20m_upscale)
ds_20m_upscale = stack_data(ds_10m, ds_20m_upscale, sharpend)
ds_20m_upscale

/data/sentinel2//zarrs/11SLB_10m.zarr


Unnamed: 0,Array,Chunk
Bytes,1.15 MiB,2.07 kiB
Shape,"(568, 23, 23)","(1, 23, 23)"
Dask graph,568 chunks in 2 graph layers,568 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.15 MiB 2.07 kiB Shape (568, 23, 23) (1, 23, 23) Dask graph 568 chunks in 2 graph layers Data type float32 numpy.ndarray",23  23  568,

Unnamed: 0,Array,Chunk
Bytes,1.15 MiB,2.07 kiB
Shape,"(568, 23, 23)","(1, 23, 23)"
Dask graph,568 chunks in 2 graph layers,568 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,255.10 GiB,41.89 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 255.10 GiB 41.89 MiB Shape (568, 10980, 10980) (1, 1000, 10980) Dask graph 6248 chunks in 21 graph layers Data type float32 numpy.ndarray",10980  10980  568,

Unnamed: 0,Array,Chunk
Bytes,255.10 GiB,41.89 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,510.20 GiB,83.77 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 510.20 GiB 83.77 MiB Shape (568, 10980, 10980) (1, 1000, 10980) Dask graph 6248 chunks in 21 graph layers Data type float64 numpy.ndarray",10980  10980  568,

Unnamed: 0,Array,Chunk
Bytes,510.20 GiB,83.77 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,1.15 MiB,2.07 kiB
Shape,"(568, 23, 23)","(1, 23, 23)"
Dask graph,568 chunks in 2 graph layers,568 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.15 MiB 2.07 kiB Shape (568, 23, 23) (1, 23, 23) Dask graph 568 chunks in 2 graph layers Data type float32 numpy.ndarray",23  23  568,

Unnamed: 0,Array,Chunk
Bytes,1.15 MiB,2.07 kiB
Shape,"(568, 23, 23)","(1, 23, 23)"
Dask graph,568 chunks in 2 graph layers,568 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,510.20 GiB,83.77 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 510.20 GiB 83.77 MiB Shape (568, 10980, 10980) (1, 1000, 10980) Dask graph 6248 chunks in 21 graph layers Data type float64 numpy.ndarray",10980  10980  568,

Unnamed: 0,Array,Chunk
Bytes,510.20 GiB,83.77 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.22 kiB,4 B
Shape,"(568,)","(1,)"
Dask graph,568 chunks in 2 graph layers,568 chunks in 2 graph layers
Data type,,
"Array Chunk Bytes 2.22 kiB 4 B Shape (568,) (1,) Dask graph 568 chunks in 2 graph layers Data type",568  1,

Unnamed: 0,Array,Chunk
Bytes,2.22 kiB,4 B
Shape,"(568,)","(1,)"
Dask graph,568 chunks in 2 graph layers,568 chunks in 2 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,510.20 GiB,83.77 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 510.20 GiB 83.77 MiB Shape (568, 10980, 10980) (1, 1000, 10980) Dask graph 6248 chunks in 21 graph layers Data type float64 numpy.ndarray",10980  10980  568,

Unnamed: 0,Array,Chunk
Bytes,510.20 GiB,83.77 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,255.10 GiB,41.89 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 255.10 GiB 41.89 MiB Shape (568, 10980, 10980) (1, 1000, 10980) Dask graph 6248 chunks in 21 graph layers Data type float32 numpy.ndarray",10980  10980  568,

Unnamed: 0,Array,Chunk
Bytes,255.10 GiB,41.89 MiB
Shape,"(568, 10980, 10980)","(1, 1000, 10980)"
Dask graph,6248 chunks in 21 graph layers,6248 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,2.49 TiB,418.85 MiB
Shape,"(10, 568, 10980, 10980)","(10, 1, 1000, 10980)"
Dask graph,6248 chunks in 34 graph layers,6248 chunks in 34 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 2.49 TiB 418.85 MiB Shape (10, 568, 10980, 10980) (10, 1, 1000, 10980) Dask graph 6248 chunks in 34 graph layers Data type float32 numpy.ndarray",10  1  10980  10980  568,

Unnamed: 0,Array,Chunk
Bytes,2.49 TiB,418.85 MiB
Shape,"(10, 568, 10980, 10980)","(10, 1, 1000, 10980)"
Dask graph,6248 chunks in 34 graph layers,6248 chunks in 34 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,80.23 MiB,144.65 kiB
Shape,"(568, 10, 7, 23, 23)","(1, 10, 7, 23, 23)"
Dask graph,568 chunks in 42 graph layers,568 chunks in 42 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 80.23 MiB 144.65 kiB Shape (568, 10, 7, 23, 23) (1, 10, 7, 23, 23) Dask graph 568 chunks in 42 graph layers Data type float32 numpy.ndarray",10  568  23  23  7,

Unnamed: 0,Array,Chunk
Bytes,80.23 MiB,144.65 kiB
Shape,"(568, 10, 7, 23, 23)","(1, 10, 7, 23, 23)"
Dask graph,568 chunks in 42 graph layers,568 chunks in 42 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,80.23 MiB,144.65 kiB
Shape,"(568, 10, 7, 23, 23)","(1, 10, 7, 23, 23)"
Dask graph,568 chunks in 42 graph layers,568 chunks in 42 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 80.23 MiB 144.65 kiB Shape (568, 10, 7, 23, 23) (1, 10, 7, 23, 23) Dask graph 568 chunks in 42 graph layers Data type float32 numpy.ndarray",10  568  23  23  7,

Unnamed: 0,Array,Chunk
Bytes,80.23 MiB,144.65 kiB
Shape,"(568, 10, 7, 23, 23)","(1, 10, 7, 23, 23)"
Dask graph,568 chunks in 42 graph layers,568 chunks in 42 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
%%time
tile = '11SLB'
process_tile(tile)

In [None]:
%%time
for tile in tiles[0:3]:
    process_tile(tile)

# Manual

## Load Data

In [None]:
ds_10m = read_10m(tile)
ds_20m = read_20m(tile)
ds_20m_upscale = upscale(ds_10m, ds_20m)
sharpend = sharpen3(ds_10m, ds_20m_upscale)
ds_20m_upscale = stack_data(ds_10m, ds_20m_upscale, sharpend)
ds_20m_upscale = ds_20m_upscale.chunk(band=10, x=500, y=500)
ds_20m_upscale

In [None]:
%%time 
zarr_store = f'/tablespace/sentinel2/{tile}_sharpend.zarr'
ds_20m_upscale.to_zarr(zarr_store, mode='w')

# Make Plots

In [None]:
import matplotlib.pyplot as plt
date = '2024-02-25'

In [None]:
def normalize(data):
    data = data - numpy.nanmin(data)
    data = data / numpy.nanmax(data)     
    return data

def mask(data):
    data = numpy.ma.masked_array(data, numpy.isnan(data))
    data[data.mask] = 0
    return data

In [None]:
# Pandband
fig, ax = plt.subplots(dpi=300)

rgb = panband.sel(time=date).squeeze().transpose('y', 'x').values 

img = ax.imshow(rgb)
ax.axis('off')  
ax.set_aspect('equal')
cbar = fig.colorbar(img, ax=ax, shrink=0.7)
rgb.min(), rgb.max()

In [None]:
# Intensity
fig, ax = plt.subplots(dpi=300)

rgb = intensity.sel(time=date).squeeze().transpose('y', 'x').values 
rgb = mask(rgb)

img = ax.imshow(rgb)
ax.axis('off')  
ax.set_aspect('equal')
cbar = fig.colorbar(img, ax=ax, shrink=0.7)
rgb.min(), rgb.max()

In [None]:
# 10 m
fig, ax = plt.subplots(dpi=300)

rgb = ds_10m['reflectance'].sel(time=date, band=['B4', 'B3', 'B2']).squeeze().transpose('y', 'x', 'band').values 
rgb = mask(normalize(rgb))

ax.imshow(rgb)
ax.axis('off')  
ax.set_aspect('equal')

In [None]:
# 20 m
fig, ax = plt.subplots(dpi=300)

rgb = ds_20m['reflectance'].sel(time=date, band=['B8A', 'B5', 'B6']).squeeze().transpose('y', 'x', 'band').values 
#rgb = normalize(rgb)

plt.imshow(rgb)
plt.axis('off')  
ax.set_aspect('equal')
rgb.min(), rgb.max()

In [None]:
# 20 m Sharpend
fig, ax = plt.subplots(dpi=300)

rgb = ds_20m_upscale['reflectance'].sel(time=date, band=['B8A', 'B5', 'B6']).squeeze().transpose('y', 'x', 'band').values 
rgb = mask(rgb)

plt.imshow(rgb)
plt.axis('off')  
ax.set_aspect('equal')
rgb.min(), rgb.max()

# Make geotiffs

In [None]:
ds_10m.rio.write_crs(32611, inplace=True)
ds_10m.rio.set_spatial_dims('x', 'y', inplace=True)
ds_10m['reflectance'].sel(time=date).squeeze().to_dataset('band').rio.to_raster(f'{date}_10m.tiff')

ds_20m.rio.write_crs(32611, inplace=True)
ds_20m.rio.set_spatial_dims('x', 'y', inplace=True)
ds_20m['reflectance'].sel(time=date).squeeze().to_dataset('band').rio.to_raster(f'{date}_20m.tiff')

ds_20m_upscale.rio.write_crs(32611, inplace=True)
ds_20m_upscale.rio.set_spatial_dims('x', 'y', inplace=True)
ds_20m_upscale['reflectance'].sel(time=date).squeeze().to_dataset('band').rio.to_raster(f'{date}_sharpend.tiff')

# Some testing

In [None]:
max = ds_20m['reflectance'][:, 1, :, :].max(dim=('x', 'y'))
max = max.compute()

In [None]:
plt.plot(max['time'], max)
plt.xlabel('time')
plt.ylabel('Max value')
plt.title('Maximum values across dim2, dim3, dim4')
plt.show()