In [None]:
import numpy as np
import os.path as osp
import pickle
import subprocess
from osgeo import gdal, osr
import matplotlib.pyplot as plt
import pandas as pd
# from utils import make_map, make_map2

## Get Station Data

In [None]:
# Data remote origin
file = "test_CA_202401.pkl"
durl = f"https://demo.openwfm.org/web/data/fmda/dicts/{file}"
# Local Path for dictionary
dpath = f"data/{file}" # 2m temperature band

In [None]:
if not osp.exists(dpath):
    subprocess.call(f"wget -P data {durl}", shell=True)

dat = pd.read_pickle(dpath)

## Map Station Locations

In [None]:
# Get df of STID lat/lon
# List of column names
# column_names = ['STID', 'lat', 'lon', 'elevation']

# Initialize DataFrame with column names and data types
locs = []


for k in dat:
    locs.append(dat[k]["loc"])

df = pd.DataFrame(locs)

# Show df
df

In [None]:
import plotly.express as px
import plotly.graph_objects as go

# Map stations, credit https://stackoverflow.com/questions/53233228/plot-latitude-longitude-from-csv-in-python-3-6

def make_map(df):
    fig = px.scatter_mapbox(df, 
                        lat="lat", 
                        lon="lon", 
                        hover_name="STID", 
                        zoom=6, 
                        height=600,
                        width=800)

    # Add Points
    center_lon=df['lon'].median()
    center_lat=df['lat'].median()
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_center=dict(lat=center_lat, lon=center_lon)
    )
    # Add Lines for Bounding Box
    
    fig.add_trace(go.Scattermapbox(
        mode="lines",
        lon=[df['lon'].min(), df['lon'].min(), df['lon'].max(), df['lon'].max(), df['lon'].min()],
        lat=[df['lat'].min(), df['lat'].max(), df['lat'].max(), df['lat'].min(), df['lat'].min()],
        marker=dict(size=8, color="black"),
        line=dict(width=1.5, color="black"),
        showlegend=False
    ))
    
    fig.update_traces(marker=dict(size=10))
    fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
    return fig

In [None]:
make_map(df)

## Add Raster Background

In [None]:
# Geotiff file
turl = "https://demo.openwfm.org/web/data/fmda/tif/20240101/hrrr.t00z.wrfprsf00.616.tif"
tpath = "data/hrrr.t00z.wrfprsf00.616.tif"
if not osp.exists(tpath):
    subprocess.call(f"wget -P data {turl}", shell=True)

ds = gdal.Open(tpath)
band = ds.GetRasterBand(1)
data = band.ReadAsArray()
gt = ds.GetGeoTransform()
gp = ds.GetProjection()

In [None]:
print(type(data))

### Convert to lon/lat

In [None]:
# Generate Arrays of pixel indices
from osgeo import osr, ogr, gdal

def pixel_to_world(geo_matrix, x, y):
    # Given geotransform info of a geotiff file and an (x,y) pixel coord pair, return the coord pair that matches the geotiff in meters
    # Inputs: 
    # geomatrix: output of ds.GetGeoTransform() for given geotiff file
    # tuple of length 6 contains: 
    # A geotransform consists in a set of 6 coefficients
    # GT(0) x-coordinate of the upper-left corner of the upper-left pixel.
    # GT(1) w-e pixel resolution / pixel width.
    # GT(2) row rotation (typically zero).
    # GT(3) y-coordinate of the upper-left corner of the upper-left pixel.
    # GT(4) column rotation (typically zero).
    # GT(5) n-s pixel resolution / pixel height (negative value for a north-up image).
    # x: pixel index x coord (1)
    # y: pixel index y coord (1)
    # Return: coordinates of same point as given x,y as offset from UL (m)
    # Example: pixel_to_world(mat, 0, 0) returns UL x,y from geotiff
    
    ul_x = geo_matrix[0]
    ul_y = geo_matrix[3]
    x_dist = geo_matrix[1]
    y_dist = geo_matrix[5]
    _x = x * x_dist + ul_x
    _y = y * y_dist + ul_y
    return _x, _y


def build_transform_inverse(dataset, EPSG):
    # Given gdal dataset and target EPSG, return transformation function that transforms meter coord pairs to pixel coord pairs 
    # Inputs:
    # dataset: geotiff file
    # EPSG: integer
    source = osr.SpatialReference(wkt=dataset.GetProjection())
    target = osr.SpatialReference()
    target.ImportFromEPSG(EPSG)
    return osr.CoordinateTransformation(source, target)

def world_to_epsg(wx, wy, trans):
    # Inputs:
    # wx, wy: output of build_transform_inverse
    # wx: x coordinate (m) related to geotiff reference point
    # wy: y coordinate (m) related to geotiff reference point
    # transform: function to transform to given epsg, function type is osgeo.osr.CoordinateTransformation
    # Return: 
    # point from osgeo Geometry object
    point = ogr.Geometry(ogr.wkbPoint)
    point.AddPoint(wx, wy)
    point.Transform(trans)
    return point

def find_spatial_coordinate_from_pixel(dataset, x, y, transform=None, epsg=4326):
    # Given gdal dataset, target x y pixel pair, and EPSG, return the EPSG defined coordinate pair 
    # dataset: gdal dataset, from geotiff file
    # x (int): pixel x index 
    # y (int): pixel y index 
    ## Upper left corner is often (0,0)
    # transform: transform inverse. output of build_transform_inverse, default none and it calculates from epsg
    # supply transform to save computational time
    # epsg: default 4326 (WGS84)
    # Return: coord pair in given epsg, eg lat/lon (floats)
    if transform is None:
        transform = build_transform_inverse(ds, epsg)
    world_x, world_y = pixel_to_world(dataset.GetGeoTransform(), x, y)
    point = world_to_epsg(world_x, world_y, transform)
    return point.GetX(), point.GetY()

In [None]:
# Initialize empty arrays
lons=np.zeros(np.shape(data))
lats=np.zeros(np.shape(data))

# get transformation once and reuse
transform = build_transform_inverse(ds, EPSG=4326)
# Loop over indices and fill
for i in range(0, np.shape(lons)[0]): # iterate i over x coord (longitude)
    for j in range(0, np.shape(lons)[1]): # iterate j over y coord (latitude)
        coord = find_spatial_coordinate_from_pixel(ds, j, i, transform=transform) # note order flip is intentional
        lats[i,j]=coord[0]
        lons[i,j]=coord[1]

In [None]:
# Get lons and lats within range of df
min_lon = df['lon'].min()
max_lon = df['lon'].max()
min_lat = df['lat'].min()
max_lat = df['lat'].max()

## THIS DOESNT WORK YET
mask_lon = (lons >= min_lon) & (lons <= max_lon)
mask_lat = (lats >= min_lat) & (lats <= max_lat)
lon_subset = lons[mask_lon]
lat_subset = lats[mask_lat]