In [1]:
%matplotlib notebook

import os
import re
import sys
import json
import shapely
import requests
import rasterio
import numpy as np
import pandas as pd
import datetime as dt
from cartopy import crs
import geopandas as gpd
import plotly.express as px
from rasterio.mask import mask
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from rasterio.transform import Affine
from shapely.geometry import mapping, Point

In [2]:
# each file is a tile of the visualization
directory_path = './data'
files = [file for file in os.listdir(directory_path) if file.endswith('med300.tif')]

In [3]:
# Adjust longitude from [0,360] to [-180,180]
def adjust_longitude(lon):
    if lon > 180:
        return lon - 360
        
    return lon

In [None]:
world_gdf = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world_gdf = world_gdf.set_geometry('geometry')   # set the active geometry
world_gdf['geometry'].csr = 'EPSG:4326'

In [5]:
# Calculate indices based on realative to geographic bounds
def calculate_indices(transform, width, height, min_lon, max_lon, min_lat, max_lat):
    # lon & lat for each pixel
    lon_per_pixel = transform.a
    lat_per_pixel = transform.e
    start_lon = transform.c
    start_lat = transform.f

    # indices
    col_start = max(int((start_lat - max_lat) / lon_per_pixel), 0)
    col_end = min(int((start_lat - min_lat) / lon_per_pixel), width)
    row_start = max(int((min_lon - start_lon) / abs(lat_per_pixel)), 0)
    row_end = min(int((max_lon - start_lon) / abs(lat_per_pixel)), height)

    return row_start, row_end, col_start, col_end

In [None]:
directory_path = './data'
files = [file for file in os.listdir(directory_path) if file.endswith('med300.tif')]

# create visualization
fig = go.Figure()

# pixles per tile
n = 250
global_min_elevation = 0
global_max_elevation = 8848 # meters

# Set geographic bounds
min_lon, max_lon = -180, 180  # Example longitude bounds
min_lat, max_lat = -90, 90  # Example latitude bounds

for file in files:

    with rasterio.open(directory_path + '/' + file) as src:
        img = src.read(1)  # Read the first band
        meta = src.meta

    transform = meta['transform']
    width, height = meta['width'], meta['height']

    row_start, row_end, col_start, col_end = calculate_indices(transform, width, height, min_lon, max_lon, min_lat, max_lat)

    n_lat = n_lon = n
    if row_end - row_start != height or col_end - col_start != width:
        n_lat = max(int(n * (row_end - row_start) / height), 1)
        n_lon = max(int(n * (col_end - col_start) / width), 1)
    
    # Generate longitude and latitude arrays within the bounds
    lon = np.linspace(transform.c + col_start * transform.a, transform.c + col_end * transform.a, n_lon)
    lat = np.linspace(transform.f + row_start * transform.e, transform.f + row_end * transform.e, n_lat)
    lon, lat = np.meshgrid(lon, lat)

    # Select data within bounds
    row_indices = np.linspace(row_start, row_end - 1, n_lat, dtype=int)
    col_indices = np.linspace(col_start, col_end - 1, n_lon, dtype=int)

    row_indices = np.clip(row_indices, 0, height - 1)
    col_indices = np.clip(col_indices, 0, width - 1)

    z_data = img[np.ix_(row_indices, col_indices)]

    # Add to plot
    fig.add_trace(go.Surface(
        z=z_data, x=lon, y=lat, 
        colorscale='IceFire',
        cmin=global_min_elevation,
        cmax=global_max_elevation/1.5,
        colorbar=dict(title='elevation(m)'),
        showlegend=False))

camera = dict(
    up=dict(x=0, y=0, z=1),  # sets the "up" direction in terms of the plot's x, y, z axes
    center=dict(x=-0.2, y=-0.2, z=-.02),  # R3 location
    eye=dict(x=-0.3, y=-0.45, z=.5)  # sets the position of the camera in x, y, z coordinates
)
fig.update_layout(
    title='Earth Terrain Basemap',
    showlegend=False,
    scene=dict(
        aspectmode='manual',
        aspectratio=dict(x=1.5, y=1, z=1),
        xaxis = dict(
            range=[min_lon, max_lon], title='Latitude',
            backgroundcolor="rgb(0,0,0)",
            gridcolor="grey",
            showbackground=True,
            zerolinecolor="grey"),
        yaxis = dict(
            range=[min_lat, max_lat], title='Longitude',
            backgroundcolor="rgb(0,0,0)",
            gridcolor="grey",
            showbackground=True,
            zerolinecolor="grey"),
        zaxis = dict(
            range=[0.0000001, 300000], title='Altitude',
            # backgroundcolor="rgb(102, 130, 212)",
            gridcolor="grey",
            showbackground=True,
            zerolinecolor="black",),
    ),
    width=1250,  # Width of the figure in pixels
    height=900,  # Height of the figure in pixels
    autosize=False,
    scene_camera=camera
)
coord_list = []

for shapely_object in world_gdf['geometry']:
    
    if shapely_object.geom_type == 'Polygon':
        coords = list(shapely_object.exterior.coords)
        coord_list.extend(coords)
        coord_list.append([None,None])
    elif shapely_object.geom_type == 'MultiPolygon':
        for polygon in shapely_object.geoms:
            coords = list(polygon.exterior.coords)
            coord_list.extend(coords)
            coord_list.append([None,None])

x, y = zip(*coord_list)
h = 2
z = h*np.ones(len(x))
fig.add_scatter3d(x=x, y=y, z=z, mode='lines', line_color="rgb(70, 70, 70)", line_width=1, showlegend=False)

fig.show()