In [None]:
import ee
import geemap
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

ee.Authenticate()

ee.Initialize()

# ==================== 1. 配置参数 ====================
# 时间范围 - 扩展到7-9月
START_DATE = '2022-07-01'
END_DATE = '2022-10-01'

# 空间范围
BOUNDARY_PATH = '/Users/hanxu/geemap/bdy.shp'
bdy = geemap.shp_to_ee(BOUNDARY_PATH)

# 统一的目标投影和分辨率 - 使用30m以适配所有数据源
TARGET_SCALE = 30  # 30米分辨率
TARGET_CRS = 'EPSG:4326'  # WGS84

# 输出路径
OUTPUT_DIR = '/Users/hanxu/geemap/out_plots'
os.makedirs(OUTPUT_DIR, exist_ok=True)
OUTPUT_CSV = 'multi_source_features_2022_07_09.csv'
SOIL_DATA_PATH = '/Users/hanxu/geemap/material/soil/soil_2022_08.csv'

print("="*60)
print("多源遥感数据融合系统 - 盐渍化监测")
print("="*60)
print(f"时间范围: {START_DATE} 至 {END_DATE}")
print(f"空间分辨率: {TARGET_SCALE}米")
print(f"输出目录: {OUTPUT_DIR}")

# ==================== 2. Landsat 8 处理（包含热红外波段）====================

def process_landsat8_comprehensive(start_date, end_date, boundary):
    """
    综合处理Landsat 8数据，包含所有相关指数
    """
    print("\n处理Landsat 8数据...")
    
    def apply_scale_factors(image):
        """应用缩放因子"""
        optical = image.select('SR_B.*').multiply(0.0000275).add(-0.2)
        thermal = image.select('ST_B10').multiply(0.00341802).add(149.0)
        return image.addBands(optical, None, True).addBands(thermal, None, True)
    
    def mask_l8_clouds(image):
        """云掩膜"""
        qa = image.select('QA_PIXEL')
        mask = qa.bitwiseAnd(1 << 3).eq(0).And(qa.bitwiseAnd(1 << 4).eq(0))
        return image.updateMask(mask)
    
    def calculate_comprehensive_indices(image):
        """计算全面的光谱指数"""
        # 基础波段
        b2 = image.select('SR_B2')  # Blue
        b3 = image.select('SR_B3')  # Green
        b4 = image.select('SR_B4')  # Red
        b5 = image.select('SR_B5')  # NIR
        b6 = image.select('SR_B6')  # SWIR1
        b7 = image.select('SR_B7')  # SWIR2
        b10 = image.select('ST_B10') # Thermal
        
        # 植被指数
        ndvi = b5.subtract(b4).divide(b5.add(b4)).rename('NDVI')
        evi = b5.subtract(b4).divide(b5.add(b4.multiply(6)).subtract(b2.multiply(7.5)).add(1)).multiply(2.5).rename('EVI')
        savi = b5.subtract(b4).divide(b5.add(b4).add(0.5)).multiply(1.5).rename('SAVI')
        msavi = (b5.multiply(2).add(1).subtract(
            (b5.multiply(2).add(1)).pow(2).subtract(b5.subtract(b4).multiply(8))
            .sqrt())).divide(2).rename('MSAVI')
        
        # 水体指数
        ndwi = b3.subtract(b5).divide(b3.add(b5)).rename('NDWI')
        mndwi = b3.subtract(b6).divide(b3.add(b6)).rename('MNDWI')
        
        # 盐度指数（多种）
        si1 = b2.multiply(b4).sqrt().rename('SI1')
        si2 = b2.pow(2).add(b3.pow(2)).add(b4.pow(2)).sqrt().rename('SI2')
        si3 = b3.pow(2).add(b4.pow(2)).sqrt().rename('SI3')
        si4 = b2.subtract(b3).pow(2).add(b3.subtract(b4).pow(2)).sqrt().rename('SI4')
        
        # 盐度相关的归一化指数
        s1 = b2.divide(b4).rename('S1')
        s2 = b2.subtract(b4).divide(b2.add(b4)).rename('S2')
        s3 = b3.multiply(b4).divide(b2).rename('S3')
        s5 = b2.multiply(b4).divide(b3).rename('S5')
        s6 = b4.multiply(b5).divide(b3).rename('S6')
        
        # SWIR相关指数
        ndsi = b6.subtract(b5).divide(b6.add(b5)).rename('NDSI')
        si_msi = b6.divide(b5).rename('SI_MSI')
        
        # 土壤指数
        bsi = ((b6.add(b4)).subtract(b5.add(b2))).divide((b6.add(b4)).add(b5.add(b2))).rename('BSI')
        bi = b2.pow(2).add(b3.pow(2)).add(b4.pow(2)).sqrt().rename('BI')
        
        # 热红外相关指数（重要！）
        # 温度-植被指数
        tvdi = b10.subtract(b10.reduce(ee.Reducer.min())).divide(
            b10.reduce(ee.Reducer.max()).subtract(b10.reduce(ee.Reducer.min()))
        ).rename('TVDI')
        
        # 温度与NDVI的交互
        temp_ndvi_ratio = b10.divide(ndvi.add(1)).rename('Temp_NDVI_ratio')
        
        return image.addBands([
            ndvi, evi, savi, msavi, ndwi, mndwi,
            si1, si2, si3, si4, s1, s2, s3, s5, s6,
            ndsi, si_msi, bsi, bi, tvdi, temp_ndvi_ratio
        ])
    
    # 获取第一个影像的投影
    first_image = ee.Image(ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
                          .filterBounds(boundary).first())
    target_projection = first_image.select('SR_B2').projection()
    
    # 加载并处理Landsat 8集合
    l8_collection = (ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
                     .filterDate(start_date, end_date)
                     .filterBounds(boundary)
                     .map(apply_scale_factors)
                     .map(mask_l8_clouds)
                     .map(calculate_comprehensive_indices))
    
    print(f"  找到 {l8_collection.size().getInfo()} 幅Landsat 8影像")
    
    # 创建中值合成
    l8_composite = l8_collection.median().clip(boundary)
    
    # 选择所有相关波段
    l8_bands = [
        # 原始波段
        'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7', 'ST_B10',
        # 所有计算的指数
        'NDVI', 'EVI', 'SAVI', 'MSAVI', 'NDWI', 'MNDWI',
        'SI1', 'SI2', 'SI3', 'SI4', 'S1', 'S2', 'S3', 'S5', 'S6',
        'NDSI', 'SI_MSI', 'BSI', 'BI', 'TVDI', 'Temp_NDVI_ratio'
    ]
    
    return l8_composite.select(l8_bands), target_projection

# ==================== 3. Sentinel-1 处理（优化的雷达指数）====================

def process_sentinel1_enhanced(start_date, end_date, boundary, target_projection, target_scale):
    """
    增强的Sentinel-1处理，包含更多雷达指数
    注意：在GEE中，VV和VH的单位是dB，所以VV-VH更合适
    """
    print("\n处理Sentinel-1数据...")
    
    def preprocess_s1(image):
        # 获取极化波段
        vv = image.select('VV')
        vh = image.select('VH')
        angle = image.select('angle')
        
        # 应用改进的滤波（使用更大的核以减少斑点噪声）
        vv_filtered = vv.focal_median(radius=50, kernelType='circle', units='meters').rename('VV')
        vh_filtered = vh.focal_median(radius=50, kernelType='circle', units='meters').rename('VH')
        
        # 重要：在dB尺度下，使用减法计算交叉极化差异
        # VV - VH 在dB尺度等同于线性尺度的 VV/VH
        cross_pol_diff = vv_filtered.subtract(vh_filtered).rename('VV_VH_diff')
        
        # 归一化差异极化指数（转换到线性尺度计算）
        vv_linear = ee.Image(10).pow(vv_filtered.divide(10))
        vh_linear = ee.Image(10).pow(vh_filtered.divide(10))
        
        # 雷达植被指数 (RVI) - 在线性尺度计算
        rvi = vh_linear.multiply(4).divide(vv_linear.add(vh_linear)).rename('RVI')
        
        # 双极化SAR植被指数 (DPSVI)
        dpsvi = vv_linear.add(vh_linear).divide(vv_linear).rename('DPSVI')
        
        # 极化比（在dB尺度就是差值）
        pol_ratio = cross_pol_diff.rename('Pol_Ratio')
        
        # 雷达土壤湿度指数
        # 使用VH作为土壤湿度的代理（VH对土壤湿度更敏感）
        soil_moisture_index = vh_filtered.multiply(-1).add(5).rename('SMI')  # 简单的线性变换
        
        # VV和VH的标准化（0-1范围）
        vv_norm = vv_filtered.unitScale(-30, 0).rename('VV_norm')
        vh_norm = vh_filtered.unitScale(-30, -5).rename('VH_norm')
        
        return image.addBands([
            vv_filtered, vh_filtered, cross_pol_diff, 
            rvi, dpsvi, pol_ratio, soil_moisture_index,
            vv_norm, vh_norm, angle
        ])
    
    # 加载Sentinel-1集合
    s1_collection = (ee.ImageCollection('COPERNICUS/S1_GRD')
                     .filterDate(start_date, end_date)
                     .filterBounds(boundary)
                     .filter(ee.Filter.eq('instrumentMode', 'IW'))
                     .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
                     .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
                     .filter(ee.Filter.eq('orbitProperties_pass', 'ASCENDING'))
                     .map(preprocess_s1))
    
    print(f"  找到 {s1_collection.size().getInfo()} 幅Sentinel-1影像")
    
    # 创建中值合成
    s1_composite = s1_collection.median()
    
    # 重投影到目标投影和分辨率
    s1_resampled = s1_composite.reproject(
        crs=target_projection,
        scale=target_scale
    ).clip(boundary)
    
    # 选择最终波段
    s1_bands = [
        'VV', 'VH', 'VV_VH_diff', 'RVI', 'DPSVI', 
        'Pol_Ratio', 'SMI', 'VV_norm', 'VH_norm', 'angle'
    ]
    
    return s1_resampled.select(s1_bands)

# ==================== 4. 环境因子处理 ====================

def calculate_TWI(boundary, target_projection, target_scale):
    """使用HydroSHEDS数据计算TWI"""
    print("使用HydroSHEDS数据计算TWI...")

    # 导入HydroSHEDS数据 (Flow Accumulation)
    fa = ee.Image("WWF/HydroSHEDS/15ACC").clip(boundary).float()
    dem = ee.Image("USGS/SRTMGL1_003").clip(boundary)

    # 计算坡度（度）并转换为弧度
    slope_rad = ee.Terrain.slope(dem).multiply(np.pi / 180).rename('slope_rad')

    # 计算TWI
    area = fa.multiply(900).rename('specific_catchment_area')  # 每像元面积约30m×30m = 900 m²
    TWI = (area.add(1)).log().subtract(slope_rad.tan().add(0.001).log()).rename('TWI')

    # 重投影到目标投影和分辨率
    TWI_resampled = TWI.reproject(crs=target_projection, scale=target_scale).clip(boundary)

    # 检查TWI统计值
    twi_stats = TWI_resampled.reduceRegion(
        reducer=ee.Reducer.minMax(),
        geometry=boundary,
        scale=target_scale,
        bestEffort=True,
        maxPixels=1e13
    ).getInfo()

    print("✅ TWI波段统计值：", twi_stats)

    if twi_stats['TWI_min'] is None or twi_stats['TWI_max'] is None:
        raise ValueError("⚠️ TWI波段仍无有效数据，请再次检查")
    else:
        print(f"TWI 最小值: {twi_stats['TWI_min']:.4f}")
        print(f"TWI 最大值: {twi_stats['TWI_max']:.4f}")

    return TWI_resampled

def process_environmental_factors_enhanced(start_date, end_date, boundary, target_projection, target_scale):
    """
    完整环境因子处理函数（修复并使用HydroSHEDS计算TWI）
    """
    print("\n处理环境因子数据...")

    env_features = ee.Image().select([])

    # 4.1 MODIS ET (蒸散发) - 8天产品
    try:
        et_collection = (ee.ImageCollection('MODIS/061/MOD16A2')
                        .filterDate(start_date, end_date)
                        .filterBounds(boundary))

        if et_collection.size().gt(0):
            et_mean = et_collection.select('ET').mean().rename('ET_mean')
            et_max = et_collection.select('ET').max().rename('ET_max')
            et_std = et_collection.select('ET').reduce(ee.Reducer.stdDev()).rename('ET_std')
            pet_mean = et_collection.select('PET').mean().rename('PET_mean')

            for band in [et_mean, et_max, et_std, pet_mean]:
                band_resampled = band.reproject(crs=target_projection, scale=target_scale).clip(boundary)
                env_features = env_features.addBands(band_resampled)

            print("  ✓ ET数据处理完成")
    except Exception as e:
        print(f"  ✗ ET数据处理失败: {e}")

    # 4.2 CHIRPS降水数据
    try:
        precip_collection = (ee.ImageCollection('UCSB-CHG/CHIRPS/DAILY')
                            .filterDate(start_date, end_date)
                            .filterBounds(boundary))

        if precip_collection.size().gt(0):
            precip_sum = precip_collection.sum().rename('precip_total')
            precip_mean = precip_collection.mean().rename('precip_mean')
            precip_max = precip_collection.max().rename('precip_max_daily')

            precip_days = precip_collection.map(lambda img: img.gt(1)).sum()
            total_days = precip_collection.size()
            precip_frequency = precip_days.divide(total_days).rename('precip_frequency')

            for band in [precip_sum, precip_mean, precip_max, precip_frequency]:
                band_resampled = band.reproject(crs=target_projection, scale=target_scale).clip(boundary)
                env_features = env_features.addBands(band_resampled)

            print("  ✓ 降水数据处理完成")
    except Exception as e:
        print(f"  ✗ 降水数据处理失败: {e}")

    # 4.3 地下水数据
    try:
        gw_image = ee.Image('projects/ee-hanxu0223/assets/gw_2022_08')
        gw_band = gw_image.select([0]).rename('groundwater_depth')
        gw_resampled = gw_band.reproject(crs=target_projection, scale=target_scale).clip(boundary)
        env_features = env_features.addBands(gw_resampled)
        print("  ✓ 地下水数据处理完成")
    except Exception as e:
        print(f"  ✗ 地下水数据处理失败: {e}")

    # 4.4 地形数据（明确使用HydroSHEDS数据计算TWI）
    try:
        dem = ee.Image('USGS/SRTMGL1_003').clip(boundary)
        elevation = dem.select('elevation').rename('elevation')
        slope = ee.Terrain.slope(elevation).rename('slope')
        aspect = ee.Terrain.aspect(elevation).rename('aspect')

        # 明确调用calculate_TWI函数
        TWI_resampled = calculate_TWI(boundary, target_projection, target_scale)

        hillshade = ee.Terrain.hillshade(elevation).rename('hillshade')

        # 重投影所有地形数据（包含TWI）
        for band in [elevation, slope, aspect, TWI_resampled, hillshade]:
            band_resampled = band.reproject(crs=target_projection, scale=target_scale).clip(boundary)
            env_features = env_features.addBands(band_resampled)

        print("  ✓ 地形数据处理完成（含有效TWI）")
    except Exception as e:
        print(f"  ✗ 地形数据处理失败: {e}")

    return env_features

# ==================== 5. 数据融合和掩膜 ====================

def create_comprehensive_landcover_mask(boundary):
    """创建综合的土地覆盖掩膜"""
    print("\n创建土地覆盖掩膜...")
    
    # ESA WorldCover
    esa_worldcover = ee.ImageCollection("ESA/WorldCover/v200").first().clip(boundary)
    landcover_map = esa_worldcover.select('Map')
    
    # 排除：建筑（50）、水体（80）、永久冰雪（70）
    valid_mask = (landcover_map.neq(50)
                  .And(landcover_map.neq(80))
                  .And(landcover_map.neq(70)))
    
    print("  ✓ 掩膜创建完成（排除建筑、水体、冰雪）")
    
    return valid_mask

# ==================== 6. 多源数据融合主函数 ====================

def fuse_multisource_data_comprehensive():
    """
    综合融合多源遥感数据（修复后，返回target_projection）
    """
    print("\n" + "="*60)
    print("开始多源数据融合...")
    print("="*60)
    
    # 创建地图对象
    Map = geemap.Map()
    Map.add_basemap('HYBRID')
    Map.centerObject(bdy, 10)
    
    # 1. 处理Landsat 8（获取投影信息）
    l8_features, target_projection = process_landsat8_comprehensive(START_DATE, END_DATE, bdy)
    
    # 2. 处理Sentinel-1
    s1_features = process_sentinel1_enhanced(START_DATE, END_DATE, bdy, target_projection, TARGET_SCALE)
    
    # 3. 处理环境因子
    env_features = process_environmental_factors_enhanced(START_DATE, END_DATE, bdy, target_projection, TARGET_SCALE)
    
    # 4. 合并所有特征
    print("\n合并多源特征...")
    all_features = ee.Image.cat([l8_features, s1_features, env_features])
    
    # 5. 应用土地覆盖掩膜
    landcover_mask = create_comprehensive_landcover_mask(bdy)
    all_features_masked = all_features.updateMask(landcover_mask)
    
    # 6. 获取并打印所有波段名称
    band_names = all_features_masked.bandNames().getInfo()
    print(f"\n融合后的特征波段总数: {len(band_names)}")
    
    # 按数据源分类显示
    print("\nLandsat 8 特征:")
    l8_bands = [b for b in band_names if b.startswith(('SR_B', 'ST_B', 'ND', 'SI', 'S', 'EVI', 'SAVI', 'MSAVI', 'BSI', 'BI', 'TVDI', 'Temp'))]
    for i, band in enumerate(l8_bands):
        print(f"  {i+1}. {band}")
    
    print(f"\nSentinel-1 特征:")
    s1_bands = [b for b in band_names if b.startswith(('VV', 'VH', 'RVI', 'DPSVI', 'Pol', 'SMI', 'angle'))]
    for i, band in enumerate(s1_bands):
        print(f"  {i+1}. {band}")
    
    print(f"\n环境因子特征:")
    env_bands = [b for b in band_names if b.startswith(('ET', 'PET', 'precip', 'groundwater', 'elevation', 'slope', 'aspect', 'TWI', 'hillshade'))]
    for i, band in enumerate(env_bands):
        print(f"  {i+1}. {band}")
    
    # 7. 可视化
    Map.addLayer(bdy, {}, 'Study Area')
    
    # Landsat 8 真彩色
    Map.addLayer(l8_features.select(['SR_B4', 'SR_B3', 'SR_B2']), 
                 {'min': 0, 'max': 0.3}, 'Landsat 8 True Color')
    
    # Landsat 8 假彩色（突出植被）
    Map.addLayer(l8_features.select(['SR_B5', 'SR_B4', 'SR_B3']), 
                 {'min': 0, 'max': 0.4}, 'Landsat 8 False Color')
    
    # 热红外波段
    Map.addLayer(l8_features.select('ST_B10'), 
                 {'min': 290, 'max': 320, 'palette': ['blue', 'white', 'red']}, 
                 'Land Surface Temperature')
    
    # Sentinel-1 VV
    Map.addLayer(s1_features.select('VV'), 
                 {'min': -25, 'max': 0}, 'Sentinel-1 VV')
    
    # VV-VH差异（重要的盐度指标）
    Map.addLayer(s1_features.select('VV_VH_diff'), 
                 {'min': 0, 'max': 15, 'palette': ['blue', 'yellow', 'red']}, 
                 'VV-VH Difference')
    
    # ET
    Map.addLayer(env_features.select('ET_mean'), 
                 {'min': 0, 'max': 300, 'palette': ['red', 'yellow', 'green', 'blue']}, 
                 'Evapotranspiration', False)
    
    return all_features_masked, Map, band_names, target_projection

# ==================== 7. 样本数据提取和特征工程 ====================

def extract_and_engineer_features(feature_image, sample_points_path, band_names, projection, scale):
    """
    提取样本特征并进行特征工程（修复后版本，增加projection和scale参数）
    """
    print("\n特征提取和工程...")
    
    sample_fc = geemap.csv_to_ee(sample_points_path)
    print(f"加载了 {sample_fc.size().getInfo()} 个样本点")

    # 关键点：显式指定projection和scale
    sampled_data = feature_image.sampleRegions(
        collection=sample_fc,
        properties=['salinity'],
        scale=scale,
        projection=projection,
        geometries=True,
        tileScale=4
    )

    sample_size = sampled_data.size().getInfo()
    print(f"提取到的有效样本数: {sample_size}")
    
    if sample_size == 0:
        raise ValueError("⚠️ 提取特征值失败，请检查样本点与影像是否匹配。")

    # 提取数据到本地
    features_list = sampled_data.getInfo()['features']
    data_records = []
    for feature in features_list:
        properties = feature['properties']
        coords = feature['geometry']['coordinates']
        properties['longitude'] = coords[0]
        properties['latitude'] = coords[1]
        data_records.append(properties)

    df = pd.DataFrame(data_records)

    if df.empty:
        raise ValueError("⚠️ 提取的数据框为空，请重新检查数据提取过程。")

    # 特征工程（示例）
    if 'NDVI' in df.columns and 'ST_B10' in df.columns:
        df['Temp_NDVI_interaction'] = df['ST_B10'] * df['NDVI']

    # 保存数据
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"✓ 特征数据成功保存至: {OUTPUT_CSV}")
    print(f"  总样本数: {len(df)}")
    print(f"  总特征数: {len(df.columns)-3}")

    return df

# ==================== 8. 综合特征分析和可视化 ====================

def comprehensive_feature_analysis(df, output_dir):
    """
    全面的特征分析和可视化
    """
    print("\n" + "="*60)
    print("综合特征分析...")
    print("="*60)
    
    # 准备数据
    feature_cols = [col for col in df.columns if col not in ['salinity', 'longitude', 'latitude']]
    X = df[feature_cols].fillna(0)
    y = df['salinity']
    
    # 1. 特征分布分析
    print("\n1. 生成特征分布图...")
    
    # 选择关键特征进行可视化
    key_features = ['NDVI', 'ST_B10', 'VV_VH_diff', 'SI1', 'ET_mean', 'groundwater_depth']
    available_key_features = [f for f in key_features if f in df.columns]
    
    if len(available_key_features) >= 6:
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.ravel()
        
        for idx, feature in enumerate(available_key_features[:6]):
            ax = axes[idx]
            
            # 散点图
            scatter = ax.scatter(df[feature], y, alpha=0.6, c=y, cmap='RdYlBu_r')
            ax.set_xlabel(feature)
            ax.set_ylabel('Salinity')
            ax.set_title(f'{feature} vs Salinity')
            
            # 添加趋势线
            z = np.polyfit(df[feature].fillna(0), y, 1)
            p = np.poly1d(z)
            ax.plot(df[feature].sort_values(), p(df[feature].sort_values()), "r--", alpha=0.8)
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'feature_vs_salinity_scatter.png'), dpi=300, bbox_inches='tight')
        plt.show()
    
    # 2. 相关性分析
    print("\n2. 计算特征相关性...")
    
    # 计算与盐度的相关性
    correlations = X.corrwith(y).sort_values(ascending=False)
    
    # 绘制相关性条形图
    plt.figure(figsize=(10, 12))
    top_n = 30  # 显示前30个特征
    correlations.head(top_n).plot(kind='barh')
    plt.xlabel('Correlation with Salinity')
    plt.title('Top Features Correlated with Salinity')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_correlation_with_salinity.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 3. 特征间相关性热力图
    print("\n3. 生成特征相关性热力图...")
    
    # 选择高相关特征的子集
    high_corr_features = correlations.abs().nlargest(20).index.tolist()
    
    plt.figure(figsize=(12, 10))
    corr_matrix = df[high_corr_features + ['salinity']].corr()
    mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
    
    sns.heatmap(corr_matrix, mask=mask, annot=True, fmt='.2f',
                cmap='coolwarm', center=0, square=True, linewidths=0.5,
                cbar_kws={"shrink": 0.8})
    plt.title('Feature Correlation Heatmap (Top 20 Features)', fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_correlation_heatmap.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 4. 数据源贡献分析
    print("\n4. 分析不同数据源的贡献...")
    
    # 按数据源分类特征
    source_mapping = {
        'Landsat': ['SR_B', 'ST_B', 'NDVI', 'EVI', 'SAVI', 'MSAVI', 'NDWI', 'MNDWI', 
                   'SI1', 'SI2', 'SI3', 'SI4', 'S1', 'S2', 'S3', 'S5', 'S6', 
                   'NDSI', 'SI_MSI', 'BSI', 'BI', 'TVDI', 'Temp'],
        'Sentinel-1': ['VV', 'VH', 'RVI', 'DPSVI', 'Pol', 'SMI', 'angle'],
        'Environmental': ['ET', 'PET', 'precip', 'groundwater', 'elevation', 'slope', 
                         'aspect', 'TWI', 'hillshade'],
        'Interaction': ['interaction', 'ratio', 'diff', 'Balance', 'Stress', 'Wetness']
    }
    
    source_corr = {}
    for source, keywords in source_mapping.items():
        source_features = [f for f in feature_cols if any(k in f for k in keywords)]
        if source_features:
            source_corr[source] = X[source_features].corrwith(y).abs().mean()
    
    # 绘制数据源贡献图
    plt.figure(figsize=(10, 6))
    sources = list(source_corr.keys())
    values = list(source_corr.values())
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D']
    
    bars = plt.bar(sources, values, color=colors[:len(sources)])
    plt.xlabel('Data Source')
    plt.ylabel('Average Absolute Correlation with Salinity')
    plt.title('Contribution of Different Data Sources')
    
    # 添加数值标签
    for bar, value in zip(bars, values):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                f'{value:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'data_source_contribution.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 5. 空间分布分析
    if 'longitude' in df.columns and 'latitude' in df.columns:
        print("\n5. 生成空间分布图...")
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 12))
        
        # 盐度空间分布
        scatter1 = axes[0, 0].scatter(df['longitude'], df['latitude'], 
                                     c=df['salinity'], s=50, cmap='RdYlBu_r',
                                     edgecolors='black', linewidth=0.5)
        axes[0, 0].set_title('Salinity Spatial Distribution')
        axes[0, 0].set_xlabel('Longitude')
        axes[0, 0].set_ylabel('Latitude')
        plt.colorbar(scatter1, ax=axes[0, 0], label='Salinity')
        
        # NDVI空间分布
        if 'NDVI' in df.columns:
            scatter2 = axes[0, 1].scatter(df['longitude'], df['latitude'], 
                                         c=df['NDVI'], s=50, cmap='RdYlGn',
                                         edgecolors='black', linewidth=0.5)
            axes[0, 1].set_title('NDVI Spatial Distribution')
            axes[0, 1].set_xlabel('Longitude')
            axes[0, 1].set_ylabel('Latitude')
            plt.colorbar(scatter2, ax=axes[0, 1], label='NDVI')
        
        # 温度空间分布
        if 'ST_B10' in df.columns:
            scatter3 = axes[1, 0].scatter(df['longitude'], df['latitude'], 
                                         c=df['ST_B10'], s=50, cmap='hot',
                                         edgecolors='black', linewidth=0.5)
            axes[1, 0].set_title('Land Surface Temperature Distribution')
            axes[1, 0].set_xlabel('Longitude')
            axes[1, 0].set_ylabel('Latitude')
            plt.colorbar(scatter3, ax=axes[1, 0], label='Temperature (K)')
        
        # VV-VH差异空间分布
        if 'VV_VH_diff' in df.columns:
            scatter4 = axes[1, 1].scatter(df['longitude'], df['latitude'], 
                                         c=df['VV_VH_diff'], s=50, cmap='viridis',
                                         edgecolors='black', linewidth=0.5)
            axes[1, 1].set_title('VV-VH Difference Spatial Distribution')
            axes[1, 1].set_xlabel('Longitude')
            axes[1, 1].set_ylabel('Latitude')
            plt.colorbar(scatter4, ax=axes[1, 1], label='VV-VH (dB)')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'spatial_distribution_analysis.png'), dpi=300, bbox_inches='tight')
        plt.show()
    
    return correlations

# ==================== 9. 增强的特征选择流程 ====================

def enhanced_feature_selection_multisource(df, output_dir, min_features=15, max_features=30):
    """
    针对多源数据的增强特征选择
    """
    print("\n" + "="*60)
    print("增强特征选择流程...")
    print("="*60)
    
    from sklearn.preprocessing import StandardScaler
    from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
    from sklearn.feature_selection import SelectKBest, f_regression, mutual_info_regression
    from sklearn.linear_model import LassoCV, ElasticNetCV
    from sklearn.model_selection import cross_val_score
    import warnings
    warnings.filterwarnings('ignore')
    
    # 准备特征和目标
    feature_cols = [col for col in df.columns if col not in ['salinity', 'longitude', 'latitude']]
    X = df[feature_cols].fillna(0)
    y = df['salinity']
    
    # 标准化
    scaler = StandardScaler()
    X_scaled = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols)
    
    print(f"特征总数: {len(feature_cols)}")
    print(f"样本总数: {len(X)}")
    
    # 1. 多方法特征重要性评估
    print("\n1. 多方法特征重要性评估...")
    
    # Random Forest
    rf = RandomForestRegressor(n_estimators=200, max_depth=10, random_state=42, n_jobs=-1)
    rf.fit(X_scaled, y)
    rf_importance = pd.Series(rf.feature_importances_, index=feature_cols, name='RF_importance')
    
    # Gradient Boosting
    gb = GradientBoostingRegressor(n_estimators=100, max_depth=5, random_state=42)
    gb.fit(X_scaled, y)
    gb_importance = pd.Series(gb.feature_importances_, index=feature_cols, name='GB_importance')
    
    # F-statistic
    f_selector = SelectKBest(score_func=f_regression, k='all')
    f_selector.fit(X_scaled, y)
    f_importance = pd.Series(f_selector.scores_, index=feature_cols, name='F_score')
    
    # Mutual Information
    mi_scores = mutual_info_regression(X_scaled, y, random_state=42, n_neighbors=5)
    mi_importance = pd.Series(mi_scores, index=feature_cols, name='MI_score')
    
    # LASSO
    lasso = LassoCV(cv=5, random_state=42, max_iter=1000)
    lasso.fit(X_scaled, y)
    lasso_importance = pd.Series(np.abs(lasso.coef_), index=feature_cols, name='LASSO_coef')
    
    # ElasticNet
    elastic = ElasticNetCV(cv=5, random_state=42, max_iter=1000)
    elastic.fit(X_scaled, y)
    elastic_importance = pd.Series(np.abs(elastic.coef_), index=feature_cols, name='ElasticNet_coef')
    
    # Correlation
    corr_importance = X_scaled.corrwith(y).abs()
    corr_importance.name = 'Correlation'
    
    # 合并所有重要性指标
    importance_df = pd.concat([
        rf_importance, gb_importance, f_importance, mi_importance,
        lasso_importance, elastic_importance, corr_importance
    ], axis=1)
    
    # 标准化
    for col in importance_df.columns:
        if importance_df[col].max() > 0:
            importance_df[col] = importance_df[col] / importance_df[col].max()
    
    # 计算加权综合得分
    weights = {
        'RF_importance': 0.20,
        'GB_importance': 0.15,
        'F_score': 0.15,
        'MI_score': 0.15,
        'LASSO_coef': 0.10,
        'ElasticNet_coef': 0.10,
        'Correlation': 0.15
    }
    
    importance_df['Weighted_score'] = sum(importance_df[col] * weight 
                                         for col, weight in weights.items())
    
    # 2. 数据源平衡选择
    print("\n2. 基于数据源的平衡选择...")
    
    # 识别特征来源
    def identify_source(feature_name):
        if any(x in feature_name for x in ['SR_B', 'ST_B', 'NDVI', 'EVI', 'SAVI', 'MSAVI', 
                                           'NDWI', 'MNDWI', 'SI', 'S1', 'S2', 'S3', 'S5', 'S6',
                                           'NDSI', 'BSI', 'BI', 'TVDI', 'Temp']):
            return 'Landsat'
        elif any(x in feature_name for x in ['VV', 'VH', 'RVI', 'DPSVI', 'Pol', 'SMI', 'angle']):
            return 'Sentinel-1'
        elif any(x in feature_name for x in ['ET', 'PET', 'precip', 'groundwater', 
                                             'elevation', 'slope', 'aspect', 'TWI', 'hillshade']):
            return 'Environmental'
        else:
            return 'Interaction'
    
    importance_df['Source'] = importance_df.index.map(identify_source)
    
    # 每个数据源的最小特征数
    min_per_source = {
        'Landsat': 5,
        'Sentinel-1': 3,
        'Environmental': 3,
        'Interaction': 2
    }
    
    # 选择特征
    selected_features = []
    for source, min_count in min_per_source.items():
        source_features = importance_df[importance_df['Source'] == source].nlargest(
            min_count, 'Weighted_score'
        ).index.tolist()
        selected_features.extend(source_features)
        print(f"  {source}: 选择了 {len(source_features)} 个特征")
    
    # 补充高分特征
    remaining_budget = max_features - len(selected_features)
    remaining_features = importance_df[~importance_df.index.isin(selected_features)].nlargest(
        remaining_budget, 'Weighted_score'
    ).index.tolist()
    selected_features.extend(remaining_features)
    
    # 3. 相关性优化
    print("\n3. 优化特征相关性...")
    
    def optimize_by_correlation(features, X, threshold=0.95):
        corr_matrix = X[features].corr().abs()
        upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
        
        to_drop = []
        for column in upper.columns:
            if column in to_drop:
                continue
            correlated = upper.index[upper[column] > threshold].tolist()
            if correlated:
                all_corr_features = [column] + correlated
                importance_scores = importance_df.loc[all_corr_features, 'Weighted_score']
                keep_feature = importance_scores.idxmax()
                drop_features = [f for f in all_corr_features if f != keep_feature]
                to_drop.extend(drop_features)
        
        return [f for f in features if f not in to_drop]
    
    selected_features = optimize_by_correlation(selected_features, X_scaled, threshold=0.95)
    print(f"  相关性优化后剩余特征数: {len(selected_features)}")
    
    # 4. 交叉验证选择最优特征数
    print("\n4. 交叉验证确定最优特征数...")
    
    # 按重要性排序
    selected_features_sorted = importance_df.loc[selected_features].sort_values(
        'Weighted_score', ascending=False
    ).index.tolist()
    
    cv_scores = []
    feature_range = range(min_features, min(len(selected_features_sorted) + 1, max_features + 1))
    
    for n_features in feature_range:
        features_subset = selected_features_sorted[:n_features]
        X_subset = X_scaled[features_subset]
        
        # 使用较快的模型进行CV
        rf_cv = RandomForestRegressor(n_estimators=50, max_depth=8, random_state=42, n_jobs=-1)
        scores = cross_val_score(rf_cv, X_subset, y, cv=5, scoring='r2')
        
        cv_scores.append({
            'n_features': n_features,
            'mean_r2': scores.mean(),
            'std_r2': scores.std()
        })
        
        print(f"  {n_features} 特征: R² = {scores.mean():.4f} ± {scores.std():.4f}")
    
    # 找到最优特征数
    cv_scores_df = pd.DataFrame(cv_scores)
    
    # 使用"一个标准差"规则
    best_idx = cv_scores_df['mean_r2'].idxmax()
    best_score = cv_scores_df.loc[best_idx, 'mean_r2']
    best_std = cv_scores_df.loc[best_idx, 'std_r2']
    
    # 找到在一个标准差内的最少特征数
    threshold = best_score - best_std
    optimal_idx = cv_scores_df[cv_scores_df['mean_r2'] >= threshold]['n_features'].idxmin()
    optimal_n = cv_scores_df.loc[optimal_idx, 'n_features']
    
    print(f"\n最优特征数: {optimal_n}")
    print(f"对应R²: {cv_scores_df.loc[optimal_idx, 'mean_r2']:.4f}")
    
    # 最终特征集
    final_features = selected_features_sorted[:optimal_n]
    
    # 5. 可视化结果
    print("\n5. 生成特征选择可视化...")
    
    # 5.1 特征重要性热力图
    plt.figure(figsize=(14, 10))
    importance_matrix = importance_df.loc[final_features, list(weights.keys())]
    
    sns.heatmap(importance_matrix, annot=True, fmt='.3f', cmap='YlOrRd',
                cbar_kws={'label': 'Normalized Importance'})
    plt.title('Multi-Method Feature Importance Heatmap', fontsize=14)
    plt.xlabel('Method')
    plt.ylabel('Feature')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_importance_heatmap.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 5.2 特征选择曲线
    plt.figure(figsize=(10, 6))
    plt.plot(cv_scores_df['n_features'], cv_scores_df['mean_r2'], 'b-', marker='o', label='Mean R²')
    plt.fill_between(cv_scores_df['n_features'],
                     cv_scores_df['mean_r2'] - cv_scores_df['std_r2'],
                     cv_scores_df['mean_r2'] + cv_scores_df['std_r2'],
                     alpha=0.2, color='blue', label='±1 std')
    plt.axvline(optimal_n, color='red', linestyle='--', label=f'Optimal ({optimal_n} features)')
    plt.axhline(threshold, color='gray', linestyle=':', label='One std rule threshold')
    plt.xlabel('Number of Features')
    plt.ylabel('Cross-validation R²')
    plt.title('Feature Selection Performance Curve')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_selection_curve.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 5.3 数据源分布饼图
    source_dist = importance_df.loc[final_features, 'Source'].value_counts()
    
    plt.figure(figsize=(8, 8))
    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D']
    plt.pie(source_dist.values, labels=source_dist.index, autopct='%1.1f%%',
            colors=colors[:len(source_dist)], startangle=90)
    plt.title('Distribution of Selected Features by Data Source')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'feature_source_distribution.png'), dpi=300, bbox_inches='tight')
    plt.show()
    
    # 6. 保存结果
    print("\n6. 保存特征选择结果...")
    
    # 特征重要性报告
    feature_report = importance_df.loc[final_features].copy()
    feature_report = feature_report.sort_values('Weighted_score', ascending=False)
    feature_report.to_csv(os.path.join(output_dir, 'feature_importance_report.csv'))
    
    # 打印最终选择的特征
    print("\n" + "="*60)
    print("最终选择的特征:")
    print("="*60)
    
    for source in ['Landsat', 'Sentinel-1', 'Environmental', 'Interaction']:
        source_final = feature_report[feature_report['Source'] == source]
        if len(source_final) > 0:
            print(f"\n{source} ({len(source_final)}个):")
            for feat, score in source_final['Weighted_score'].items():
                print(f"  - {feat}: {score:.4f}")
    
    return final_features, importance_df, cv_scores_df

# ==================== 10. 主执行函数 ====================

def main():
    """
    主执行函数
    """
    try:
        # 1. 多源数据融合
        fused_image, map_obj, band_names, target_projection = fuse_multisource_data_comprehensive()
        
        # 2. 特征提取和工程
        df = extract_and_engineer_features(
            fused_image,
            SOIL_DATA_PATH,
            band_names,
            projection=target_projection,  # 现在已定义
            scale=TARGET_SCALE
        )

        
        # 3. 综合特征分析
        correlations = comprehensive_feature_analysis(df, OUTPUT_DIR)
        
        # 4. 特征选择
        final_features, importance_df, cv_scores = enhanced_feature_selection_multisource(
            df, OUTPUT_DIR, min_features=15, max_features=35
        )
        
        # 5. 保存处理后的数据
        df_final = df[final_features + ['salinity', 'longitude', 'latitude']]
        final_csv_path = os.path.join(OUTPUT_DIR, 'final_selected_features.csv')
        df_final.to_csv(final_csv_path, index=False)
        
        print("\n" + "="*60)
        print("处理完成！")
        print("="*60)
        print(f"✓ 原始特征数据: {OUTPUT_CSV}")
        print(f"✓ 最终特征数据: {final_csv_path}")
        print(f"✓ 所有图表保存在: {OUTPUT_DIR}")
        print(f"✓ 最终特征数: {len(final_features)}")
        
        return map_obj, df_final, final_features
        
    except Exception as e:
        print(f"\n错误: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None

# 执行
if __name__ == "__main__":
    map_result, final_data, selected_features = main()
    if map_result:
        map_result  # 显示地图