In [None]:
import numpy as np
import os.path as osp
import pickle
import subprocess
from osgeo import gdal, osr
import matplotlib.pyplot as plt
import pandas as pd
# from utils import make_map, make_map2

## Get Station Data

In [None]:
# Data remote origin
file = "test_CA_202401.pkl"
durl = f"https://demo.openwfm.org/web/data/fmda/dicts/{file}"
# Local Path for dictionary
dpath = f"data/{file}" # 2m temperature band

In [None]:
if not osp.exists(dpath):
    subprocess.call(f"wget -P data {durl}", shell=True)

dat = pd.read_pickle(dpath)

## Map Station Locations

In [None]:
# Get df of STID lat/lon
# List of column names
# column_names = ['STID', 'lat', 'lon', 'elevation']

# Initialize DataFrame with column names and data types
locs = []
for k in dat:
    locs.append(dat[k]["loc"])

df = pd.DataFrame(locs)
# Show df
df

In [None]:
import plotly.express as px
import plotly.graph_objects as go

# Map stations, credit https://stackoverflow.com/questions/53233228/plot-latitude-longitude-from-csv-in-python-3-6

def make_st_map(df):
    fig = go.Figure(go.Scattermapbox(
        lat=df['lat'],
        lon=df['lon'],
        mode='markers',
        marker=go.scattermapbox.Marker(
            size=10,
            opacity=0.7,
        ),
        text=df['STID'],
        showlegend=False  # Turn off legend
    ))

    # Add Points
    center_lon=df['lon'].median()
    center_lat=df['lat'].median()
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_center=dict(lat=center_lat, lon=center_lon)
    )
    # Add Lines for Bounding Box
    
    fig.add_trace(go.Scattermapbox(
        mode="lines",
        lon=[df['lon'].min(), df['lon'].min(), df['lon'].max(), df['lon'].max(), df['lon'].min()],
        lat=[df['lat'].min(), df['lat'].max(), df['lat'].max(), df['lat'].min(), df['lat'].min()],
        marker=dict(size=5, color="black"),
        line=dict(width=1.5, color="black"),
        showlegend=False
    ))
    
    fig.update_layout(
        margin={"r":0,"t":0,"l":0,"b":0},
        mapbox_zoom =5,
        mapbox_center={"lat": np.median(df.lat), "lon": np.median(df.lon)},  # Center the map on desired location
    )
    return fig

In [None]:
make_st_map(df)

## Add Raster Background

### Read Data 

In [None]:
# Geotiff file
turl = "https://demo.openwfm.org/web/data/fmda/tif/20240101/hrrr.t00z.wrfprsf00.616.tif"
tpath = "data/hrrr.t00z.wrfprsf00.616.tif"
if not osp.exists(tpath):
    subprocess.call(f"wget -P data {turl}", shell=True)

ds = gdal.Open(tpath)
band = ds.GetRasterBand(1)
data = band.ReadAsArray()
gt = ds.GetGeoTransform()
gp = ds.GetProjection()

In [None]:
print(type(data))

In [None]:
plt.imshow(data)

### Trim w bbox

In [None]:
# Format xmin,ymin,xmax,ymax
bbox = [df.lon.min(),df.lat.min(),
                       df.lon.max(),df.lat.max()]

In [None]:
def get_projection_info(ds, epsg = 4326):
    # Given a geotiff file (a HRRR band), 
    # return info necessary to transform lat/lon coords to the file structure
    # Inputs: 
    # ds: (osgeo.gdal.Dataset)
    # epsg: (int) default 4326 for lon/lat
    # Return: (tuple) with fields (ct, g_inv)
        # ct: (osgeo.osr.CoordinateTransformation)
        # gt_inv: (tuple) output of gdal.InvGeoTransform, also could be found with gdalinfo on command line
    gt = ds.GetGeoTransform()
    gp = ds.GetProjection()
    if(ds.RasterCount>1):
        print('Not Implemented for multiple Raster bands')
        sys.exit(-1)
    # Get Projection info
    point_srs = osr.SpatialReference()
    point_srs.ImportFromEPSG(4326) # hardcode for lon/lat
    # GDAL>=3: make sure it's x/y
    # see https://trac.osgeo.org/gdal/wiki/rfc73_proj6_wkt2_srsbarn
    point_srs.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
    file_srs = osr.SpatialReference()
    file_srs.ImportFromWkt(gp)
    ct = osr.CoordinateTransformation(point_srs, file_srs)
    gt_inv = gdal.InvGeoTransform(gt)

    return ct, gt_inv

In [None]:
ct, gt_inv = get_projection_info(ds)

In [None]:
xmin, ymin, z = ct.TransformPoint(bbox[0], bbox[1])
xmax, ymax, z = ct.TransformPoint(bbox[2], bbox[3])

In [None]:
ulx, uly = gdal.ApplyGeoTransform(gt_inv, xmin, ymax)
urx, ury = gdal.ApplyGeoTransform(gt_inv, xmax, ymax)
llx, lly = gdal.ApplyGeoTransform(gt_inv, xmin, ymin)
lrx, lry = gdal.ApplyGeoTransform(gt_inv, xmax, ymin)

In [None]:
# Check Points
plt.imshow(data)
plt.scatter(ulx, uly, color='k')
plt.scatter(urx, ury, color='r')
plt.scatter(llx, lly, color='b')
plt.scatter(lrx, lry, color='y')

In [None]:
# Read Band with bounds
data_subset=ds.ReadAsArray(xoff=ulx,
                           yoff=uly,
                           xsize=int(urx - ulx),
                           ysize=int(lry - ury))

In [None]:
plt.imshow(data_subset)

### Convert to lon/lat

In [None]:
# Generate Arrays of pixel indices
from osgeo import osr, ogr, gdal

def pixel_to_world(geo_matrix, x, y):
    # Given geotransform info of a geotiff file and an (x,y) pixel coord pair, return the coord pair that matches the geotiff in meters
    # Inputs: 
    # geomatrix: output of ds.GetGeoTransform() for given geotiff file
    # tuple of length 6 contains: 
    # A geotransform consists in a set of 6 coefficients
    # GT(0) x-coordinate of the upper-left corner of the upper-left pixel.
    # GT(1) w-e pixel resolution / pixel width.
    # GT(2) row rotation (typically zero).
    # GT(3) y-coordinate of the upper-left corner of the upper-left pixel.
    # GT(4) column rotation (typically zero).
    # GT(5) n-s pixel resolution / pixel height (negative value for a north-up image).
    # x: pixel index x coord (1)
    # y: pixel index y coord (1)
    # Return: coordinates of same point as given x,y as offset from UL (m)
    # Example: pixel_to_world(mat, 0, 0) returns UL x,y from geotiff
    
    ul_x = geo_matrix[0]
    ul_y = geo_matrix[3]
    x_dist = geo_matrix[1]
    y_dist = geo_matrix[5]
    _x = x * x_dist + ul_x
    _y = y * y_dist + ul_y
    return _x, _y


def build_transform_inverse(dataset, EPSG):
    # Given gdal dataset and target EPSG, return transformation function that transforms meter coord pairs to pixel coord pairs 
    # Inputs:
    # dataset: geotiff file
    # EPSG: integer
    source = osr.SpatialReference(wkt=dataset.GetProjection())
    target = osr.SpatialReference()
    target.ImportFromEPSG(EPSG)
    return osr.CoordinateTransformation(source, target)

def world_to_epsg(wx, wy, trans):
    # Inputs:
    # wx, wy: output of build_transform_inverse
    # wx: x coordinate (m) related to geotiff reference point
    # wy: y coordinate (m) related to geotiff reference point
    # transform: function to transform to given epsg, function type is osgeo.osr.CoordinateTransformation
    # Return: 
    # point from osgeo Geometry object
    point = ogr.Geometry(ogr.wkbPoint)
    point.AddPoint(wx, wy)
    point.Transform(trans)
    return point

def find_spatial_coordinate_from_pixel(dataset, x, y, transform=None, epsg=4326):
    # Given gdal dataset, target x y pixel pair, and EPSG, return the EPSG defined coordinate pair 
    # dataset: gdal dataset, from geotiff file
    # x (int): pixel x index 
    # y (int): pixel y index 
    ## Upper left corner is often (0,0)
    # transform: transform inverse. output of build_transform_inverse, default none and it calculates from epsg
    # supply transform to save computational time
    # epsg: default 4326 (WGS84)
    # Return: coord pair in given epsg, eg lat/lon (floats)
    if transform is None:
        transform = build_transform_inverse(ds, epsg)
    world_x, world_y = pixel_to_world(dataset.GetGeoTransform(), x, y)
    point = world_to_epsg(world_x, world_y, transform)
    return point.GetX(), point.GetY()

In [None]:
# Initialize empty arrays
lons=np.zeros(np.shape(data))
lats=np.zeros(np.shape(data))

# get transformation once and reuse
transform = build_transform_inverse(ds, EPSG=4326)
# Loop over indices and fill
for i in range(0, np.shape(lons)[0]): # iterate i over x coord (longitude)
    for j in range(0, np.shape(lons)[1]): # iterate j over y coord (latitude)
        coord = find_spatial_coordinate_from_pixel(ds, j, i, transform=transform) # note order flip is intentional
        lats[i,j]=coord[0]
        lons[i,j]=coord[1]

In [None]:
# Get lons and lats within range of df
min_lon = df['lon'].min()
max_lon = df['lon'].max()
min_lat = df['lat'].min()
max_lat = df['lat'].max()

## THIS DOESNT WORK YET
mask_lon = (lons >= min_lon) & (lons <= max_lon)
mask_lat = (lats >= min_lat) & (lats <= max_lat)
mask = np.logical_and(mask_lon, mask_lat)
lon_subset = lons[mask]
lat_subset = lats[mask]
vals_subset = data[mask]

print("Shapes:")
print(f"Lons: {lon_subset.shape}")
print(f"Lats: {lat_subset.shape}")
print(f"Vals: {vals_subset.shape}")

### Map Fields of vals