In [34]:
import os
import xarray as xr
import pandas as pd
import numpy as np
import pygrib
import netCDF4
import scipy
import matplotlib.pyplot as plt
DATA_DIR="/root/data_downloads"
era5_multi_level_dir=f"{DATA_DIR}/era5/multi"
era5_single_level_dir=f"{DATA_DIR}/era5/single"


## Get get ibtracs labels for current year

In [3]:
s3key = "data/ibtracs/IBTrACS.since1980.v04r00.nc"
ibtracs_dir = f"{DATA_DIR}/ibtracs"
%alias mkdatadir mkdir -p %l
%mkdatadir $ibtracs_dir
# dest = "/data_downloads/ibtracs/IBTrACS.ALL.v04r00.nc"
ibtracs_dest = f"{ibtracs_dir}/IBTrACS.since1980.v04r00.nc"
if os.path.exists(ibtracs_dest):
    print(f"{ibtracs_dest} already downloaded")
else:
    s3 = boto3.client('s3')
    s3.download_file(bucket, s3key, ibtracs_dest)
    print(f"downloaded {ibtracs_dest}")

/root/data_downloads/ibtracs/IBTrACS.since1980.v04r00.nc already downloaded


In [45]:
ibtracs_nc = netCDF4.Dataset(ibtracs_dest)
"""
Need to combine the datasets by hour.  Will want to OR the storms together where they share hours 
to produce a global mask of TC labels.  This will then be included in the outputted netcdf files.
"""
def ibtracs_time_to_pd_timestamp(ibtracs_time):
    ts = pd.to_datetime(ibtracs_time)
    return ts


ibtracs_roci = ibtracs_nc.variables["bom_roci"]
start_year = 2018
end_year = 2018
timestamp_storms = {}
print_freq = 100
for storm_idx, storm_times in enumerate(ibtracs_nc.variables["iso_time"][:]):
    for time_idx, time in enumerate(storm_times):
        timestamp_str = time.tobytes().decode("utf-8")
        if not timestamp_str:
            break # we have hit the last timestamp for the storm
        # Convert byte arrays to pandas Timestamp
        timestamp = ibtracs_time_to_pd_timestamp(timestamp_str) # 
        if start_year <= timestamp.year <= end_year:
            storm_tup = (storm_idx, time_idx)
            if timestamp_str in timestamp_storms:
                timestamp_storms[timestamp_str].append(storm_tup)
            else:
                timestamp_storms[timestamp_str] = [storm_tup]
    if storm_idx % print_freq == 0:
        print(f"{storm_idx} of {ibtracs_nc.variables['iso_time'].shape[0]}")


0 of 4687
100 of 4687
200 of 4687
300 of 4687
400 of 4687
500 of 4687
600 of 4687
700 of 4687
800 of 4687
900 of 4687
1000 of 4687
1100 of 4687
1200 of 4687
1300 of 4687
1400 of 4687
1500 of 4687
1600 of 4687
1700 of 4687
1800 of 4687
1900 of 4687
2000 of 4687
2100 of 4687
2200 of 4687
2300 of 4687
2400 of 4687
2500 of 4687
2600 of 4687
2700 of 4687
2800 of 4687
2900 of 4687
3000 of 4687
3100 of 4687
3200 of 4687
3300 of 4687
3400 of 4687
3500 of 4687
3600 of 4687
3700 of 4687
3800 of 4687
3900 of 4687
4000 of 4687
4100 of 4687
4200 of 4687
4300 of 4687
4400 of 4687
4500 of 4687
4600 of 4687


In [31]:
print(f"Storms found at {len(timestamp_storms)} timestamps in the {start_year}-{end_year} year range")

Storms found at 2759 timestamps in the 2018-2018 year range


## Process ERA5 data and output augmented netcdf file for each day
Note: time in this dataset has units: hours since 1900-01-01 00:00:00.0
can be converted to a pandas timestamp with
```
pd.to_datetime(timestamp, unit='h', origin='1900-01-01 00:00:00.0')
```
and can do the reverse conversion via
```
origin = pd.Timestamp('1900-01-01 00:00:00')
hours_since_1900 = int((pandas_timestamp - origin).total_seconds() / 3600)
```


In [63]:
year = "2018"
month = "01"
day = "01"
# filename=f'{era5_single_level_dir}/{year}-{month}-{day}.grib'
single_file=f'{era5_single_level_dir}/{year}-{month}-{day}.nc'
multi_file=f'{era5_multi_level_dir}/{year}-{month}-{day}.nc'

single_nc =  netCDF4.Dataset(single_file)
multi_nc =  netCDF4.Dataset(multi_file)
single_vars = single_nc.variables.keys()
multi_vars = multi_nc.variables.keys()


mtpr = xr.DataArray(single_nc.variables["mtpr"], coords={
    'time': single_nc.variables["time"][:],
    'lat': single_nc.variables["latitude"][:], 
    'lon': single_nc.variables["longitude"][:]}, 
    dims=('time','lat', 'lon'))

sp = xr.DataArray(single_nc.variables["sp"], coords={
    'time': single_nc.variables["time"][:],
    'lat': single_nc.variables["latitude"][:], 
    'lon': single_nc.variables["longitude"][:]}, 
    dims=('time','lat', 'lon'))

tcwv = xr.DataArray(single_nc.variables["tcwv"], coords={
    'time': single_nc.variables["time"][:],
    'lat': single_nc.variables["latitude"][:], 
    'lon': single_nc.variables["longitude"][:]}, 
    dims=('time','lat', 'lon'))

pv850 = xr.DataArray(multi_nc.variables["pv"], coords={
    'time': multi_nc.variables["time"][:],
    'lat': multi_nc.variables["latitude"][:], 
    'lon': multi_nc.variables["longitude"][:]}, 
    dims=('time','lat', 'lon'))

u850 = xr.DataArray(multi_nc.variables["u"], coords={
    'time': multi_nc.variables["time"][:],
    'lat': multi_nc.variables["latitude"][:], 
    'lon': multi_nc.variables["longitude"][:]}, 
    dims=('time','lat', 'lon'))

v850 = xr.DataArray(multi_nc.variables["v"], coords={
    'time': multi_nc.variables["time"][:],
    'lat': multi_nc.variables["latitude"][:], 
    'lon': multi_nc.variables["longitude"][:]}, 
    dims=('time','lat', 'lon'))

ws850 = (u850**2 + v850**2)**0.5
print(v850, ws850)


<xarray.DataArray 'v' (time: 8, lat: 721, lon: 1440)>
array([[[ 1.19968920e-04,  1.19968920e-04,  1.19968920e-04, ...,
          1.19968920e-04,  1.19968920e-04,  1.19968920e-04],
        [-2.07003972e+00, -2.07580619e+00, -2.08041936e+00, ...,
         -2.05389363e+00, -2.06081339e+00, -2.06542656e+00],
        [-1.80708908e+00, -1.81170225e+00, -1.81631542e+00, ...,
         -1.79440286e+00, -1.79901603e+00, -1.80362920e+00],
        ...,
        [-6.46293013e+00, -6.47676963e+00, -6.48945585e+00, ...,
         -6.42256490e+00, -6.43640440e+00, -6.45024391e+00],
        [-6.70396822e+00, -6.71434785e+00, -6.72588077e+00, ...,
         -6.67167603e+00, -6.68205566e+00, -6.69358859e+00],
        [ 1.19968920e-04,  1.19968920e-04,  1.19968920e-04, ...,
          1.19968920e-04,  1.19968920e-04,  1.19968920e-04]],

       [[ 1.19968920e-04,  1.19968920e-04,  1.19968920e-04, ...,
          1.19968920e-04,  1.19968920e-04,  1.19968920e-04],
        [-8.47549878e-01, -8.59082801e-01, -8.694

## Compute labels

In [79]:
def haversine_distance(lat, lon, lat_center, lon_center):
    """
    returns matrix of dim (lat,lon) with distance in nautical miles
    """
    lat1, lon1, lat2, lon2 = map(np.radians, [lat, lon, lat_center, lon_center])
    # Haversine formula
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = np.sin(dlat[:,None] / 2) ** 2 + np.cos(lat1[:,None]) * np.cos(lat2) * np.sin(dlon[None,:] / 2) ** 2
    c = 2 * np.arcsin(np.sqrt(a))
    r = 6371  # Radius of the Earth in kilometers

    # Calculate the distance in 
    distance = c * r
    # 1.852 km per nautical mile
    return distance / 1.852


def make_labeled_dataset(timestamp, timestep_labels)->None:
    """
    Take a timestamp and the labels for the timestamp, combine with
    ERA5 data, regrid, and return an xarray Dataset.  These datasets
    can later be concatenated along the time dimension before being 
    written to a file.
    
    We try to match the format of the climatenet augmented dataset.
    The ClimateDatasetLabeled class concatenates files across the 
    time dimension, so we can put as many timesteps as we want in 
    each file. Here we take the timestep and find the matching 
    ERA5 file and extract the variables for the corresponding timestamp.
    
    
    Only some variables are provided here.  If more are desired, they
    can be additionally downloaded from copernicus. 
    
    Augmented climatenet dataset:
    dimensions(sizes): lat(768), lon(1152), time(1)
    variables(dimensions): float64 lat(lat), float64 lon(lon), 
        float32 TMQ(time, lat, lon), float32 U850(time, lat, lon), 
        float32 V850(time, lat, lon), float32 UBOT(time, lat, lon), 
        float32 VBOT(time, lat, lon), float32 QREFHT(time, lat, lon), 
        float32 PS(time, lat, lon), float32 PSL(time, lat, lon), 
        float32 T200(time, lat, lon), float32 T500(time, lat, lon), 
        float32 PRECT(time, lat, lon), float32 TS(time, lat, lon), 
        float32 TREFHT(time, lat, lon), float32 Z1000(time, lat, lon), 
        float32 Z200(time, lat, lon), float32 ZBOT(time, lat, lon), 
        <class 'str'> time(time), int64 LABELS(lat, lon)
    """
    origin = pd.Timestamp('1900-01-01 00:00:00')
    pd_timestamp = pd.to_datetime(timestamp)
    hours_since_1900 = int((pd_timestamp - origin).total_seconds() / 3600)
    year = pd_timestamp.year
    month = pd_timestamp.strftime('%m')
    day = pd_timestamp.strftime('%d')
    # filename=f'{era5_single_level_dir}/{year}-{month}-{day}.grib'
    single_file=f'{era5_single_level_dir}/{year}-{month}-{day}.nc'
    multi_file=f'{era5_multi_level_dir}/{year}-{month}-{day}.nc'

    single_nc =  netCDF4.Dataset(single_file)
    multi_nc =  netCDF4.Dataset(multi_file)
    single_vars = single_nc.variables.keys()
    multi_vars = multi_nc.variables.keys()


    mtpr = xr.DataArray(single_nc.variables["mtpr"], name="PRECT", coords={
        'time': single_nc.variables["time"][:],
        'lat': single_nc.variables["latitude"][:], 
        'lon': single_nc.variables["longitude"][:]}, 
        dims=('time','lat', 'lon')).sel(time=hours_since_1900)

    sp = xr.DataArray(single_nc.variables["sp"], name="PSL", coords={
        'time': single_nc.variables["time"][:],
        'lat': single_nc.variables["latitude"][:], 
        'lon': single_nc.variables["longitude"][:]}, 
        dims=('time','lat', 'lon')).sel(time=hours_since_1900)

    tcwv = xr.DataArray(single_nc.variables["tcwv"], name="TMQ", coords={
        'time': single_nc.variables["time"][:],
        'lat': single_nc.variables["latitude"][:], 
        'lon': single_nc.variables["longitude"][:]}, 
        dims=('time','lat', 'lon')).sel(time=hours_since_1900)

    pv850 = xr.DataArray(multi_nc.variables["pv"], name="VRT850", coords={
        'time': multi_nc.variables["time"][:],
        'lat': multi_nc.variables["latitude"][:], 
        'lon': multi_nc.variables["longitude"][:]}, 
        dims=('time','lat', 'lon')).sel(time=hours_since_1900)

    u850 = xr.DataArray(multi_nc.variables["u"], name="U850", coords={
        'time': multi_nc.variables["time"][:],
        'lat': multi_nc.variables["latitude"][:], 
        'lon': multi_nc.variables["longitude"][:]}, 
        dims=('time','lat', 'lon')).sel(time=hours_since_1900)

    v850 = xr.DataArray(multi_nc.variables["v"], name="V850", coords={
        'time': multi_nc.variables["time"][:],
        'lat': multi_nc.variables["latitude"][:], 
        'lon': multi_nc.variables["longitude"][:]}, 
        dims=('time','lat', 'lon')).sel(time=hours_since_1900)

    ws850 = ((u850**2 + v850**2)**0.5).rename("WS850")
    
    labels = xr.DataArray(timestep_labels[None,:,:], name="LABELS", coords={
        'time': np.array([hours_since_1900]),
        'lat': multi_nc.variables["latitude"][:], 
        'lon': multi_nc.variables["longitude"][:]},
        dims=('time','lat', 'lon'))
    

    dataset = xr.Dataset(
        data_vars={
            "PRECT": mtpr,
            "PSL": sp,
            "TMQ": tcwv,
            "VRT850": pv850,
            "WS850": ws850,
            "LABELS": labels
        },
        # coords={
        #     'time': 
        # }
    )
    single_nc.close()
    multi_nc.close()
    return dataset

# max_roci = np.argmax(ibtracs_nc.variables["bom_roci"])
# max_roci = np.unravel_index(max_roci, ibtracs_nc.variables["bom_roci"].shape)
# print("roci max", np.max(ibtracs.variables["bom_roci"]), max_roci)


# need to treat fill values as none or numpy gets unhappy
masked_lat = ibtracs_nc.variables["lat"][:]
masked_lat = np.ma.masked_array(masked_lat, mask=(masked_lat == ibtracs_nc.variables["lat"]._FillValue), fill_value=None)
masked_lon = ibtracs_nc.variables["lon"][:]
masked_lon = np.ma.masked_array(masked_lon, mask=(masked_lon == ibtracs_nc.variables["lon"]._FillValue), fill_value=None)
masked_roci = ibtracs_nc.variables["bom_roci"][:]
masked_roci = np.ma.masked_array(masked_roci, mask=(masked_roci == ibtracs_nc.variables["bom_roci"]._FillValue), fill_value=None)


# generate a mask per storm
print(tc_labels.shape) 
print(single_nc.variables["latitude"].shape)
print_freq = 100
iteration = 0

labeled_days = np.zeros((len(timestamp_storms), single_nc.dimensions["latitude"].size, single_nc.dimensions["longitude"].size), dtype=np.int8)

for timestamp, storms in timestamp_storms.items():
    timestep_labels = np.zeros((single_nc.dimensions["latitude"].size, single_nc.dimensions["longitude"].size), dtype=np.int8)
    for storm_id,timestep in storms:
        

        
        distances = haversine_distance(
            single_nc.variables["latitude"], # pull out the time  
            single_nc.variables["longitude"], 
            masked_lat[storm_id,timestep], 
            masked_lon[storm_id,timestep]
        )
        storm_labels = np.where(distances <= masked_roci[storm_id,timestep], 1, 0)
        timestep_labels = timestep_labels | storm_labels
    if iteration % print_freq == 0:
        print(f"{iteration} of {len(timestamp_storms)}")
    iteration += 1
    dataset = make_labeled_dataset(timestamp, timestep_labels)
    print(dataset)
    break
    # print(np.max(timestep_labels))
    
# # storm_id = 2403
# timestep = 5

# distances = haversine_distance(era5_single.variables["latitude"], era5_single.variables["longitude"], masked_lat[storm_id,timestep], masked_lon[storm_id,timestep])

# # print(distances.shape) 
# tc_label = np.where(distances <= masked_roci[storm_id,timestep], 1, 0)

(721, 1440)
(721,)
0 of 2759
<xarray.Dataset>
Dimensions:  (time: 1, lat: 721, lon: 1440)
Coordinates:
  * time     (time) int64 1034376
  * lat      (lat) float32 90.0 89.75 89.5 89.25 ... -89.25 -89.5 -89.75 -90.0
  * lon      (lon) float32 0.0 0.25 0.5 0.75 1.0 ... 359.0 359.2 359.5 359.8
Data variables:
    PRECT    (lat, lon) float64 9.452e-06 9.452e-06 ... -8.674e-19 -8.674e-19
    PSL      (lat, lon) float64 1.008e+05 1.008e+05 ... 6.938e+04 6.938e+04
    TMQ      (lat, lon) float64 2.157 2.157 2.157 2.157 ... 1.201 1.201 1.201
    VRT850   (lat, lon) float64 1.121e-06 1.121e-06 ... -1.719e-06 -1.719e-06
    WS850    (lat, lon) float64 0.0004198 0.0004198 ... 0.0004198 0.0004198
    LABELS   (time, lat, lon) int64 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0


In [None]:

tcwv.isel(time=0).plot.contourf()
plt.show()
pv850.isel(time=0).plot.contourf()
plt.show()
mtpr.isel(time=0).plot.contourf()
plt.show()
ws850.isel(time=0).plot.contourf()
plt.show()
# plt.colorbar()
