In [1]:
import sys 
sys.path.append("/home/jovyan/work/notebooks/satellite_data/SA_segmentation/pytorch_segmentation/planetunet")
import os
import warnings
import numpy as np
import rasterio
from osgeo import gdal
import matplotlib.pyplot as plt
from config.config_types import PostprocessingConfig
import postprocessing

In [2]:
years = list(range(2014,2021))
path = "/home/jovyan/work/satellite_data/tmp/inference/"
model_name = "smp_unet_mitb3_08_03_2023_170715.pth"

# Update year VRTs

In [3]:
for i in years:
    print(f"Year: {str(i)}")
    if os.path.isfile(os.path.join(path,model_name,str(i)+".vrt")):
        os.remove(os.path.join(path,model_name,str(i)+".vrt"))
    cmd = f"cd {os.path.join(path,model_name)} && gdalbuildvrt {str(i)}.vrt {str(i)}/*_*.tif"
    os.system(cmd)

Year: 2014
0...10...20...30...40...50...60...70...80...90...100 - done.
Year: 2015
0...10...20...30...40...50...60...70...80...90...100 - done.
Year: 2016
0...10...20...30...40...50...60...70...80...90...100 - done.
Year: 2017
0...10...20...30...40...50...60...70...80...90...100 - done.
Year: 2018
0...10...20...30...40...50...60...70...80...90...100 - done.
Year: 2019
0...10...20...30...40...50...60...70...80...90...100 - done.
Year: 2020
0...10...20...30...40...50...60...70...80...90...100 - done.


In [4]:
# import math
# raster_fp = os.path.join(path,model_name,"2014","tmp_shape_15_03_2023_213818_105.tif")
# src = rasterio.open(raster_fp)
# out_fp = os.path.join(path,model_name,"test2.tif")
# latitude = (src.bounds[3] + src.bounds[1]) / 2
# x_res = 10 / (111320 * math.cos(math.radians(abs(latitude))))  # at equator 1°lon ~= 111.32 km
# y_res = 10 / 110540  # and        1°lat ~= 110.54 km


# warp_options = dict(
#     xRes=x_res,
#     yRes=y_res,
#     srcNodata=255,
#     dstNodata=None,
#     resampleAlg=gdal.GRA_Average, #gdal.GRA_NearestNeighbour ,#
#     outputType=gdal.GDT_Float32,
#     creationOptions=["COMPRESS=LZW", "TILED=YES", "BIGTIFF=IF_SAFER", "NUM_THREADS=ALL_CPUS"],
#     warpOptions=["NUM_THREADS=ALL_CPUS"],
#     warpMemoryLimit=1000000000,
#     multithread=True
# )
# ds = gdal.Warp(out_fp, raster_fp, **warp_options)
# del ds

# Create cover maps

In [None]:
for i in years:
    print(f"Year: {str(i)}")
    cmd = f"rm {os.path.join(path,model_name,str(i))}/canopy_cover_rasters/*"
    os.system(cmd)
    postproc_config = \
            PostprocessingConfig(run_name="test",
                                 postprocessing_dir=os.path.join(path,model_name,str(i)),
                                 create_polygons=False,
                                 create_centroids=False,
                                 create_density_maps=False,
                                 create_canopy_cover_maps=True,
                                 postproc_workers=25,
                                 postproc_gridsize=(8, 8),
                                 canopy_resolutions=(100,),
                                 density_resolutions=(100,),
                                 area_thresholds=(3, 15, 50, 200),
                                 canopy_map_dtype='float32',
                                 #canopy_map_dtype='uint8',
                                 no_vsimem=True
                                 )
    postprocessing.postprocess_all(postproc_config)

Year: 2014
Starting postprocessing.
Postprocessing predictions in /home/jovyan/work/satellite_data/tmp/inference/smp_unet_mitb3_08_03_2023_170715.pth/2014
Creating canopy cover maps


Creating 100m canopy cover map:   0%|          | 0/237 [00:00<?, ?it/s]

In [6]:
cmd = f"cd {os.path.join(path,model_name)} && gdalbuildvrt -srcnodata 255 -vrtnodata 255 -overwrite -r average SA_cover_map.vrt */canopy_cover_rasters/canopy_cover_*m_test.tif"

In [7]:
os.system(cmd)

0...10...20...30...40...50...60...70...80...90...100 - done.


0

# Byte map

In [8]:
options = dict(
    format="Gtiff",
    outputType=gdal.GDT_Byte,#gdal.GDT_Float32,#
    resampleAlg=gdal.GRA_Average,
    noData=255,
    creationOptions=["BIGTIFF=IF_SAFER", "COMPRESS=LZW", "PREDICTOR=2","NUM_THREADS=ALL_CPUS"],
    scaleParams=[[0, 100, 0, 254]],
    stats=True,
)
ds = gdal.Translate(os.path.join(path,model_name,"SA_cover_map_org.tif"), os.path.join(path,model_name,"SA_cover_map.vrt"), **options)
del ds

In [9]:
ds = gdal.Open(os.path.join(path,model_name,"SA_cover_map_org.tif"))

# Read the input raster data into a numpy array
band = ds.GetRasterBand(1)
data = band.ReadAsArray()

# Apply the viridis colormap to the input data
colormap = gdal.ColorTable()
for i in range(256):
    color = plt.cm.viridis(i/255.0)[:3]
    colormap.SetColorEntry(i, (int(color[0]*255), int(color[1]*255), int(color[2]*255), 255))

# Create a new raster file with the viridis colormap
output_file = os.path.join(path,model_name,"SA_cover_map.tif")

if os.path.isfile(output_file):
    os.remove(output_file)
driver = gdal.GetDriverByName('Gtiff')
out_ds = driver.Create(output_file, ds.RasterXSize, ds.RasterYSize, 1, gdal.GDT_Byte,options=["COMPRESS=LZW","PREDICTOR=2"])
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
#out_ds.SetCompression(gdal.COMPRESSION_LZW)
out_band = out_ds.GetRasterBand(1)
out_band.SetRasterColorTable(colormap)
out_band.SetNoDataValue(255)
out_band.WriteArray(data)

# Clean up
del ds, band, out_band, out_ds

cmd = f"gdaladdo {output_file} 4 16 32 64 128 256 512 1024"
os.system(cmd)

0...10...20...30...40...50...60...70...80...90...100 - done.


0

# Float map

In [18]:
options = dict(
    format="Gtiff",
    outputType=gdal.GDT_Float32,#
    resampleAlg=gdal.GRA_Average,
    noData=255,
    creationOptions=["BIGTIFF=IF_SAFER", "COMPRESS=LZW", "PREDICTOR=2","NUM_THREADS=ALL_CPUS"],
    scaleParams=[[0, 100, 0, 254]],
    stats=True,
)
ds = gdal.Translate(os.path.join(path,model_name,"SA_cover_map_org_float32.tif"), os.path.join(path,model_name,"SA_cover_map.vrt"), **options)
del ds

# Florian

In [76]:
ds = gdal.Open("/home/jovyan/work/satellite_data/ku_sync/South_Africa/tree_cover_map/florian_tree_cover_map.tif")

# Read the input raster data into a numpy array
band = ds.GetRasterBand(1)
data = band.ReadAsArray()

data = data + 1
# Normalize the input data to the range [0, 1]
data_norm = (data - np.min(data)) / (np.max(data) - np.min(data)) #instead of 0

# Scale the normalized data to the range [0, 255]
data = (data_norm * 255).astype(np.uint8)

# Apply the viridis colormap to the input data
colormap = gdal.ColorTable()
for i in range(256):
    color = plt.cm.viridis(i/255.0)[:3]
    colormap.SetColorEntry(i, (int(color[0]*255), int(color[1]*255), int(color[2]*255), 255))

# Create a new raster file with the viridis colormap
output_file = "/home/jovyan/work/satellite_data/ku_sync/South_Africa/tree_cover_map/florian_tree_cover_map_scaled.tif"

if os.path.isfile(output_file):
    os.remove(output_file)
driver = gdal.GetDriverByName('Gtiff')
out_ds = driver.Create(output_file, ds.RasterXSize, ds.RasterYSize, 1, gdal.GDT_Byte,options=["COMPRESS=LZW","PREDICTOR=2"])
out_ds.SetProjection(ds.GetProjection())
out_ds.SetGeoTransform(ds.GetGeoTransform())
#out_ds.SetCompression(gdal.COMPRESSION_LZW)
out_band = out_ds.GetRasterBand(1)
out_band.SetRasterColorTable(colormap)
out_band.SetNoDataValue(0)
out_band.WriteArray(data)

# Clean up
del ds, band, out_band, out_ds

cmd = f"gdaladdo {output_file} 4 16 32 64 128 256 512 1024"
os.system(cmd)

0...10...20...30...40...50...60...70...80...90...100 - done.


0

In [13]:
import math
raster_fp = "/home/jovyan/work/satellite_data/tmp/inference/smp_unet_mitb3_08_03_2023_170715.pth/2018/tmp_shape_15_03_2023_001614_244.tif"
src = rasterio.open(raster_fp)
out_fp = os.path.join("test2.tif")
latitude = (src.bounds[3] + src.bounds[1]) / 2
x_res = 100 / (111320 * math.cos(math.radians(abs(latitude))))  # at equator 1°lon ~= 111.32 km
y_res = 100 / 110540  # and        1°lat ~= 110.54 km


warp_options = dict(
    xRes=x_res,
    yRes=y_res,
    srcNodata=255,
    dstNodata=None,
    resampleAlg=gdal.GRA_Average, #gdal.GRA_NearestNeighbour ,#
    outputType=gdal.GDT_Float32,
    creationOptions=["COMPRESS=LZW", "TILED=YES", "BIGTIFF=IF_SAFER", "NUM_THREADS=ALL_CPUS"],
    warpOptions=["NUM_THREADS=ALL_CPUS"],
    warpMemoryLimit=1000000000,
    multithread=True
)
ds = gdal.Warp(out_fp, raster_fp, **warp_options)
del ds

In [14]:
src = rasterio.open(out_fp)
arr = src.read()

In [16]:
arr.min()

0.0