### **1. Load Libraries**

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import cartopy.feature as cfeature
from cartopy.util import add_cyclic_point
from matplotlib.ticker import MaxNLocator
import warnings
import os

from matplotlib.cm import ScalarMappable
from matplotlib.ticker import MaxNLocator

import pandas as pd


### **2. Pre-setup**

In [None]:
warnings.filterwarnings("ignore", message="facecolor will have no effect as it has been defined as \"never\".")
warnings.filterwarnings("ignore", message="This figure includes Axes that are not compatible with tight_layout")
warnings.filterwarnings("ignore", message="Glyph 8322")
warnings.filterwarnings("ignore", message="Warning: 'partition' will ignore the 'mask' of the MaskedArray.")

plt.rcParams["font.family"] = ["DejaVu Sans", "Arial"]
plt.rcParams["font.size"] = 12  # 基础字体大小
plt.rcParams["axes.unicode_minus"] = False  # 解决负号显示问题

### **3. Define datetime and files**

In [None]:
target_date_str = "2020-08-15"
target_date = np.datetime64(target_date_str)
target_time = np.datetime64(f"{target_date_str}T13:30")

ct_file = f"/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/CarbonTracker/xCO2_1330LST_global/CT2022.xCO2_1330_glb3x2_{target_date_str}.nc"
cams_file = f"/data3/interns/NRT_CO2_Emission_Map_Project/MingjuanZhang_work/CAMS/cams_data/cams73_latest_co2_col_surface_inst_202008.nc"
ml_file = f"/data3/interns/NRT_CO2_Emission_Map_Project/ML_XCO2/XCO2_prediction_full/monthly_xco2_full_2020_08.npy"

for file_path in [ct_file, cams_file, ml_file]:
    if not os.path.exists(file_path):
        print(f"Warning: File does not exist -> {file_path}")

### **4. Read Data (Daily)**

In [None]:
# 2.1 CarbonTracker
ds_ct = xr.open_dataset(ct_file)
ct_lon_name = 'lon' if 'lon' in ds_ct.variables else 'longitude'
ct_lat_name = 'lat' if 'lat' in ds_ct.variables else 'latitude'
ct_lon = ds_ct[ct_lon_name].values
ct_lat = ds_ct[ct_lat_name].values
ct_xco2 = ds_ct['xco2'].isel(time=0).values
ct_xco2, ct_lon = add_cyclic_point(ct_xco2, coord=ct_lon)
ds_ct.close()

# 2.2 CAMS
ds_cams = xr.open_dataset(cams_file)
cams_lon_name = 'lon' if 'lon' in ds_cams.variables else 'longitude'
cams_lat_name = 'lat' if 'lat' in ds_cams.variables else 'latitude'
cams_lon = ds_cams[cams_lon_name].values
cams_lat = ds_cams[cams_lat_name].values
cams_time = ds_cams['time'].values
june1_mask = (cams_time >= target_date) & (cams_time < target_date + np.timedelta64(1, 'D'))
if np.any(june1_mask):
    cams_june1 = ds_cams['XCO2'].isel(time=june1_mask)
    closest_idx = np.argmin(np.abs(cams_june1['time'].values - target_time))
    cams_xco2_raw = cams_june1.isel(time=closest_idx).values
else:
    print(f"Warning: No data found for {target_date_str} in CAMS, using first time step")
    cams_xco2_raw = ds_cams['XCO2'].isel(time=0).values
cams_xco2 = cams_xco2_raw * 1000000
cams_xco2, cams_lon = add_cyclic_point(cams_xco2, coord=cams_lon)
ds_cams.close()

# 2.3 Ours
ml_struct = np.load(ml_file)
if 'xco2_pred' in ml_struct.dtype.names:
    ml_xco2_pred = ml_struct['xco2_pred']
    ml_lon = ml_struct['lon']
    ml_lat = ml_struct['lat']
else:
    raise ValueError("'xco2_pred' field not found in ML data file, please check data structure.")
unique_lon = np.unique(ml_lon)
unique_lat = np.unique(ml_lat)
ml_grid = np.full((len(unique_lat), len(unique_lon)), np.nan)
lon_indices = np.searchsorted(unique_lon, ml_lon)
lat_indices = np.searchsorted(unique_lat, ml_lat)
valid = (lon_indices < len(unique_lon)) & (lat_indices < len(unique_lat)) & ~np.isnan(ml_xco2_pred)
ml_grid[lat_indices[valid], lon_indices[valid]] = ml_xco2_pred[valid]
ml_grid, unique_lon = add_cyclic_point(ml_grid, coord=unique_lon)

### **5. Matching the final**

In [None]:
ml_struct = np.load(ml_file, allow_pickle=True)

required_fields = {"xco2_pred", "lon", "lat"}
if not required_fields.issubset(set(ml_struct.dtype.names or [])):
    raise ValueError(f"ML file missing fields. Need {required_fields}, got {ml_struct.dtype.names}")

ml_xco2_pred_all = ml_struct["xco2_pred"]
ml_lon_all = ml_struct["lon"]
ml_lat_all = ml_struct["lat"]

time_field_candidates = ["time", "datetime", "date", "timestamp", "Time", "DATE", "Datetime"]
time_field = next((f for f in time_field_candidates if f in (ml_struct.dtype.names or [])), None)

if time_field is None:
    raise ValueError(
        f"Your ML .npy has no time field ({time_field_candidates}). "
        "So I cannot select a specific day/time. Please check the file structure."
    )

ml_time = pd.to_datetime(ml_struct[time_field], errors="coerce")
ml_time64 = ml_time.to_numpy(dtype="datetime64[ns]")

day_mask = (ml_time64 >= target_date) & (ml_time64 < target_date + np.timedelta64(1, "D"))
if not np.any(day_mask):
    raise ValueError(f"No ML data found on {target_date_str} in field '{time_field}'.")

day_times = ml_time64[day_mask]
closest_i = np.argmin(np.abs(day_times - target_time))
selected_time = day_times[closest_i]

time_mask = (ml_time64 == selected_time)
ml_lon = ml_lon_all[time_mask]
ml_lat = ml_lat_all[time_mask]
ml_xco2_pred = ml_xco2_pred_all[time_mask]

df_ml = pd.DataFrame({"lon": ml_lon, "lat": ml_lat, "xco2_pred": ml_xco2_pred})
df_ml = df_ml.dropna(subset=["lon", "lat", "xco2_pred"])
df_ml = df_ml.groupby(["lat", "lon"], as_index=False)["xco2_pred"].mean()

ml_lon = df_ml["lon"].to_numpy()
ml_lat = df_ml["lat"].to_numpy()
ml_xco2_pred = df_ml["xco2_pred"].to_numpy()

print(f"[ML] Using time = {pd.Timestamp(selected_time).isoformat()} (closest to {pd.Timestamp(target_time).isoformat()})")
print(f"[ML] Points used = {len(ml_xco2_pred)}")

unique_lon = np.unique(ml_lon)
unique_lat = np.unique(ml_lat)

ml_grid = np.full((len(unique_lat), len(unique_lon)), np.nan)
lon_indices = np.searchsorted(unique_lon, ml_lon)
lat_indices = np.searchsorted(unique_lat, ml_lat)

valid = (
    (lon_indices >= 0) & (lon_indices < len(unique_lon)) &
    (lat_indices >= 0) & (lat_indices < len(unique_lat)) &
    ~np.isnan(ml_xco2_pred)
)
ml_grid[lat_indices[valid], lon_indices[valid]] = ml_xco2_pred[valid]

ml_grid, unique_lon = add_cyclic_point(ml_grid, coord=unique_lon)

### **6. Plotting**

In [None]:

ct_valid = ct_xco2[~np.isnan(ct_xco2)]
cams_valid = cams_xco2[~np.isnan(cams_xco2)]
ml_valid = ml_grid[~np.isnan(ml_grid)]
all_data = np.concatenate([ct_valid, cams_valid, ml_valid])
vmin = np.percentile(all_data, 2.5)
vmax = np.percentile(all_data, 97.5)
norm = Normalize(vmin=vmin, vmax=vmax)

regions = {
    "Asia": {"lon": [70, 140], "lat": [10, 50]},
    "Australia": {"lon": [110, 180], "lat": [-50, 0]},
    "North America": {"lon": [-130, -60], "lat": [20, 60]},
    "Europe and North Africa": {"lon": [-10, 60], "lat": [10, 70]}
}

In [None]:
output_dir = "3.2Comparison_xco2_split_figs"
os.makedirs(output_dir, exist_ok=True)

data_labels = ["CarbonTracker", "CAMS", "Our Model"]
data_sources = [
    (ct_xco2, ct_lon, ct_lat),
    (cams_xco2, cams_lon, cams_lat),
    (ml_grid, unique_lon, unique_lat)
]

def safe_name(s):
    return (s.replace(" ", "_")
             .replace("&", "and")
             .replace("/", "_"))

for region_name, bounds in regions.items():
    lon_min, lon_max = bounds["lon"]
    lat_min, lat_max = bounds["lat"]

    for label, (data, lons, lats) in zip(data_labels, data_sources):

        fig = plt.figure(figsize=(5, 3.5))
        ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
        ax.set_extent([lon_min, lon_max, lat_min, lat_max],
                      crs=ccrs.PlateCarree())
        ax.add_feature(cfeature.LAND, alpha=0.3)
        ax.add_feature(cfeature.COASTLINE, linewidth=1.0)
        ax.add_feature(cfeature.BORDERS, linewidth=0.8, color='black')

        mesh = ax.pcolormesh(
            lons, lats, data,
            cmap='Spectral_r', norm=norm, shading='auto'
        )

        xticks = np.arange(
            np.ceil(lon_min / 20) * 20,  
            np.floor(lon_max / 20) * 20 + 1e-6, 
            20
        )
        ax.set_xticks(xticks, crs=ccrs.PlateCarree())

        yticks = np.arange(
            np.ceil(lat_min / 20) * 20,
            np.floor(lat_max / 20) * 20 + 1e-6,
            20
        )
        ax.set_yticks(yticks, crs=ccrs.PlateCarree())

        def format_func(x, pos):
            return f"{x:.0f}"

        ax.xaxis.set_major_formatter(plt.FuncFormatter(format_func))
        ax.yaxis.set_major_formatter(plt.FuncFormatter(format_func))
        ax.tick_params(axis='both', labelsize=11)

        for spine in ax.spines.values():
            spine.set_edgecolor('black')
            spine.set_linewidth(1.5)

        fname = f"xco2_{safe_name(region_name)}_{safe_name(label)}_{target_date_str.replace('-', '')}.png"
        fpath = os.path.join(output_dir, fname)
        plt.savefig(fpath, dpi=300, bbox_inches='tight')
        plt.close(fig)

        print(f"Saved: {fpath}")

fig_cb = plt.figure(figsize=(6, 1.0)) 
ax_cb = fig_cb.add_axes([0.05, 0.4, 0.9, 0.25])
sm = ScalarMappable(norm=norm, cmap='Spectral_r')
sm.set_array([])
cbar = fig_cb.colorbar(sm, cax=ax_cb, orientation='horizontal')
cbar.ax.set_xlabel(r'XCO$\mathregular{_2}$ (ppm)', fontsize=14)
cbar.ax.tick_params(labelsize=12)
cbar.locator = MaxNLocator(nbins=8)
cbar.update_ticks()
cb_fname = f"xco2_colorbar_{target_date_str.replace('-', '')}.png"
cb_fpath = os.path.join(output_dir, cb_fname)
plt.savefig(cb_fpath, dpi=300, bbox_inches='tight')
plt.close(fig_cb)
print(f"Saved colorbar: {cb_fpath}")