# 海洋流场超分辨率任务绘图标准与效果指标

本notebook介绍海洋流场超分辨率任务中的标准化绘图方法和常用效果指标，为研究和比较提供统一的评估框架。

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.ticker import FuncFormatter
import cv2
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.feature import graycomatrix, graycoprops
import warnings
warnings.filterwarnings('ignore')

## 1. 海洋流场绘图标准

### 1.1 绘图标准概述
**全局绘图风格：**
- 开启`LaTex`渲染
- 统一使用 `Times New Roman` 字体，字号 `18pt`
- 坐标轴线宽、曲线线宽都固定为 `1pt`

**地理坐标处理：**
- 经度格式：`-150.0` → `150°W`
- 纬度格式：`45.0` → `45°N`
- 使用pcolormesh而非imshow以正确处理地理坐标

**颜色条设置：**
- 数据图：对称范围，以真值的最大绝对值为准
- 误差图：对称范围，以模型误差的最大绝对值为准，便于对比
- 统一使用seismic色图（红-白-蓝）

**布局设计：**
- 2行分别对应u、v速度分量
- 6列：LR输入、HR真值、双三次预测、模型预测、双三次误差、模型误差
- 颜色条精确定位在GridSpec的指定列

In [None]:
def setup_ocean_plot_style():
    """设置海洋流场标准化绘图样式"""
    plt.rc('text', usetex=True)
    plt.rc('font', family='Times New Roman', size=18)
    plt.rc('axes', linewidth=1, labelsize=18)
    plt.rc('lines', linewidth=1, markersize=4)
    
    plt.rcParams['xtick.direction'] = 'in'
    plt.rcParams['ytick.direction'] = 'in'
    plt.rcParams['xtick.top'] = True
    plt.rcParams['ytick.right'] = True
    plt.rcParams['xtick.major.pad'] = 7.5
    plt.rcParams['ytick.major.pad'] = 7.5
    plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
    plt.rcParams['text.usetex'] = True

# 经纬度格式化函数
def format_lon(x, pos):
    """经度格式化"""
    return f"{abs(x)}$^\circ${'E' if x >= 0 else 'W'}"

def format_lat(y, pos):
    """纬度格式化"""
    return f"{abs(y)}$^\circ${'N' if y >= 0 else 'S'}"

def plot_field(ax, lon, lat, data, title, vmin, vmax, cmap):
    """
    海洋流场绘图函数
    """
    # 特殊处理：如果绘制的是低分辨率输入，则使用降采样后的坐标
    if 'Input' in title:
        lon = lon[::4, ::4]
        lat = lat[::4, ::4]
    # 使用pcolormesh绘制2D图
    im = ax.pcolormesh(lon, lat, data, cmap=cmap, vmin=vmin, vmax=vmax, shading='auto')
    # 应用自定义的经纬度格式化函数
    ax.xaxis.set_major_formatter(FuncFormatter(format_lon))
    ax.yaxis.set_major_formatter(FuncFormatter(format_lat))
    # 设置子图标题
    ax.set_title(title, fontsize=16)
    # 设置子图的宽高比为1，确保地理图像不会被拉伸
    ax.set_box_aspect(1)
    # 设置Y轴的显示范围
    ax.set_ylim([np.nanmin(lat), np.nanmax(lat)])
    # 设置背景色为浅灰色，这样NaN的区域就会显示为灰色
    ax.set_facecolor('lightgray')
    # 返回pcolormesh的输出对象
    return im

print("海洋流场绘图标准设置完成")

In [None]:
def create_ocean_comparison_plot(data_dict, longitude, latitude, date_str, save_path=None):
    """
    创建海洋流场超分辨率比较图
    Args:
        data_dict: 包含海洋流场数据的字典
            - ubar_input: u分量低分辨率输入
            - ubar_gt: u分量高分辨率真值
            - ubar_bicubic: u分量双三次插值预测
            - ubar_model: u分量模型预测
            - vbar_input: v分量低分辨率输入
            - vbar_gt: v分量高分辨率真值
            - vbar_bicubic: v分量双三次插值预测
            - vbar_model: v分量模型预测
        longitude: 经度坐标数组
        latitude: 纬度坐标数组
        date_str: 日期字符串
        save_path: 保存路径
    """
    setup_ocean_plot_style() # 设置绘图样式
    
    # 计算误差
    ubar_bicubic_error = data_dict['ubar_bicubic'] - data_dict['ubar_gt']
    ubar_model_error = data_dict['ubar_model'] - data_dict['ubar_gt']
    vbar_bicubic_error = data_dict['vbar_bicubic'] - data_dict['vbar_gt']
    vbar_model_error = data_dict['vbar_model'] - data_dict['vbar_gt']
    # 组织数据
    data = [
        [data_dict['ubar_input'], data_dict['ubar_gt'], data_dict['ubar_bicubic'], 
         data_dict['ubar_model'], ubar_bicubic_error, ubar_model_error],
        [data_dict['vbar_input'], data_dict['vbar_gt'], data_dict['vbar_bicubic'], 
         data_dict['vbar_model'], vbar_bicubic_error, vbar_model_error]
    ]
    # 标题
    titles = [
        [r'\begin{center} \textbf{Low-Resolution Input \\ ($\boldsymbol{u_{LR}}$)} \end{center}',
         r'\begin{center} \textbf{High-Resolution Ground Truth \\ (Reference $\boldsymbol{u_{ref}}$)} \end{center}',
         r'\begin{center} \textbf{Bicubic Prediction \\ (Predicted $\boldsymbol{u_{bicubic}}$)} \end{center}',
         r'\begin{center} \textbf{Model Prediction \\ (Predicted $\boldsymbol{u_{model}}$)} \end{center}',
         r'\begin{center} \textbf{Bicubic Pointwise Error} \\  $\boldsymbol{ u_{bicubic} - u_{ref} }$ \end{center}',
         r'\begin{center} \textbf{Model Pointwise Error} \\  $\boldsymbol{ u_{model} - u_{ref} }$ \end{center}'],
        [r'\begin{center} \textbf{Low-Resolution Input \\ ($\boldsymbol{v_{LR}}$)} \end{center}',
         r'\begin{center} \textbf{High-Resolution Ground Truth \\ (Reference $\boldsymbol{v_{ref}}$)} \end{center}',
         r'\begin{center} \textbf{Bicubic Prediction \\ (Predicted $\boldsymbol{v_{bicubic}}$)} \end{center}',
         r'\begin{center} \textbf{Model Prediction \\ (Predicted $\boldsymbol{v_{model}}$)} \end{center}',
         r'\begin{center} \textbf{Bicubic Pointwise Error} \\  $\boldsymbol{ v_{bicubic} - v_{ref} }$ \end{center}',
         r'\begin{center} \textbf{Model Pointwise Error} \\  $\boldsymbol{ v_{model} - v_{ref} }$ \end{center}']
    ]
    
    # 计算颜色条范围
    vmax_ubar_data = np.nanmax(np.abs(data_dict['ubar_gt'])) if np.nanmax(np.abs(data_dict['ubar_gt'])) != 0 else 0.1
    vmin_ubar_data = -vmax_ubar_data
    vmax_vbar_data = np.nanmax(np.abs(data_dict['vbar_gt'])) if np.nanmax(np.abs(data_dict['vbar_gt'])) != 0 else 0.1
    vmin_vbar_data = -vmax_vbar_data
    
    vmax_ubar_pred_error = np.nanmax(np.abs(ubar_model_error)) if np.nanmax(np.abs(ubar_model_error)) != 0 else 0.1
    vmin_ubar_pred_error = -vmax_ubar_pred_error
    vmax_vbar_pred_error = np.nanmax(np.abs(vbar_model_error)) if np.nanmax(np.abs(vbar_model_error)) != 0 else 0.1
    vmin_vbar_pred_error = -vmax_vbar_pred_error
    # 创建画布
    fig = plt.figure(figsize=(36, 12))
    gs = gridspec.GridSpec(2, 8, figure=fig,
                           width_ratios=[1, 1, 1, 1, 0.12, 1, 1, 0.12],
                           wspace=0.1, hspace=0.1)
    
    # 创建子图轴
    axes = [[None for _ in range(6)] for _ in range(2)]
    plot_col_map = [0, 1, 2, 3, 5, 6]
    
    for i in range(2):
        for j in range(6):
            axes[i][j] = fig.add_subplot(gs[i, plot_col_map[j]])
    
    # 初始化mappable对象变量
    im_data_u, im_error_u_pred = None, None
    im_data_v, im_error_v_pred = None, None
    
    # 主绘图循环
    for i in range(2):  # 遍历行 (i=0是u, i=1是v)
        for j in range(6):  # 遍历列 (不同的数据类型)
            ax = axes[i][j]
            
            # 根据列索引选择颜色映射和范围
            if j == 4:  # 双三次插值误差
                cmap = 'seismic'
                vmin = vmin_ubar_pred_error if i == 0 else vmin_vbar_pred_error
                vmax = vmax_ubar_pred_error if i == 0 else vmax_vbar_pred_error
                im = plot_field(ax, longitude, latitude, data[i][j], titles[i][j], vmin, vmax, cmap)
                
            elif j == 5:  # 模型预测误差
                cmap = 'seismic'
                vmin = vmin_ubar_pred_error if i == 0 else vmin_vbar_pred_error
                vmax = vmax_ubar_pred_error if i == 0 else vmax_vbar_pred_error
                im = plot_field(ax, longitude, latitude, data[i][j], titles[i][j], vmin, vmax, cmap)
                if i == 0: im_error_u_pred = im
                else: im_error_v_pred = im

            else:  # 前4个图: 输入、真值、双三次预测、模型预测
                cmap = 'seismic'
                vmin = vmin_ubar_data if i == 0 else vmin_vbar_data
                vmax = vmax_ubar_data if i == 0 else vmax_vbar_data
                im = plot_field(ax, longitude, latitude, data[i][j], titles[i][j], vmin, vmax, cmap)
                if i == 0: im_data_u = im
                else: im_data_v = im
            
            # 坐标轴标签处理
            if i != len(axes) - 1:
                ax.set_xticks([])
                ax.set_xlabel("")
            
            if j != 0:
                ax.set_yticks([])
                ax.set_ylabel("")
    
    # 创建颜色条
    cax_data_u_container = fig.add_subplot(gs[0, 4])
    cax_error_u_pred_container = fig.add_subplot(gs[0, 7])
    cax_data_v_container = fig.add_subplot(gs[1, 4])
    cax_error_v_pred_container = fig.add_subplot(gs[1, 7])
    
    for cax_container in [cax_data_u_container, cax_data_v_container, 
                          cax_error_u_pred_container, cax_error_v_pred_container]:
        cax_container.axis('off')
    
    cax_data_u = inset_axes(cax_data_u_container, width="50%", height="90%", loc='center', 
                            bbox_to_anchor=(-0.6, 0, 1, 1), bbox_transform=cax_data_u_container.transAxes)
    cax_error_u_pred = inset_axes(cax_error_u_pred_container, width="50%", height="90%", loc='center', 
                                  bbox_to_anchor=(-0.6, 0, 1, 1), bbox_transform=cax_error_u_pred_container.transAxes)
    cax_data_v = inset_axes(cax_data_v_container, width="50%", height="90%", loc='center', 
                            bbox_to_anchor=(-0.6, 0, 1, 1), bbox_transform=cax_data_v_container.transAxes)
    cax_error_v_pred = inset_axes(cax_error_v_pred_container, width="50%", height="90%", loc='center', 
                                  bbox_to_anchor=(-0.6, 0, 1, 1), bbox_transform=cax_error_v_pred_container.transAxes)
    
    # 绘制颜色条
    cbar_data_u = fig.colorbar(im_data_u, cax=cax_data_u)
    cbar_error_u_pred = fig.colorbar(im_error_u_pred, cax=cax_error_u_pred)
    cbar_data_v = fig.colorbar(im_data_v, cax=cax_data_v)
    cbar_error_v_pred = fig.colorbar(im_error_v_pred, cax=cax_error_v_pred)
    
    # 添加日期标签
    fig.text(0.07, 0.5, date_str,
            ha='center', va='center', rotation='horizontal',
            fontsize=28, fontweight='bold')
    
    # 调整布局
    fig.tight_layout(rect=[0, 0, 1, 1])
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"海洋流场比较图已保存到 {save_path}")
    
    plt.show()

print("海洋流场比较绘图函数已创建")

## 2. 效果指标

### 2.1 海洋流场指标

针对海洋流场数据的特点，指标计算需要考虑：
- 向量场的特性（u、v分量）
- 地理空间数据的特点
- NaN值的处理

In [None]:
class OceanSuperResolutionMetrics:
    """海洋流场超分辨率指标计算类"""
    
    def __init__(self):
        pass
    
    def psnr(self, img1, img2, data_range=None):
        """
        峰值信噪比 (Peak Signal-to-Noise Ratio, PSNR)
        
        PSNR = 10 * log10( (data_range^2) / MSE )
        
        Args:
            data_range: 数据动态范围, 默认取 img2 的有效像素范围
        Returns:
            psnr_value: float
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0
        img1_clean = img1[mask]
        img2_clean = img2[mask]
        if data_range is None:
            data_range = img2_clean.max() - img2_clean.min()
        return psnr(img2_clean, img1_clean, data_range=data_range)

    def ssim(self, img1, img2, K1=0.01, K2=0.03, L=None, window_size=11):
        """
        结构相似性指数 (Structural Similarity Index, SSIM)
        
        SSIM = ((2*mu_x*mu_y + C1) * (2*sigma_xy + C2)) /
               ((mu_x^2 + mu_y^2 + C1) * (sigma_x^2 + sigma_y^2 + C2))
        
        其中：
            mu_x, mu_y: 有效像素均值
            sigma_x^2, sigma_y^2: 有效像素方差
            sigma_xy: 有效像素协方差
            C1 = (K1*L)^2, C2 = (K2*L)^2 为稳定常数
        
        Args:
            K1, K2: 常数，默认 0.01, 0.03
            L: 数据动态范围，如果 None，取有效像素范围
        Returns:
            ssim_val: float
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0

        img1_valid = img1[mask]
        img2_valid = img2[mask]

        if L is None:
            L = np.nanmax([img1_valid, img2_valid]) - np.nanmin([img1_valid, img2_valid])
        C1 = (K1*L)**2
        C2 = (K2*L)**2

        mu_x = np.mean(img1_valid)
        mu_y = np.mean(img2_valid)
        sigma_x2 = np.var(img1_valid)
        sigma_y2 = np.var(img2_valid)
        sigma_xy = np.cov(img1_valid, img2_valid, ddof=0)[0,1]

        ssim_val = ((2*mu_x*mu_y + C1)*(2*sigma_xy + C2)) / ((mu_x**2 + mu_y**2 + C1)*(sigma_x2 + sigma_y2 + C2))
        return float(ssim_val)

    def mae(self, img1, img2):
        """
        平均绝对误差 (Mean Absolute Error, MAE)
        
        MAE = mean( | img1 - img2 | ) 只在有效像素计算
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0
        return np.mean(np.abs(img1[mask] - img2[mask]))

    def mse(self, img1, img2):
        """
        均方误差 (Mean Squared Error, MSE)
        
        MSE = mean( (img1 - img2)^2 ) 只在有效像素计算
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0
        return np.mean((img1[mask] - img2[mask]) ** 2)

    def rmse(self, img1, img2):
        """
        均方根误差 (Root Mean Squared Error, RMSE)
        
        RMSE = sqrt(MSE)
        """
        return np.sqrt(self.mse(img1, img2))

    def relative_l2_error(self, img1, img2):
        """
        相对 L2 误差 (Relative L2 Error)
        
        Relative L2 = || img1 - img2 ||_2 / || img2 ||_2
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0
        img1_clean = img1[mask]
        img2_clean = img2[mask]
        l2_error = np.linalg.norm(img1_clean - img2_clean)
        l2_norm = np.linalg.norm(img2_clean)
        return l2_error / l2_norm if l2_norm > 0 else 0.0

    def fsim(self, img1, img2):
        """
        特征相似性指数 (Feature Similarity Index, FSIM, mask-aware)
        
        FSIM = mean( (2*|G1|*|G2| + eps) / (|G1|^2 + |G2|^2 + eps) )
        其中 G1, G2 为图像梯度幅值
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0

        img1_valid = np.where(mask, img1, 0).astype(np.float32)
        img2_valid = np.where(mask, img2, 0).astype(np.float32)

        grad1_x = cv2.Sobel(img1_valid, cv2.CV_64F, 1, 0, ksize=3)
        grad1_y = cv2.Sobel(img1_valid, cv2.CV_64F, 0, 1, ksize=3)
        grad2_x = cv2.Sobel(img2_valid, cv2.CV_64F, 1, 0, ksize=3)
        grad2_y = cv2.Sobel(img2_valid, cv2.CV_64F, 0, 1, ksize=3)

        grad1_mag = np.sqrt(grad1_x**2 + grad1_y**2)
        grad2_mag = np.sqrt(grad2_x**2 + grad2_y**2)

        numerator = 2 * grad1_mag * grad2_mag + 1e-8
        denominator = grad1_mag**2 + grad2_mag**2 + 1e-8
        similarity = numerator / denominator

        return float(np.mean(similarity[mask]))

    def gpp(self, img1, img2):
        """
        梯度相位相关性 (Gradient Phase Correlation, GPP, mask-aware)
        
        GPP = (corr(grad_x) + corr(grad_y)) / 2
        其中 corr() 计算有效像素的 Pearson 相关系数
        """
        mask = ~(np.isnan(img1) | np.isnan(img2))
        if not np.any(mask):
            return 0.0

        img1_valid = np.where(mask, img1, 0).astype(np.float32)
        img2_valid = np.where(mask, img2, 0).astype(np.float32)

        grad1_x = cv2.Sobel(img1_valid, cv2.CV_64F, 1, 0, ksize=3)
        grad1_y = cv2.Sobel(img1_valid, cv2.CV_64F, 0, 1, ksize=3)
        grad2_x = cv2.Sobel(img2_valid, cv2.CV_64F, 1, 0, ksize=3)
        grad2_y = cv2.Sobel(img2_valid, cv2.CV_64F, 0, 1, ksize=3)

        try:
            corr_x = np.corrcoef(grad1_x[mask].flatten(), grad2_x[mask].flatten())[0, 1]
            corr_y = np.corrcoef(grad1_y[mask].flatten(), grad2_y[mask].flatten())[0, 1]
        except:
            return 0.0

        corr_x = 0.0 if np.isnan(corr_x) else corr_x
        corr_y = 0.0 if np.isnan(corr_y) else corr_y

        return float((corr_x + corr_y) / 2)
    
    def compute_ocean_metrics(self, pred_u, pred_v, gt_u, gt_v, data_range=None):
        """
        计算海洋流场的所有指标（u和v分量，mask-aware版本）

        Args:
            pred_u, pred_v: 预测的u、v速度分量 (2D np.ndarray)
            gt_u, gt_v: 真值的u、v速度分量 (2D np.ndarray)
            data_range: 数据范围（可选）

        Returns:
            dict: 包含u、v分量和综合指标的字典
        """
        # 确保输入为float类型
        pred_u = pred_u.astype(np.float64)
        pred_v = pred_v.astype(np.float64)
        gt_u = gt_u.astype(np.float64)
        gt_v = gt_v.astype(np.float64)

        metrics = {}

        # ===== u分量指标 =====
        metrics['u_PSNR'] = self.psnr(pred_u, gt_u, data_range)
        metrics['u_SSIM'] = self.ssim(pred_u, gt_u)  # mask-aware ssim
        metrics['u_FSIM'] = self.fsim(pred_u, gt_u)
        metrics['u_MAE'] = self.mae(pred_u, gt_u)
        metrics['u_MSE'] = self.mse(pred_u, gt_u)
        metrics['u_RMSE'] = self.rmse(pred_u, gt_u)
        metrics['u_Relative_L2_Error'] = self.relative_l2_error(pred_u, gt_u)
        metrics['u_GPP'] = self.gpp(pred_u, gt_u)

        # ===== v分量指标 =====
        metrics['v_PSNR'] = self.psnr(pred_v, gt_v, data_range)
        metrics['v_SSIM'] = self.ssim(pred_v, gt_v)
        metrics['v_FSIM'] = self.fsim(pred_v, gt_v)
        metrics['v_MAE'] = self.mae(pred_v, gt_v)
        metrics['v_MSE'] = self.mse(pred_v, gt_v)
        metrics['v_RMSE'] = self.rmse(pred_v, gt_v)
        metrics['v_Relative_L2_Error'] = self.relative_l2_error(pred_v, gt_v)
        metrics['v_GPP'] = self.gpp(pred_v, gt_v)

        # ===== 平均指标 =====
        metrics['avg_PSNR'] = (metrics['u_PSNR'] + metrics['v_PSNR']) / 2
        metrics['avg_SSIM'] = (metrics['u_SSIM'] + metrics['v_SSIM']) / 2
        metrics['avg_FSIM'] = (metrics['u_FSIM'] + metrics['v_FSIM']) / 2
        metrics['avg_MAE'] = (metrics['u_MAE'] + metrics['v_MAE']) / 2
        metrics['avg_MSE'] = (metrics['u_MSE'] + metrics['v_MSE']) / 2
        metrics['avg_RMSE'] = (metrics['u_RMSE'] + metrics['v_RMSE']) / 2
        metrics['avg_Relative_L2_Error'] = (metrics['u_Relative_L2_Error'] + metrics['v_Relative_L2_Error']) / 2
        metrics['avg_GPP'] = (metrics['u_GPP'] + metrics['v_GPP']) / 2

        # ===== 向量场幅值指标 =====
        pred_mag = np.sqrt(pred_u**2 + pred_v**2)
        gt_mag = np.sqrt(gt_u**2 + gt_v**2)

        metrics['magnitude_PSNR'] = self.psnr(pred_mag, gt_mag, data_range)
        metrics['magnitude_SSIM'] = self.ssim(pred_mag, gt_mag)
        metrics['magnitude_MAE'] = self.mae(pred_mag, gt_mag)
        metrics['magnitude_RMSE'] = self.rmse(pred_mag, gt_mag)

        return metrics


# 创建海洋流场指标计算器实例
ocean_metrics_calculator = OceanSuperResolutionMetrics()
print("海洋流场效果指标计算器已初始化")

## 3. 实际应用示例

### 3.1 加载数据(使用Task1保存的数据)

In [None]:
# 请填入task1保存的数据的路径
u_gt_path = '' 
u_lq_path = '' 
u_sr_path = ''
v_gt_path = ''
v_lq_path = ''
v_sr_path = ''

lat_path = './data/static/lat.npy'
lon_path = './data/static/lon.npy'
mask_path = './data/static/mask.npy'

date_str = ""  # 请填入Task1的数据的日期

scale_factor = 4  # 假设超分辨率因子为4
# 加载数据
u_gt_data = np.load(u_gt_path)
u_lq_data = np.load(u_lq_path)
u_sr_data = np.load(u_sr_path)
v_gt_data = np.load(v_gt_path)
v_lq_data = np.load(v_lq_path)
v_sr_data = np.load(v_sr_path)

gt_data = np.stack([u_gt_data, v_gt_data], axis=0)
lq_data = np.stack([u_lq_data, v_lq_data], axis=0)
sr_data = np.stack([u_sr_data, v_sr_data], axis=0)

lat = np.load(lat_path)
lon = np.load(lon_path)
mask = np.load(mask_path)

# 将mask区域填为nan
gt_data = np.where(mask, gt_data, np.nan)
lq_data = np.where(mask, lq_data, np.nan)
sr_data = np.where(mask, sr_data, np.nan)

ocean_data = {
    'ubar_input': gt_data[0, ::scale_factor, ::scale_factor],
    'vbar_input': gt_data[1, ::scale_factor, ::scale_factor],
    'ubar_gt': gt_data[0, :, :],
    'vbar_gt': gt_data[1, :, :],
    'ubar_bicubic': lq_data[0, :, :],
    'vbar_bicubic': lq_data[1, :, :],
    'ubar_model': sr_data[0, :, :],  
    'vbar_model': sr_data[1, :, :],  
    'longitude': lon,
    'latitude': lat
}
print("海洋流场数据加载完成")

## 3.2 打印与保存指标

In [None]:
print("=== 双三次插值方法指标 ===")
bicubic_metrics = ocean_metrics_calculator.compute_ocean_metrics(
    ocean_data['ubar_bicubic'], ocean_data['vbar_bicubic'],
    ocean_data['ubar_gt'], ocean_data['vbar_gt']
)

print("\n--- u分量指标 ---")
for metric in ['u_PSNR', 'u_SSIM', 'u_FSIM', 'u_MAE', 'u_MSE', 'u_RMSE', 'u_Relative_L2_Error', 'u_GPP']:
    print(f"{metric}: {bicubic_metrics[metric]:.4f}")

print("\n--- v分量指标 ---")
for metric in ['v_PSNR', 'v_SSIM', 'v_FSIM', 'v_MAE', 'v_MSE', 'v_RMSE', 'v_Relative_L2_Error', 'v_GPP']:
    print(f"{metric}: {bicubic_metrics[metric]:.4f}")

print("\n--- 综合指标 ---")
for metric in ['avg_PSNR', 'avg_SSIM', 'avg_FSIM', 'avg_MAE', 'avg_MSE', 'avg_RMSE', 'avg_Relative_L2_Error', 'avg_GPP']:
    print(f"{metric}: {bicubic_metrics[metric]:.4f}")

print("\n--- 速度幅值指标 ---")
for metric in ['magnitude_PSNR', 'magnitude_SSIM', 'magnitude_MAE', 'magnitude_RMSE']:
    print(f"{metric}: {bicubic_metrics[metric]:.4f}")

print("\n" + "="*50)
print("=== 模型预测指标 ===")
model_metrics = ocean_metrics_calculator.compute_ocean_metrics(
    ocean_data['ubar_model'], ocean_data['vbar_model'],
    ocean_data['ubar_gt'], ocean_data['vbar_gt']
)

print("\n--- u分量指标 ---")
for metric in ['u_PSNR', 'u_SSIM', 'u_FSIM', 'u_MAE', 'u_MSE', 'u_RMSE', 'u_Relative_L2_Error', 'u_GPP']:
    print(f"{metric}: {model_metrics[metric]:.4f}")

print("\n--- v分量指标 ---")
for metric in ['v_PSNR', 'v_SSIM', 'v_FSIM', 'v_MAE', 'v_MSE', 'v_RMSE', 'v_Relative_L2_Error', 'v_GPP']:
    print(f"{metric}: {model_metrics[metric]:.4f}")

print("\n--- 综合指标 ---")
for metric in ['avg_PSNR', 'avg_SSIM', 'avg_FSIM', 'avg_MAE', 'avg_MSE', 'avg_RMSE', 'avg_Relative_L2_Error', 'avg_GPP']:
    print(f"{metric}: {model_metrics[metric]:.4f}")

print("\n--- 速度幅值指标 ---")
for metric in ['magnitude_PSNR', 'magnitude_SSIM', 'magnitude_MAE', 'magnitude_RMSE']:
    print(f"{metric}: {model_metrics[metric]:.4f}")

print("\n" + "="*50)
print("=== 指标改进情况 ===")
improvement_metrics = ['avg_PSNR', 'avg_SSIM', 'avg_FSIM', 'avg_GPP']
reduction_metrics = ['avg_MAE', 'avg_MSE', 'avg_RMSE', 'avg_Relative_L2_Error']

for metric in improvement_metrics:
    improvement = model_metrics[metric] - bicubic_metrics[metric]
    print(f"{metric}改进: {improvement:+.4f}")

for metric in reduction_metrics:
    improvement = bicubic_metrics[metric] - model_metrics[metric]
    reduction_pct = (improvement / bicubic_metrics[metric]) * 100 if bicubic_metrics[metric] != 0 else 0
    print(f"{metric}降低: {improvement:.4f} ({reduction_pct:.1f}%)")

## 3.3 可视化

In [None]:
# 组织要保存的字典
results_to_save = {
    "bicubic_metrics": bicubic_metrics,
    "model_metrics": model_metrics,
    "improvement_metrics": {},
    "reduction_metrics": {}
}

# 计算改进情况
for metric in improvement_metrics:
    improvement = model_metrics[metric] - bicubic_metrics[metric]
    results_to_save["improvement_metrics"][metric] = improvement

for metric in reduction_metrics:
    improvement = bicubic_metrics[metric] - model_metrics[metric]
    reduction_pct = (improvement / bicubic_metrics[metric]) * 100 if bicubic_metrics[metric] != 0 else 0
    results_to_save["reduction_metrics"][metric] = {
        "absolute": improvement,
        "percentage": reduction_pct
    }

# 保存为 JSON 文件
output_file = "ocean_metrics_results.json"
with open(output_file, "w") as f:
    json.dump(results_to_save, f, indent=4)

print(f"指标结果已保存到 {output_file}")

In [None]:
# 创建海洋流场标准化可视化
create_ocean_comparison_plot(
    ocean_data, 
    ocean_data['longitude'], 
    ocean_data['latitude'], 
    date_str,
    save_path="ocean_super_resolution_comparison.png"
)

## 注意事项

1. **数据格式**：确保数据为numpy数组，形状匹配
2. **NaN处理**：海洋数据通常包含陆地区域的NaN值，已在指标计算中处理
3. **坐标系统**：确保经纬度坐标与数据尺寸匹配