In [None]:
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from matplotlib import rcParams
import numpy as np
import warnings
from matplotlib.patches import Patch

# 全局字体设置
try:
    rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'Microsoft YaHei']
    rcParams['axes.unicode_minus'] = False
except Exception as e:
    warnings.warn(f"字体设置失败: {str(e)}")

def plot_yield_comparison(base_path, years, shapefile_path, output_filename=None):
    """
    绘制产量预测误差对比图
    
    参数:
        base_path (str): Excel文件的基础路径
        years (list): 需要绘制的年份列表
        shapefile_path (str): 形状文件的路径
        output_filename (str, optional): 如果要保存图片，提供文件名
    
    返回:
        None (显示或保存图表)
    """
    # 1. 准备地理数据
    gdf = gpd.read_file(shapefile_path)
    gdf[['province_code', 'city_code']] = gdf[['省级码', '区划码']].astype(int)
    
    # 2. 计算全局地理范围
    total_bounds = gdf.total_bounds
    min_lon, min_lat, max_lon, max_lat = total_bounds
    
    # 3. 创建画布（调整宽高比）
    fig = plt.figure(figsize=(18, 7), dpi=300, facecolor='white')
    gs = fig.add_gridspec(2, 3, height_ratios=[0.9, 0.1], hspace=0.15, wspace=0.02)
    axs = [fig.add_subplot(gs[0, i]) for i in range(3)]
    
    # 4. 数据加载函数
    def safe_load(year):
        try:
            df = pd.read_excel(f"{base_path}/{year}年产量预测评估_3DCNN_xLSTM.xlsx")
            df[['province_code', 'city_code']] = df[['province_code', 'city_code']].astype(int)
            return df
        except Exception as e:
            warnings.warn(f"{year}年数据加载失败: {str(e)}")
            return pd.DataFrame()
    
    # 5. 计算全局误差范围
    error_range = []
    for year in years:
        merged = gdf.merge(safe_load(year), on=['province_code', 'city_code'], how='left')
        error = (merged['预测值'] - merged['真实值']) / merged['真实值']
        error_range.extend(error.dropna().tolist())
    
    vmin, vmax = min(error_range), max(error_range)
    
    # 6. 设置地理刻度参数
    lon_step = 3  # 经度间隔
    lat_step = 2  # 纬度间隔
    
    # 7. 生成经纬度刻度
    lon_ticks = np.arange(np.floor(min_lon), np.ceil(max_lon)+1, lon_step)
    lat_ticks = np.arange(np.floor(min_lat), np.ceil(max_lat)+1, lat_step)
    
    # 8. 绘制每个年份
    for ax, year in zip(axs, years):
        data = safe_load(year)
        if data.empty:
            continue
            
        merged = gdf.merge(data, on=['province_code', 'city_code'], how='left')
        merged['error'] = (merged['预测值'] - merged['真实值']) / merged['真实值']
        mre = (merged['error'].abs().mean()) * 100  # 计算平均绝对相对误差

        # 绘制有数据区域
        valid = merged[merged['error'].notna()]
        if not valid.empty:
            valid.plot(column='error', ax=ax, cmap='coolwarm', 
                      vmin=vmin, vmax=vmax, edgecolor='k', linewidth=0.5)
        
        # 绘制无数据区域
        invalid = merged[merged['error'].isna()]
        if not invalid.empty:
            invalid.plot(ax=ax, facecolor='none', edgecolor='gray', 
                        linewidth=0.8, hatch='///', alpha=0.3)
        
        # 设置地理范围和坐标轴
        ax.set_xlim(min_lon, max_lon)
        ax.set_ylim(min_lat, max_lat)
        
        # 设置经纬度刻度
        ax.set_xticks(lon_ticks)
        ax.set_yticks(lat_ticks)
        ax.set_xticklabels([f'{x:.0f}°E' for x in lon_ticks], fontsize=14)
        ax.set_yticklabels([f'{y:.0f}°N' for y in lat_ticks], fontsize=14)
        
        # 美化坐标轴
        ax.tick_params(axis='both', which='both', length=5)
        for spine in ax.spines.values():
            spine.set_linewidth(1.5)
            spine.set_color('black')
        
        ax.set_title(f'{year} (MRE: {mre:.2f}%)', fontsize=18, pad=5)
        ax.grid(True, linestyle='--', alpha=0.5)
    
    # 9. 添加颜色条
    sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm.set_array(error_range)
    cbar_ax = fig.add_axes([0.2, 0.09, 0.6, 0.04])  # [left, bottom, width, height]
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='horizontal', format='%.2f')
    cbar.set_label('Relative Error (RE)', fontsize=16, labelpad=10)
    cbar.ax.tick_params(labelsize=14, pad=5)
    
    # 10. 显示或保存图表
    if output_filename:
        plt.savefig(output_filename, bbox_inches='tight', dpi=300)
        plt.close()
        print(f"图表已保存为: {output_filename}")
    else:
        plt.show()

# 使用示例:
if __name__ == "__main__":
    # 设置参数
    base_path = "D:/python/crop_yield_prediction"  # Excel文件基础路径
    years = [2010, 2016, 2022]  # 要绘制的年份
    shapefile_path = "D:/Crop/NCC/NCC.shp"  # 形状文件路径
    
    # 调用函数绘图
    plot_yield_comparison(
        base_path=base_path,
        years=years,
        shapefile_path=shapefile_path,
        output_filename="yield_comparison.png"  # 可选: 如果要保存图片
    )