# Spatial-temporal Correlated Random Fields (SCRF)

Author: Cindy Chiao  
Date: 12/29/2021

The purpose of this notebook is to generate spatially and temporally correlated random fields to be used to perturb the mean prediction result of GARD downscaling method. First, we use ERA5 observation data to find the appropriate length scales of correlation both spatially and temporally. Then, we use the correlation length scales to generate random fields covering the global domain in the resolution of ERA5. The package [gstools](https://geostat-framework.readthedocs.io/projects/gstools/en/stable/#pip) is used heavily in this process. A SCRF is generated for precipitation and another for temperature (will be used to perturb both tmin and tmax prediction result). 

Ideally, the entire available observation time series of the global domain would be used in determining the correlation length scales, and the SCRF of the entire future prediction period would be generated as one contiguous dataset. However, this proves to be prohibitive in terms of computation time due to the single threaded nature of the gstools algorithm. Thus, random subsamples of the observation time series were used to find the correlation length, and the SCRF was generated in 10 year long chunks. 

To find a representative spatial correlation length, we calculate the average spatial correlation lengths of 100 20x20 degree maps of 365 day time series. The 20x20 degree maps are constrained to areas of the globe that contain major landmass, avoiding the areas where it's majority ocean. 

To find a representative temporal correlation length, we use 10,000 samples of 365 day time series, again constrained to the areas containing major landmass. 

The spatial/temporal length scales for tmin and tmax are then averaged to be the length scale used to generate SCRF for temperature. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import random 
import xarray as xr
import numpy as np
import pandas as pd
import gstools as gs
from cmip6_downscaling.workflows.paths import make_scrf_path
from cmip6_downscaling.methods.gard import generate_scrf
from cmip6_downscaling.data.observations import get_obs

random.seed(20211228)

In [3]:
from carbonplan_trace.tiles import tiles
from carbonplan_trace.v1 import utils


In [4]:
# import dask
# from dask.distributed import Client
# from dask_gateway import Gateway

# client = Client(n_workers=8)
# client

## Finding correlation length scales

In [5]:
# get all 20x20 degree tiles that covers at least one 10x10 degree tile that has major land mass 

expanded_tiles = []
for tile in tiles:
    lat, lon = utils.get_lat_lon_tags_from_tile_path(tile)
    min_lat, max_lat, min_lon, max_lon = utils.parse_bounding_box_from_lat_lon_tags(lat, lon)
    
    for i in [-1, 0]:
        for j in [-1, 0]: 
            lat_tag, lon_tag = utils.get_lat_lon_tags_from_bounding_box(max_lat + (i * 10.), min_lon + (j * 10.))
            expanded_tiles.append(f'{lat_tag}_{lon_tag}')
            
expanded_tiles = list(set(expanded_tiles))
expanded_tiles = [t for t in expanded_tiles if '190W' not in t and '170E' not in t and '80N' not in t and '80S' not in t]

In [6]:
def convert_long3_to_long1(long3):
    # see https://confluence.ecmwf.int/pages/viewpage.action?pageId=149337515
    long1 = (long3 + 180) % 360 - 180
    return long1

In [7]:
variables = ['tasmax', 'tasmin', 'pr']
seasonality_period = 31
temporal_scaler = 1000.0

sample_length = 365
n_samples_temporal = 10000
n_tiles_spatial = 100

In [8]:
def get_spatial_length(data):
    fields = data.values
    print(np.mean(fields))
    bin_center, gamma = gs.vario_estimate(
        pos=(data.lon.values, data.lat.values),
        field=fields,
        latlon=True,
        mesh_type='structured',
    )
    spatial = gs.Gaussian(dim=2, latlon=True, rescale=gs.EARTH_RADIUS)
    spatial.fit_variogram(bin_center, gamma, sill=np.mean(np.var(fields, axis=(1, 2))))

    return spatial.len_scale

In [9]:
# spatial length scale 

for v in variables:
    print(v)
    fname = f'{v}_spatial_length_scale.csv'
    if os.path.exists(fname):
        df = pd.read_csv(fname)
        df = df.loc[df.spatial_length_scale > 1]
        print(df.spatial_length_scale.mean())
    else:
        data = get_obs(
            obs='ERA5',
            train_period_start=1980,
            train_period_end=2020,
            variables=v,
            chunking_approach=None,
        )[v]

        if v == 'pr':
            data = data * 1e6

        # go from 0-360 to -180-180 longitude 
        data['lon'] = convert_long3_to_long1(data.lon)
        data = data.reindex(lon=sorted(data.lon.values))

        # detrend 
        seasonality = (
            data.rolling({'time': seasonality_period}, center=True, min_periods=1)
            .mean()
            .groupby('time.dayofyear')
            .mean()
        )
        detrended = data.groupby("time.dayofyear") - seasonality
        detrended = detrended.transpose('time', 'lon', 'lat')
        possible_time_starts = len(detrended.time) - sample_length

        spatial_length_scale = []
        chosen_tiles = random.sample(expanded_tiles, k=n_tiles_spatial)
        for tile in chosen_tiles:
            lat, lon = utils.get_lat_lon_tags_from_tile_path(tile)
            min_lat, max_lat, min_lon, max_lon = utils.parse_bounding_box_from_lat_lon_tags(lat, lon)
            max_lat += 10
            max_lon += 10 
            t = random.randint(a=0, b=possible_time_starts)
            sub = detrended.sel(lat=slice(max_lat, min_lat), lon=slice(min_lon, max_lon)).isel(time=slice(t, t+sample_length))
            # spatial_length_scale.append(client.persist(get_spatial_length(sub), retries=1))
            l = get_spatial_length(sub)
            spatial_length_scale.append(l)
            print(tile, l)

        df = pd.DataFrame({'tile': chosen_tiles, 'spatial_length_scale': spatial_length_scale})
        df.to_csv(f'{v}_spatial_length_scale.csv')
        df = df.loc[df.spatial_length_scale > 1]
        print(df.spatial_length_scale.mean())

tasmax
437.36955028922074
tasmin
419.73912697665384
pr
404.29331338341586


In [8]:
sample_length = 100
temporal_scaler = 100.

In [9]:
# temporal length scale 
for v in variables:
    print(v)
    data = get_obs(
        obs='ERA5',
        train_period_start=1980,
        train_period_end=2020,
        variables=v,
        chunking_approach=None,
    )[v]

    if v == 'pr':
        data = data * 1e3

    # go from 0-360 to -180-180 longitude 
    data['lon'] = convert_long3_to_long1(data.lon)
    data = data.reindex(lon=sorted(data.lon.values))
    
    # detrend 
    print('detrending')
    seasonality = (
        data.rolling({'time': seasonality_period}, center=True, min_periods=1)
        .mean()
        .groupby('time.dayofyear')
        .mean()
    )
    detrended = data.groupby("time.dayofyear") - seasonality
    detrended = detrended.transpose('time', 'lon', 'lat')
    possible_time_starts = len(detrended.time) - sample_length
    detrended = detrended.stack(point=['lat', 'lon'])
    
    print('building samples')
    points = []

    chosen_tiles = random.choices(tiles, k=n_samples_temporal)
    ii = random.choices(np.arange(40), k=n_samples_temporal)
    jj = random.choices(np.arange(40), k=n_samples_temporal)

    lat_lon_tags = [utils.get_lat_lon_tags_from_tile_path(tile) for tile in chosen_tiles]
    bounding_boxes = [
        utils.parse_bounding_box_from_lat_lon_tags(lat, lon)
        for lat, lon in lat_lon_tags
    ]

    for bounding_box, i, j in zip(bounding_boxes, ii, jj):
        min_lat, max_lat, min_lon, max_lon = bounding_box
        lat = min_lat + i * 0.25
        lon = min_lon + j * 0.25
        points.append((lat, lon))

    points.sort(key=lambda u: u[0])
    sub = detrended.sel(point=points).load()
    
    t_starts = np.array(random.choices(np.arange(possible_time_starts), k=n_samples_temporal))
    t_ends = t_starts + sample_length

    fields = []
    for i, (start, end) in enumerate(zip(t_starts, t_ends)):
        f = sub.isel(point=i, time=slice(start, end)).values
        fields.append(f)

    print('finding correlation length')
    t = np.arange(sample_length) / temporal_scaler
    bin_center, gamma = gs.vario_estimate(pos=t, field=fields, mesh_type='structured')
    temporal = gs.Gaussian(dim=1)
    temporal.fit_variogram(bin_center, gamma, sill=np.mean(np.var(fields, axis=1)))
    
    print(v, temporal.len_scale * temporal_scaler)

tasmax


    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


detrending
building samples
finding correlation length
tasmax 3.725572253986237
tasmin


    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


detrending
building samples
finding correlation length
tasmin 3.9400152223556937
pr


    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]


detrending
building samples
finding correlation length
pr 2.0457947933307676


## Generating SRCF

In [None]:
ss = {
    'tasmax': 437.36955028922074,
    'tasmin': 419.73912697665384, 
    'pr': 404.29331338341586
}

ts = {
    'tasmax': 3.725572253986237,
    'tasmin': 3.9400152223556937, 
    'pr': 2.0457947933307676
}

In [3]:
ss = 500
ts = 10

model = gs.Gaussian(dim=3, var=1.0, len_scale=[ss, ss, ts])
srf = gs.SRF(model, seed=0)

In [4]:
step = 25
nx = 100
ny = 100
nt = 365*30+8

x = np.arange(0, nx*step, step)
y = np.arange(0, ny*step, step)
t = np.arange(0, nt)

t1 = time.time()
field = srf.structured((x, y, t))
t2 = time.time()

print((t2-t1)/60)

62.33059689203898


In [5]:
step = 25
nx = 100
ny = 100
nt = 365*30+8

x = np.arange(0, nx*step, step)
y = np.arange(0, ny*step, step)
t = np.arange(0, nt)

len(x) * len(y) * len(t)

109580000

In [4]:
# 10 year chunks 
# 

step = 25
nx = 360*4
ny = 180*4
nt = 365

x = np.arange(0, nx*step, step)
y = np.arange(0, ny*step, step)
t = np.arange(0, nt)

t1 = time.time()
field = srf.structured((x, y, t))
t2 = time.time()

print((t2-t1)/60)

217.89220894575118


In [18]:
import matplotlib.pyplot as plt
import random

In [None]:
for _ in np.arange(30):
    i = random.randint(0, 365)
    plt.imshow(field[:, :, _])
    plt.show()
    plt.close()
    
# import numpy as np
# import cv2
# size = 720*16//9, 720
# duration = 2
# fps = 25
# out = cv2.VideoWriter('output.mp4', cv2.VideoWriter_fourcc(*'mp4v'), fps, (size[1], size[0]), False)
# for _ in range(fps * duration):
#     data = np.random.randint(0, 256, size, dtype='uint8')
#     out.write(data)
# out.release()

In [6]:
step = 25
nx = 360*4
ny = 180*4
nt = 365

x = np.arange(0, nx*step, step)
y = np.arange(0, ny*step, step)
t = np.arange(0, nt)

len(x) * len(y) * len(t)

378432000

In [7]:
378432000 / 109580000

3.453476911845227

In [8]:
217.89220894575118 / 62.33059689203898

3.4957503988475507

In [9]:
217.89220894575118 * 30 / 60

108.94610447287559

In [4]:
ds = xr.DataArray(
    np.random.rand(360*4, 180*4, 365*30+8),
    dims=['lon', 'lat', 'time'],
    coords=[np.arange(360*4), np.arange(180*4), np.arange(365*30+8)]
)

In [6]:
ds.nbytes / 1e9

90.8900352