<a href="https://colab.research.google.com/github/nvnsudharsan/era5_to_prism/blob/main/era5_to_prism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install cdsapi xarray netCDF4 rioxarray geopandas matplotlib
!pip install -q xarray zarr gcsfs

Collecting cdsapi
  Downloading cdsapi-0.7.6-py2.py3-none-any.whl.metadata (3.0 kB)
Collecting netCDF4
  Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting rioxarray
  Downloading rioxarray-0.19.0-py3-none-any.whl.metadata (5.5 kB)
Collecting ecmwf-datastores-client (from cdsapi)
  Downloading ecmwf_datastores_client-0.1.0-py3-none-any.whl.metadata (21 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Collecting rasterio>=1.4.3 (from rioxarray)
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio>=1.4.3->rioxarray)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Collecting cligj>=0.5 (from rasterio>=1.4.3->rioxarray)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Collecting click-plugins (from rasterio>=1.4.3->rioxar

In [None]:
import xarray as xr
ds = xr.open_zarr(
    'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
    chunks=None,
    storage_options={"token": "anon"}
)
ds = ds.sel(time=slice(ds.attrs['valid_time_start'], ds.attrs['valid_time_stop']))

In [None]:
tp = ds['total_precipitation'].sel(
    time=slice("2015-01-01", "2024-12-31"),
    latitude=slice(31.0, 29.5),  # South to north!
    longitude=slice(262.0, 263.5)  # 360-based: 360 - 98 = 262
)

In [None]:
tp_mm = tp * 1000  # Convert to mm
tp_daily = tp_mm.resample(time='1D').sum()
tp_daily.name = "tp_mm_day"

In [None]:
tp_daily.to_netcdf("/content/drive/MyDrive/era5_tp_daily_austin_2015_2024_zarr_gcs.nc")

In [None]:
!pip install earthengine-api geemap --upgrade

Collecting earthengine-api
  Downloading earthengine_api-1.5.19-py3-none-any.whl.metadata (2.1 kB)
Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Downloading earthengine_api-1.5.19-py3-none-any.whl (462 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.6/462.6 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: jedi, earthengine-api
  Attempting uninstall: earthengine-api
    Found existing installation: earthengine-api 1.5.18
    Uninstalling earthengine-api-1.5.18:
      Successfully uninstalled earthengine-api-1.5.18
Successfully installed earthengine-api-1.5.19 jedi-0.19.2


In [None]:
import ee
ee.Authenticate()  # Run this to get the code
ee.Initialize(project='ee-my-naveensudharsan')

In [None]:
import geemap
from datetime import datetime

# Define time range
start = '2015-01-01'
end = '2024-12-31'

# Austin region
region = ee.Geometry.Rectangle([-98.0, 29.5, -96.5, 31.0])

# PRISM daily precipitation
collection = ee.ImageCollection('OREGONSTATE/PRISM/AN81d') \
    .filterDate(start, end) \
    .filterBounds(region) \
    .select('ppt')

# Convert to a multi-band image (each day as a band)
def format_band(img):
    date_str = img.date().format('YYYYMMdd')
    return img.rename(date_str)

prism_bands = collection.map(format_band)
prism_image = prism_bands.toBands()

# Export to Drive
task = ee.batch.Export.image.toDrive(
    image=prism_image.clip(region),
    description='PRISM_Precip_2015_2024_Austin',
    folder='era5_downscaled',  # GDrive folder
    fileNamePrefix='prism_ppt_austin_2015_2024',
    region=region,
    scale=800,
    maxPixels=1e13,
    fileFormat='GeoTIFF'
)
task.start()

In [None]:
import time

while task.active():
    print('Exporting...', task.status()['state'])
    time.sleep(60)

Exporting... RUNNING
Exporting... RUNNING
Exporting... RUNNING
Exporting... RUNNING
Exporting... RUNNING


In [None]:
import rioxarray as rxr
import xarray as xr

# Load PRISM high-res reference
prism = rxr.open_rasterio('/content/drive/MyDrive/era5_downscaled/prism_ppt_austin_2015_2024.tif', masked=True).squeeze()

# Load ERA5 daily file you saved earlier
era5 = xr.open_dataset('/content/drive/MyDrive/era5_tp_daily_austin_2015_2024_zarr_gcs.nc')['tp_mm_day']

# Match CRS and align to PRISM grid
era5_aligned = era5.rio.write_crs("EPSG:4326").rio.reproject_match(prism)

# Save aligned ERA5 to Drive
era5_aligned.to_netcdf('/content/drive/MyDrive/era5_downscaled/era5_tp_aligned_to_prism.nc')


In [None]:
import pandas as pd

# Load ERA5 already aligned to PRISM grid
era5 = xr.open_dataset('/content/drive/MyDrive/era5_downscaled/era5_tp_aligned_to_prism.nc')['tp_mm_day']

# Load PRISM multiband GeoTIFF
prism = rxr.open_rasterio('/content/drive/MyDrive/era5_downscaled/prism_ppt_austin_2015_2024.tif', masked=True)

# Convert to DataArray with time dimension
dates = pd.date_range('2015-01-01', periods=prism.shape[0])
prism = prism.assign_coords(band=dates).rename(band='time')

In [None]:
# Remove single 'band' dim and set coordinates correctly
prism = prism.squeeze(drop=True)
if 'spatial_ref' in prism.coords:
    prism = prism.drop_vars('spatial_ref')

In [None]:
trainval = prism.sel(time=slice('2015-01-01', '2022-12-31'))
test = prism.sel(time=slice('2023-01-01', '2024-12-31'))

# Split 80/20 train/val
train_dates = trainval.time.to_index()
split_idx = int(0.8 * len(train_dates))
train = trainval.sel(time=train_dates[:split_idx])
val = trainval.sel(time=train_dates[split_idx:])

# Same for ERA5
era5_train = era5.sel(time=train.time)
era5_val = era5.sel(time=val.time)
era5_test = era5.sel(time=test.time)

In [None]:
train.name = "tp_mm_day"
val.name = "tp_mm_day"
test.name = "tp_mm_day"

era5_train.name = "tp_mm_day"
era5_val.name = "tp_mm_day"
era5_test.name = "tp_mm_day"

In [None]:
train.to_netcdf('/content/drive/MyDrive/era5_downscaled/prism_train.nc')
val.to_netcdf('/content/drive/MyDrive/era5_downscaled/prism_val.nc')
test.to_netcdf('/content/drive/MyDrive/era5_downscaled/prism_test.nc')

era5_train.to_netcdf('/content/drive/MyDrive/era5_downscaled/era5_train.nc')
era5_val.to_netcdf('/content/drive/MyDrive/era5_downscaled/era5_val.nc')
era5_test.to_netcdf('/content/drive/MyDrive/era5_downscaled/era5_test.nc')

In [None]:
!pip install wandb



In [None]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnvnsudharsan[0m ([33mnvnsudharsan-the-university-of-texas-at-austin[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
!pip install xdownscale

Collecting xdownscale
  Downloading xdownscale-1.0.1-py3-none-any.whl.metadata (2.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->xdownscale)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->xdownscale)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->xdownscale)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->xdownscale)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->xdownscale)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->xdownscale)
  Downloading nvidia_cufft_

In [None]:
from xdownscale import Downscaler
import wandb
import torch
import os

# All models you want to try
model_list = [
    "srcnn", "fsrcnn", "lapsr", "carnm", "falsra", "falsrb", "ssresnet",
    "carn", "oisrrk2", "mdsr", "san", "rcan", "unet", "dlgsanet", "dpmn",
    "safmn", "dpt", "distgssr", "swin"
]

# Load training and test datasets (combined train+val expected)
input_da = xr.open_mfdataset([
    '/content/drive/MyDrive/era5_downscaled/era5_train.nc',
    '/content/drive/MyDrive/era5_downscaled/era5_val.nc'
]).to_array().squeeze()

target_da = xr.open_mfdataset([
    '/content/drive/MyDrive/era5_downscaled/prism_train.nc',
    '/content/drive/MyDrive/era5_downscaled/prism_val.nc'
]).to_array().squeeze()

# Load test data separately
input_test = xr.open_dataset('/content/drive/MyDrive/era5_downscaled/era5_test.nc').to_array().squeeze()
target_test = xr.open_dataset('/content/drive/MyDrive/era5_downscaled/prism_test.nc').to_array().squeeze()

# Output folder for predictions and weights
model_dir = "/content/drive/MyDrive/era5_downscaled/model_weights"
os.makedirs(model_dir, exist_ok=True)

for model_name in model_list:
    run_name = f"{model_name}_run"
    wandb.init(project="xdownscale-austin", name=run_name, reinit=True)

    # Train model using your Downscaler class
    ds = Downscaler(
        input_da=input_da,
        target_da=target_da,
        model_name=model_name,
        patch_size=64,
        batch_size=16,
        epochs=50,
        val_split=0.2,
        test_split=0.0,  # We'll use separate test data
        device='cuda' if torch.cuda.is_available() else 'cpu',
        use_wandb=True,
        patience=10
    )

    # Save best model weights manually
    torch.save(ds.model.state_dict(), f"{model_dir}/{model_name}_best.pth")

    # Predict on test set
    pred = ds.predict(input_test)
    pred.name = "tp_mm_day_pred"
    pred.to_netcdf(f"/content/drive/MyDrive/era5_downscaled/pred_{model_name}_test.nc")

    # Log test stats
    wandb.log({
        "test_mean": float(pred.mean().values),
        "test_std": float(pred.std().values)
    })
    import numpy as np

    # Compute evaluation metrics
    true = target_test.values
    predicted = pred.values

    rmse = np.sqrt(np.mean((predicted - true)**2))
    mae = np.mean(np.abs(predicted - true))
    bias = np.mean(predicted - true)
    corr = np.corrcoef(predicted.ravel(), true.ravel())[0, 1]

    wandb.log({
        "test_rmse": rmse,
        "test_mae": mae,
        "test_bias": bias,
        "test_corr": corr
    })

    wandb.finish()



ValueError: too many values to unpack (expected 2)

In [None]:
input_test

In [None]:
target_test