In [1]:
# general
import numpy as np
import cupy as cp
import xarray as xr
from tqdm.auto import tqdm
#from cmcrameri import cm

# plotting
import matplotlib.pyplot as plt

# io
from tqdm.auto import tqdm
import os
from pathlib import Path
import time
import numbers
import warnings

import sys
sys.path.append('..')

import gstatsim_custom as gsim
from gstatsim_custom import utilities_gpu

In [2]:
%%time
ds = xr.open_dataset(Path('./bedmap3_mod_500.nc'))
#ds = xr.load_dataset(Path('./bedmap3_mod_1000.nc'))

# add exposed bedrock to conditioning data
thick_cond = np.where(ds.mask.values == 4, 0, ds.thick_cond.values)

bed_cond = ds.surface_topography.values - thick_cond
ice_rock_msk = (ds.mask.values == 1) | (ds.mask.values == 4) | (ds.mask.values == 2)
bed_cond = np.where(ice_rock_msk, bed_cond, np.nan)
xx, yy = np.meshgrid(ds.x, ds.y)

CPU times: user 1.74 s, sys: 2.27 s, total: 4.01 s
Wall time: 4.37 s


In [3]:
%%time

ds = xr.open_dataset(Path('./bedmap3_mod_500.nc'))

# Convert numpy arrays to cupy arrays
mask_gpu = cp.asarray(ds.mask.values)
thick_cond_cpu = ds.thick_cond.values
surface_topo_gpu = cp.asarray(ds.surface_topography.values)

# add exposed bedrock to conditioning data
thick_cond_gpu = cp.where(mask_gpu == 4, 0, cp.asarray(thick_cond_cpu))

bed_cond_gpu = surface_topo_gpu - thick_cond_gpu
ice_rock_msk_gpu = (mask_gpu == 1) | (mask_gpu == 4) | (mask_gpu == 2)
bed_cond_gpu = cp.where(ice_rock_msk_gpu, bed_cond_gpu, cp.nan)
xx_gpu, yy_gpu = cp.meshgrid(cp.asarray(ds.x.values), cp.asarray(ds.y.values))


CPU times: user 1 s, sys: 2.17 s, total: 3.17 s
Wall time: 3.92 s


In [4]:
%%time 

cond_msk = ~np.isnan(bed_cond)
x_cond = xx[cond_msk]
y_cond = yy[cond_msk]
coordinates = (x_cond, y_cond)
data_cond = bed_cond[cond_msk]
trend = ds.trend.values

res_cond = bed_cond - trend

res_norm, nst_trans = gsim.utilities.gaussian_transformation(res_cond, cond_msk)

dsv = xr.load_dataset(Path('./continental_variogram_500.nc'))

vario = {
    'azimuth' : dsv.azimuth.values,
    'nugget' : 0,
    'major_range' : dsv.major_range.values,
    'minor_range' : dsv.minor_range.values,
    'sill' : dsv.sill.values,
    's' : dsv.smooth.values,
    'vtype' : 'matern',
}

CPU times: user 1.56 s, sys: 6.66 s, total: 8.22 s
Wall time: 8.88 s


In [5]:
%%time 

cond_msk_gpu = ~cp.isnan(bed_cond_gpu)
x_cond_gpu = xx_gpu[cond_msk_gpu]
y_cond_gpu = yy_gpu[cond_msk_gpu]
coordinates_gpu = (x_cond_gpu, y_cond_gpu)
data_cond_gpu = bed_cond_gpu[cond_msk_gpu]

# Convert trend to CuPy array
trend_gpu = cp.asarray(ds.trend.values)

res_cond_gpu = bed_cond_gpu - trend_gpu

res_norm_gpu, nst_trans_gpu = utilities_gpu.gaussian_transformation_gpu(res_cond_gpu, cond_msk_gpu)


# Load variogram dataset
dsv = xr.load_dataset(Path('./continental_variogram_500.nc'))

vario_gpu = {
    'azimuth' : dsv.azimuth.values,
    'nugget' : 0,
    'major_range' : dsv.major_range.values,
    'minor_range' : dsv.minor_range.values,
    'sill' : dsv.sill.values,
    's' : dsv.smooth.values,
    'vtype' : 'matern',
}

# Fix variogram parameters
def extract_scalar(arr):
    """Extract a single representative value from an array"""
    if hasattr(arr, 'size') and arr.size > 1:
        # Use median for robustness, or mean if you prefer
        return float(np.nanmedian(arr))
    elif hasattr(arr, 'item'):
        return float(arr.item())
    elif hasattr(arr, '__len__') and len(arr) >= 1:
        return float(arr[0])
    else:
        return float(arr)

vario_gpu_scaler = {
    'azimuth': extract_scalar(dsv.azimuth.values),
    'nugget': 0,
    'major_range': extract_scalar(dsv.major_range.values),
    'minor_range': extract_scalar(dsv.minor_range.values),
    'sill': extract_scalar(dsv.sill.values),
    's': extract_scalar(dsv.smooth.values),
    'vtype': 'matern',
}

CPU times: user 4.33 s, sys: 4.45 s, total: 8.78 s
Wall time: 8.87 s


In [6]:
from gstatsim_custom import interpolate_gpu

In [7]:
%%time 

# Convert surface topography and trend to CuPy
surface_topo_gpu = cp.asarray(ds.surface_topography.values)

# Calculate bounds on GPU
bounds_gpu = (-9999, surface_topo_gpu - trend_gpu)

# A value between 8192 and 32768 is a good starting point for a high-end GPU like the B200.
# This may need to be tuned based on available GPU memory, but 16384 is a robust choice.
optimal_batch_size = 4096
print(f"B200 maximized batch size: {optimal_batch_size}")

# Same variable names as before
print('starting optimized simulation')
tic = time.time()

sim_gpu = interpolate_gpu.sgs_gpu(
    xx_gpu, yy_gpu, res_cond_gpu, vario_gpu_scaler,
    radius=50e3,
    num_points=32,
    seed=0,
    batch_size=optimal_batch_size,  # Should be 8000-16000+
    quiet=False,
    sim_mask=ice_rock_msk_gpu,
    max_memory_gb=150.0  # Use much more memory
)


toc = time.time()
print(f'Optimized simulation completed in {toc-tic:.2f} seconds')
# Convert result back to CuPy and save
cp.save(Path('./stationary_sim_500_ep50.npy'), sim_gpu)
np.save(Path('./time_500_ep50.npy'),toc-tic)



B200 maximized batch size: 4096
starting optimized simulation


SGS:  12%|█▏        | 5718016/48387085 [13:31<1:40:03, 7107.43pts/s]

KeyboardInterrupt: 

In [8]:
sim_gpu=np.load('./stationary_sim_500_ep50.npy')

In [9]:
has_non_nan = not np.isnan(sim_gpu).all() 
has_non_nan

True

SGS:  12%|█▏        | 5718016/48387085 [13:50<1:40:03, 7107.43pts/s]

In [10]:
from cmcrameri import cm

ilow = 1000*2
ihigh = 5800*2
jlow = 700*2
jhigh = 6200*2

xx = xx_gpu.get()
yy = yy_gpu.get()
sim = sim_gpu#.get()
trend = trend_gpu.get()

xx_trim = xx[ilow:ihigh,jlow:jhigh]
yy_trim = yy[ilow:ihigh,jlow:jhigh]
sim_trim = sim[ilow:ihigh,jlow:jhigh]
trend_trim = trend[ilow:ihigh,jlow:jhigh]

plt.figure(figsize=(13,10))
im = plt.pcolormesh(xx_trim/1000, yy_trim/1000, sim_trim+trend_trim, cmap=cm.batlowW)
plt.axis('scaled')
plt.xlabel('X [km]')
plt.ylabel('Y [km]')
plt.title('SGS simulation+trend, matern covariance, 500 ep50 stationary, anisotropic')
plt.colorbar(im, pad=0.03, aspect=40, shrink=0.7, label='bed elevation [meters]')
plt.savefig(Path('./full_simulation_nonstationary_500_ep50.png'), dpi=300, bbox_inches='tight')
plt.show()

OSError: [Errno 107] Transport endpoint is not connected: '/opt/conda/lib/python3.12/.matplotlib-repo'