
 * @Author: cunyu277 2465899266@qq.com
 * @Date: 2025-04-14 15:29:44
 * @LastEditors: cunyu277 2465899266@qq.com
 * @LastEditTime: 2025-04-14 15:58:56
 * @FilePath: \crop_yield_prediction\cunuu\clean\two.ipynb
 * @Description: 
 * 
 * Copyright (c) 2025 by yh, All Rights Reserved. 



In [1]:
import os
import numpy as np
import rasterio
from tqdm import tqdm
import re

# 配置路径
BASE_DIR = "D:\\Crop\\NorthChina"
CLIPPED_DIR = os.path.join(BASE_DIR, "result")

# 输出文件夹配置
OUTPUT_DIRS = {
    'NDVI': os.path.join(BASE_DIR, "NDVI"),
    'NDMI': os.path.join(BASE_DIR, "NDMI"),
    'EVI': os.path.join(BASE_DIR, "EVI"),
    'GNDVI': os.path.join(BASE_DIR, "GNDVI"),
    'SIPI': os.path.join(BASE_DIR, "SIPI")
}

# 创建输出目录
for dir_path in OUTPUT_DIRS.values():
    os.makedirs(dir_path, exist_ok=True)

# MODIS波段定义
BANDS_PER_TIME = 7  # 每个时间点7个波段
TOTAL_TIMES = 16    # 一年16个时间段

# 波段索引(从0开始)
BLUE_BAND_IDX = 2   # B3 (蓝光)
RED_BAND_IDX = 0    # B1 (红光)
NIR_BAND_IDX = 1    # B2 (近红外)
GREEN_BAND_IDX = 3  # B4 (绿光)
SWIR_BAND_IDX = 4   # B5 (短波红外1)

def parse_modis_filename(filename):
    """解析MODIS文件名获取关键信息"""
    pattern = r"(.*?)_(\d{6})_(\d{6})_(\d{4})\.tif"
    match = re.match(pattern, filename)
    if not match:
        raise ValueError(f"文件名格式错误: {filename}")
    
    return {
        'data_type': match.group(1),
        'prov_code': match.group(2),
        'loc_code': match.group(3),
        'year': match.group(4)
    }

def calculate_all_indices(input_path):
    """计算五种植被指数"""
    with rasterio.open(input_path) as src:
        # 检查波段数量
        if src.count != BANDS_PER_TIME * TOTAL_TIMES:
            raise ValueError(f"波段数量不符: 预期{BANDS_PER_TIME*TOTAL_TIMES}，实际{src.count}")
        
        file_info = parse_modis_filename(os.path.basename(input_path))
        
        # 初始化存储数组
        results = {
            'NDVI': [],
            'NDMI': [],
            'EVI': [],
            'GNDVI': [],
            'SIPI': []
        }
        
        # 处理每个时间段
        for time_idx in range(TOTAL_TIMES):
            # 计算当前时间段的波段位置
            blue_band = time_idx * BANDS_PER_TIME + BLUE_BAND_IDX
            red_band = time_idx * BANDS_PER_TIME + RED_BAND_IDX
            nir_band = time_idx * BANDS_PER_TIME + NIR_BAND_IDX
            green_band = time_idx * BANDS_PER_TIME + GREEN_BAND_IDX
            swir_band = time_idx * BANDS_PER_TIME + SWIR_BAND_IDX
            
            # 读取波段数据 (注意波段索引从1开始)
            blue = src.read(blue_band + 1).astype(float)
            red = src.read(red_band + 1).astype(float)
            nir = src.read(nir_band + 1).astype(float)
            green = src.read(green_band + 1).astype(float)
            swir = src.read(swir_band + 1).astype(float)
            
            # 创建掩膜（标记无效值）
            mask = (blue == src.nodata) | (red == src.nodata) | (nir == src.nodata) | (green == src.nodata) | (swir == src.nodata)
            
            # 计算NDVI
            np.seterr(divide='ignore', invalid='ignore')
            results['NDVI'].append(np.where(mask, -9999, (nir - red) / (nir + red)))
            
            # 计算NDMI
            results['NDMI'].append(np.where(mask, -9999, (nir - swir) / (nir + swir)))
            
            # 计算EVI
            C1, C2, L, G = 6.0, 7.5, 1.0, 2.5  # EVI参数
            denominator = (nir + C1 * red - C2 * blue + L)
            # 避免除零和无效值传播
            with np.errstate(divide='ignore', invalid='ignore'):
                evi = G * (nir - red) / denominator
                # 仅对有效数据应用范围限制
                valid_mask = ~mask & np.isfinite(evi)
                evi[valid_mask] = np.clip(evi[valid_mask], -1.0, 1.0)
                # 保留原始无效值
                evi[mask] = -9999
            results['EVI'].append(evi)
            
            # 计算GNDVI
            results['GNDVI'].append(np.where(mask, -9999, (nir - green) / (nir + green)))
            
            # 计算SIPI (结构不敏感色素指数)
            results['SIPI'].append(np.where(mask, -9999, (nir - blue) / (nir + red)))
        
        # 转换为三维数组
        for key in results:
            results[key] = np.stack(results[key])
        
        return file_info, results, src.meta

def save_index_result(data_stack, meta, output_path):
    """保存植被指数结果"""
    # 更新元数据
    new_meta = meta.copy()
    new_meta.update({
        'count': TOTAL_TIMES,
        'dtype': 'float32',
        'nodata': -9999
    })
    
    # 写入结果
    with rasterio.open(output_path, 'w', **new_meta) as dst:
        for band_idx in range(TOTAL_TIMES):
            dst.write(data_stack[band_idx].astype('float32'), band_idx + 1)

def batch_process_indices():
    """批量处理所有MODIS数据"""
    modis_files = [f for f in os.listdir(CLIPPED_DIR) 
                  if f.startswith('ENTIRE_') and f.endswith('.tif')]
    
    print(f"发现{len(modis_files)}个MODIS文件待处理...")
    
    with tqdm(modis_files, desc="处理进度") as pbar:
        for filename in pbar:
            try:
                input_path = os.path.join(CLIPPED_DIR, filename)
                pbar.set_postfix_str(filename)
                
                # 计算所有指数
                file_info, results, meta = calculate_all_indices(input_path)
                
                # 生成输出文件名
                base_name = f"{file_info['prov_code']}_{file_info['loc_code']}_{file_info['year']}"
                
                # 保存所有指数结果
                for index_name, data_stack in results.items():
                    output_path = os.path.join(OUTPUT_DIRS[index_name], f"{index_name}_{base_name}.tif")
                    save_index_result(data_stack, meta, output_path)
                
            except Exception as e:
                print(f"\n处理 {filename} 时出错: {str(e)}")

if __name__ == "__main__":
    batch_process_indices()
    print("处理完成！结果保存在以下目录:")
    for index_name, dir_path in OUTPUT_DIRS.items():
        print(f"- {index_name}: {dir_path}")


发现585个MODIS文件待处理...


处理进度: 100%|██████████| 585/585 [02:07<00:00,  4.60it/s, ENTIRE_410000_419001_2022.tif]

处理完成！结果保存在以下目录:
- NDVI: D:\Crop\NorthChina\NDVI
- NDMI: D:\Crop\NorthChina\NDMI
- EVI: D:\Crop\NorthChina\EVI
- GNDVI: D:\Crop\NorthChina\GNDVI
- SIPI: D:\Crop\NorthChina\SIPI



