第一部分：模型训练

1. Connect to Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


2. Install the bayesnf library

In [2]:
pip install bayesnf

Collecting bayesnf
  Downloading bayesnf-0.1.3-py3-none-any.whl.metadata (4.3 kB)
Collecting jaxtyping (from bayesnf)
  Downloading jaxtyping-0.3.2-py3-none-any.whl.metadata (7.0 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping->bayesnf)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Downloading bayesnf-0.1.3-py3-none-any.whl (25 kB)
Downloading jaxtyping-0.3.2-py3-none-any.whl (55 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading wadler_lindig-0.1.7-py3-none-any.whl (20 kB)
Installing collected packages: wadler-lindig, jaxtyping, bayesnf
Successfully installed bayesnf-0.1.3 jaxtyping-0.3.2 wadler-lindig-0.1.7


In [4]:
import warnings
warnings.simplefilter('ignore')
import jax
import time
import pandas as pd
import numpy as np
from bayesnf.spatiotemporal import BayesianNeuralFieldMAP

In [5]:
df_train = pd.read_csv('/content/drive/MyDrive/xlc/input/train_CMEMS_SST.csv', index_col=0, parse_dates=['datetime'])
df_train['datetime'] = pd.to_numeric(df_train['datetime'], errors='coerce')
df_train['datetime'] = pd.to_datetime(df_train['datetime'], unit='D', origin='1899-12-30')
df_train.shape

(242091, 5)

In [6]:
model = BayesianNeuralFieldMAP(
  width=256,
  depth=2,
  freq='D',
  seasonality_periods=['W', 'M'], # week month
  num_seasonal_harmonics=[2, 4],
  feature_cols=['datetime', 'longitude', 'latitude'],
  target_col='sla',
  observation_model='NORMAL',
  timetype='index',
  standardize=['longitude','latitude'],
  )

In [8]:
import os
import time
import threading
import pynvml
import jax

# Use on-demand GPU memory allocation instead of pre-allocating a large block
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

# ------------------- Initialize GPU monitoring -------------------
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # Select GPU 0

def monitor_gpu(interval=60):
    """Background thread: print GPU memory usage every `interval` seconds"""
    while True:
        meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        used = meminfo.used / 1024**3
        total = meminfo.total / 1024**3
        print(f"[GPU Memory] {used:.2f} GB / {total:.2f} GB")
        time.sleep(interval)

# Start GPU memory monitoring thread
monitor_thread = threading.Thread(target=monitor_gpu, args=(60,), daemon=True)
monitor_thread.start()

# ------------------- Manual training loop -------------------
start_time = time.time()

model = model.fit(
    df_train,
    seed=jax.random.PRNGKey(0),  # Random seed
    ensemble_size=1,             # Train 1 model at a time
    learning_rate=0.005,
    num_epochs=5000              # Total number of training epochs
)

end_time = time.time()
print("Total training time:", end_time - start_time, "s")



[GPU Memory] 0.38 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
[GPU Memory] 1.94 GB / 15.00 GB
Total training time: 318.5527732372284 s


In [None]:
import cloudpickle
# save model
with open('/content/drive/MyDrive/xlc/model/model_CMEMS_SST.pkl', 'wb') as f:
    cloudpickle.dump(model, f)

第二部分：模型加载

In [9]:
# import model
import cloudpickle
with open('/content/drive/MyDrive/xlc/model/model_CMEMS_SST.pkl', 'rb') as f:
    model = cloudpickle.load(f)

In [10]:
# import test
import pandas as pd
df_test = pd.read_csv('/content/drive/MyDrive/xlc/input/validation_CMEMS_SST.csv', index_col=0, parse_dates=['datetime'])
df_test['datetime'] = pd.to_numeric(df_test['datetime'], errors='coerce')
df_test['datetime'] = pd.to_datetime(df_test['datetime'], unit='D', origin='1899-12-30')
last_col = df_test.columns[-1]
df_test[last_col] = df_test[last_col].fillna(9999)
quantiles = list(np.arange(0.01, 1, 0.01))
quantiles = tuple(quantiles)
yhat, yhat_quantiles = model.predict(df_test, quantiles=quantiles)

[GPU Memory] 0.38 GB / 15.00 GB


In [11]:
# to matrix
yhat_matrix = np.column_stack([q.tolist() for q in yhat_quantiles])
gt_matrix = df_test.sla.to_numpy()
yhat_matrix.shape

(360254, 99)

In [18]:
# Optional:
# compute MAE to check model performance
diff = df_test[last_col] - yhat_matrix[:, 47]
# Build mask to exclude points equal to 9999
mask = df_test[last_col] != 9999
valid_diff = diff[mask]
# MAE
mae = np.mean(np.abs(valid_diff))
print("MAE:", mae)


MAE: 0.1588131631405801


In [None]:
# Save yhat_matrix to CSV
output_csv_path = '/content/drive/MyDrive/xlc/output/SSIM_CMEMS_SST (2).csv'

# Add column names for easier distinction between different quantiles
col_names = [f'Quantile_{q}' for q in quantiles]
df_yhat = pd.DataFrame(yhat_matrix, columns=col_names)
df_yhat.to_csv(output_csv_path, index=False)

print(f"yhat_matrix has been saved to {output_csv_path}")