# 1. data pre-processing (GEE)

## initialize the gee

In [None]:
! pip install earthengine-api
import ee
ee.Authenticate()
ee.Initialize(project='crops-mapping-gaoyuan')
from ee.batch import Export
!pip install geemap
from google.colab import drive
drive.mount('/content/drive')
import time

## image process related function

In [None]:
# common functions
# function to get the sentinel-2 image collection based on the study data range and study area ----*/
# function to remove cloud
# function to exclude bad data at scene edges
def maskEdges(s2_img):
    return s2_img.updateMask(
        s2_img.select('B8A').mask().updateMask(s2_img.select('B9').mask()))

# Function to mask clouds in Sentinel-2 imagery.
def maskClouds(img):
    max_cloud_probabiltly = 5
    clouds = ee.Image(img.get('cloud_mask')).select('probability')
    isNotCloud = clouds.lt(max_cloud_probabiltly)
    return img.updateMask(isNotCloud)

def sentinel2_collection(start_data, end_data, roi):
    s2Sr = ee.ImageCollection("COPERNICUS/S2_HARMONIZED")
    s2Clouds = ee.ImageCollection("COPERNICUS/S2_CLOUD_PROBABILITY")

    # define the filter constraints
    criteria = ee.Filter.And(ee.Filter.geometry(roi), ee.Filter.date(start_data, end_data))

    # sentinel-2 data collection
    sentinel2_bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B11', 'B12']
    new_bands = ['B', 'G', 'R', 'RE1', 'RE2', 'RE3', 'NIR', 'RE4', 'SWIR1', 'SWIR2']

    # Filter input collections by desired data range and region.
    s2Sr = s2Sr.filter(criteria).map(maskEdges)
    s2Clouds = s2Clouds.filter(criteria)

    # Join S2 SR with cloud probability dataset to add cloud mask.
    s2SrWithCloudMask = ee.Join.saveFirst('cloud_mask').apply(**{
      "primary": s2Sr,
      "secondary": s2Clouds,
      "condition": ee.Filter.equals(**{"leftField": "system:index", "rightField":"system:index"})
      })

    # collect the images without cloud
    s2CloudMasked = ee.ImageCollection(s2SrWithCloudMask).map(maskClouds).select(sentinel2_bands, new_bands)
    return s2CloudMasked
# add EVI
def addEVI(image):
    EVI = image.expression('EVI = 2500 * (NIR - R) / (NIR + 6 * R - 7.5 * B + 10000)', {
        'NIR': image.select('NIR'),
        'R': image.select('R'),
        'B': image.select('B')
    })
    return image.addBands(EVI)

# year: the target identification year
# startDoy: the start Doy of optimal identification window
# endDoy: the end Doy of optimal identification window
# roi: target region of study
# output value of this function is the time series s2sr images in a given time period
def get_s2sr_images(start_date, end_date,interval, roi):

    # Define the year of the start_date
    year = ee.Date(start_date).get('year')
    first_date = ee.Date.fromYMD(year, 1, 1)
    last_date = ee.Date.fromYMD(year.add(1), 12, 31)

    # Define the Sentinel-2 image collection
    s2SR_imgCol = sentinel2_collection(first_date, last_date, roi)

    # Create a date range list with a specified interval
    millis_interval = ee.Number(interval).multiply(1000 * 60 * 60 * 24)
    dates = ee.List.sequence(start_date.millis(), end_date.millis(), millis_interval)

    # function to resample time resolution of image collection to 10 day
    def resampleTo10Days(date):
        currentDate = ee.Date(date)
        endDate = currentDate.advance(interval, 'day')
        summarizedImageCol = s2SR_imgCol.filterDate(currentDate, endDate)
        summarizedImage = summarizedImageCol.median()
        summarizedImage = summarizedImage.set('system:time_start', date)
        return summarizedImage

    # Apply the time resampling function using map()
    resampledImages = ee.ImageCollection(dates.map(resampleTo10Days))
    return resampledImages

In [None]:
def get_random_samples(roi):
    ESA_landmap = ee.ImageCollection("ESA/WorldCover/v100").first()
    cropland_WithMask = ESA_landmap.updateMask(ESA_landmap.eq(40))
    random_samples = cropland_WithMask.sample(
        region=roi,
        scale=10,
        numPixels=10000,
        geometries=True
    )
    indices = ee.List.sequence(0, random_samples.size().subtract(1))
    points_with_fid = ee.FeatureCollection(
        random_samples.toList(random_samples.size())
        .zip(indices)
        .map(lambda feature_and_index: ee.Feature(
            ee.List(feature_and_index).get(0)
        ).set('FID', ee.Number(ee.List(feature_and_index).get(1))))
    )
    points_with_coordinates = points_with_fid.map(lambda feature: feature.set({
        'longitude': feature.geometry().coordinates().get(0),
        'latitude': feature.geometry().coordinates().get(1)
    }))
    return points_with_coordinates

def extract_points_value(imgCol, pts,fileName,folderName):
  ft = ee.FeatureCollection(ee.List([]))

  def fill(img, ini):
    date = ee.Date(img.date()).format()
    inift = ee.FeatureCollection(ini)
    ft2 = img.sampleRegions(
        collection = pts,
        properties = ['FID','longitude','latitude','Label'], # Properties to include from points
        scale = 10
    )
    ft3 = ft2.map(lambda f: f.set('date', date))
    return inift.merge(ft3)
  newft = ee.FeatureCollection(imgCol.iterate(fill, ft))
  task = ee.batch.Export.table.toDrive(
      collection = newft,
      description = fileName,
      folder = folderName,
      fileFormat = 'CSV'
  )
  task.start()

In [None]:
def get_attributes_ofSamples(year, interval, roi,tilename):
    folderName = 'test_' + str(tilename)
    filename = 'primary_' + str(tilename) + '_' + str(year)
    startDate = ee.Date.fromYMD(year-1, 10, 1)
    endDate = ee.Date.fromYMD(year, 10, 1)
    s2SR_imgCol = get_s2sr_images(startDate, endDate, interval, roi)
    samples = get_random_samples(roi)
    extract_points_value(s2SR_imgCol, samples,filename,folderName)
    print(filename)

## get ranom points in the sutdy region and get time series surface reflectance value of each random points in the reference year

In [None]:
year = 2020
interval = 15
tilename = '30TYR'
s2_borderIndex = ee.FeatureCollection("projects/crops-mapping-gaoyuan/assets/sentinel_2_index_shapefile");
roi = s2_borderIndex.filter(ee.Filter.eq('Name', tilename))
get_attributes_ofSamples(year, interval, roi,tilename)

## get objects based on SNIC in the target year

In [None]:
def build_SNIC_objects (year, tilename):
    folderName = 'test_' + str(tilename)
    filename = f'SNIC_{tilename}_{year}'
    interval = 15

    s2_borderIndex = ee.FeatureCollection("projects/crops-mapping-gaoyuan/assets/sentinel_2_index_shapefile");
    roi = s2_borderIndex.filter(ee.Filter.eq('Name', tilename))
    roi_geometry = roi.first().geometry()

    startDate = ee.Date.fromYMD(year, 4, 1)
    endDate = ee.Date.fromYMD(year, 5, 1)

    s2SR_imgCol = get_s2sr_images(startDate, endDate, interval, roi)
    s2SR_median = s2SR_imgCol.median()
    segmentation_features = s2SR_median.select(['R', 'NIR', 'SWIR1'])

    seeds = ee.Algorithms.Image.Segmentation.seedGrid(36)
    snic = ee.Algorithms.Image.Segmentation.SNIC(
        image=segmentation_features,
        size=36,
        compactness=5,
        connectivity=8,
        neighborhoodSize=256,
        seeds=seeds
    )

    clusters = snic.select('clusters') \
                   .reproject(segmentation_features.select(0).projection(), None, 10) \
                   .toInt() \
                   .rename('clusters')

    task_Drive = Export.image.toDrive(
        image=clusters,
        description=filename,
        folder=folderName,
        fileNamePrefix=filename,
        region=roi_geometry,
        scale=10,
        maxPixels=1e10
    )
    task_Drive.start()

    assetId = f'projects/crops-mapping-gaoyuan/assets/SNIC_{tilename}_{year}'
    task_Asset = Export.image.toAsset(
        image=clusters,      # 你的整型标签影像
        description=filename,
        assetId=assetId,
        region=roi_geometry,             # 注意：必须是 Geometry，而不是 FeatureCollection
        scale=10,
        maxPixels=1e10   # 建议：标签类数据用 mode，避免重采样成浮点
    )
    task_Asset.start()

    print('SNIC object get done!')
    return clusters


In [None]:
# 生成对象质心点：数值显式转型
def clusters_to_centroids(clusters, roi):
    vecs = clusters.reduceToVectors(
        geometry=roi, scale=10, labelProperty='cluster',
        geometryType='centroid', eightConnected=True, maxPixels=1e13
    )
    # 给每个 feature 添加索引
    def add_index(feature, index):
        ll = feature.geometry().coordinates()
        return (feature
                .set('FID', index)  # 用 sequence 的序号代替 cluster
                .set('longitude', ee.Number(ll.get(0)))
                .set('latitude', ee.Number(ll.get(1))))

    indexed = ee.FeatureCollection(
        ee.List(vecs.toList(vecs.size()))
          .zip(ee.List.sequence(0, vecs.size().subtract(1)))
          .map(lambda pair: add_index(ee.Feature(ee.List(pair).get(0)),
                                      ee.List(pair).get(1)))
    )

    centroids = indexed.select(['FID', 'longitude', 'latitude', 'cluster'])

    # 在函数内部增加导出任务
    task = ee.batch.Export.table.toAsset(
        collection=centroids,
        description='SNIC_30TYR_2021_point',
        assetId='projects/crops-mapping-gaoyuan/assets/SNIC_30TYR_2021_point'
    )
    task.start()  # 启动任务

    return centroids

def to_object_median_image(img, clusters):
    # 按 SNIC 对象聚合到对象尺度（中值）
    vals = img.addBands(clusters)
    obj_med = vals.reduceConnectedComponents(
        reducer   = ee.Reducer.median(),
        labelBand = 'cluster'
    )
    return obj_med.copyProperties(img, ['system:time_start'])

# 采样：固定选择标量波段 + 明确日期为字符串
def sample_timeseries_by_points(imgCol, pts, fileName, folderName):
    # 按时序逐景取样到点；不做 toBands
    def _loop(img, ini):
        ini_fc = ee.FeatureCollection(ini)
        date = ee.Date(img.get('system:time_start')).format('YYYY-MM-dd')
        ft = img.sampleRegions(
            collection = pts,
            properties = ['FID','longitude','latitude'],  # 若有标签，加 'Label'
            scale      = 10,
            geometries = False
        ).map(lambda f: f.set('date', date))
        return ini_fc.merge(ft)

    out_fc = ee.FeatureCollection(imgCol.iterate(_loop, ee.FeatureCollection([])))

    task = ee.batch.Export.table.toDrive(
        collection  = out_fc,
        description = fileName,
        folder      = folderName,
        fileFormat  = 'CSV'
    )
    task.start()
    return out_fc

def get_attribute_fromObjects(year, interval, roi, tilename, SNIC_objects):
    folderName = f'test_{tilename}'
    filename = f'primary_{tilename}_{year}_objects'

    startDate = ee.Date.fromYMD(year-1, 10, 1)
    endDate   = ee.Date.fromYMD(year,   10, 1)

    s2SR_imgCol = get_s2sr_images(startDate, endDate, interval, roi)

    clusters = SNIC_objects.select('clusters').toInt().rename('cluster')

    # 与 clusters 投影一致
    s2 = s2SR_imgCol.map(lambda img: img.reproject(clusters.projection()))
    print('s2:',ee.Image(s2.first()).bandNames().getInfo())

    # 每景对象中值
    objCol = s2.map(lambda img: to_object_median_image(img, clusters))

    # 质心
    #objPts = clusters_to_centroids(clusters, roi)
    objPts = ee.FeatureCollection("projects/crops-mapping-gaoyuan/assets/SNIC_30TYR_2021_point")

    # 采样 + 导出
    samples = sample_timeseries_by_points(
        imgCol=objCol, pts=objPts, fileName=filename, folderName=folderName
    )
    return {'objCol': objCol, 'objPts': objPts, 'samples': samples}


In [None]:
year = 2021
interval = 15
tilename = '30TYR'
s2_borderIndex = ee.FeatureCollection("projects/crops-mapping-gaoyuan/assets/sentinel_2_index_shapefile");
roi = s2_borderIndex.filter(ee.Filter.eq('Name', tilename))

#SNIC_objects = build_SNIC_objects (year, tilename)

SNIC_objects = ee.Image("projects/crops-mapping-gaoyuan/assets/SNIC_30TYR_2021")
get_attribute_fromObjects(year, interval, roi, tilename, SNIC_objects)

# 2. get reference curve library (python)

## get each reference points curve in the reference year

In [None]:
! pip install earthengine-api
import ee
ee.Authenticate()
ee.Initialize(project='crops-mapping-gaoyuan')
!pip install geemap
from google.colab import drive
drive.mount('/content/drive')
!pip install rasterio
!pip install cleanlab
import os
import glob
import pandas as pd
import numpy as np
from datetime import datetime
import gc
import rasterio
import traceback
from scipy.optimize import curve_fit
from scipy.signal import savgol_filter
from scipy.optimize import curve_fit
from cleanlab.classification import CleanLearning
from sklearn.ensemble import RandomForestClassifier
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.optimize import curve_fit
from datetime import datetime
from scipy.signal import savgol_filter

In [None]:
# get winter wheat based on the refer map in 2022
def get_crop_label(df, cl_file, chunk_size=5000):
    # 提取唯一的 FID、latitude 和 longitude
    unique_fid_lat_lon = df[['FID', 'latitude', 'longitude']].drop_duplicates()

    # 构建点集合，并添加 FID 作为属性
    points = ee.FeatureCollection([
        ee.Feature(
            ee.Geometry.Point(row['longitude'], row['latitude']),
            {'FID': row['FID']}
        )
        for _, row in unique_fid_lat_lon.iterrows()
    ])

    # 加载资产
    wheat_map = ee.Image("projects/crops-mapping-gaoyuan/assets/30TYR_RPG_2020_label")

    # 分块提取的函数
    def process_points_in_chunks(feature_collection, chunk_size):
        features = feature_collection.toList(feature_collection.size())
        num_chunks = (feature_collection.size().getInfo() + chunk_size - 1) // chunk_size
        results = []
        for i in range(num_chunks):
            # 当前分块
            chunk = ee.FeatureCollection(features.slice(i * chunk_size, (i + 1) * chunk_size))

            # 对当前分块执行 reduceRegions
            sampled_chunk = wheat_map.reduceRegions(
                collection=chunk,
                reducer=ee.Reducer.first(),
                scale=10  # 替换为影像的分辨率
            )

            # 转换当前分块为 Python 数据
            sampled_data = sampled_chunk.getInfo()['features']
            for feature in sampled_data:
                geom = feature['geometry']['coordinates']
                props = feature['properties']
                results.append({
                    'longitude': geom[0],
                    'latitude': geom[1],
                    'FID': props['FID'],
                    'Label': props.get('first', None)
                })
        return results

    # 调用分块提取函数
    sampled_data = process_points_in_chunks(points, chunk_size)

    # 转换为 Pandas DataFrame
    results = pd.DataFrame(sampled_data)

    # 合并结果到原始数据
    merged = unique_fid_lat_lon.merge(results, on=['FID', 'latitude', 'longitude'], how='left')

    # 保存结果
    merged.to_csv(cl_file, index=False)

def sdiv(num, den):
        # 安全除法：分母接近 0 时置 NaN，避免 inf
        return num / den.mask(den.abs() < eps)
# add VIs
def add_VIs(df):
    # NDVI
    df['NDVI'] = 1000 * (df['NIR'] - df['R']) / (df['NIR'] + df['R'])
    # EVI (Enhanced Vegetation Index)
    df['EVI'] = 1000 * 2.5 * (df['NIR'] - df['R']) / (df['NIR'] + 6 * df['R'] - 7.5 * df['B'] + 10000)
    # LSWI (Land Surface Water Index)
    df['LSWI'] = 1000 * (df['NIR'] - df['SWIR1']) / (df['NIR'] + df['SWIR1'])
    # RVI (Ratio Vegetation Index)
    df['RVI'] = 1000 * (df['RE3'] / df['G'])
    # OSAVI (Optimized Soil-Adjusted Vegetation Index)
    df['OSAVI'] = 1000 * (df['NIR'] - df['R']) / (df['NIR'] + df['R'] + 1600)
    # GCVI (Green Chlorophyll Vegetation Index)
    df['GCVI'] = 1000 * (df['NIR'] / df['G'] - 1)
    # GVMI (Global Vegetation Moisture Index)
    df['GVMI'] = 1000 * ((df['NIR'] + 1000) - (df['SWIR1'] + 200)) / ((df['NIR'] + 1000) + (df['SWIR1'] + 200))
    # NDRE (Normalized Difference Red-Edge)
    df['NDRE'] = 1000 * (df['RE2'] - df['RE1']) / (df['RE2'] + df['RE1'])
    # REP (red edge position)
    df['REP'] = 705 + ((35*(0.5*(df['R'] + df['RE3'])- df['RE1'])) / (df['RE2'] - df['RE1']))
    return df

def add_VIs_safe(df, scale=1000, eps=1e-6):
    # 必要列检查
    need = {'NIR','R','B','SWIR1','RE3','G','RE2','RE1'}
    miss = need - set(df.columns)
    if miss:
        raise KeyError(f"Missing columns: {sorted(miss)}")

    # 转浮点，先清理明显异常反射率（按你数据实际可调整）
    X = df.copy()
    for c in need:
        X[c] = pd.to_numeric(X[c], errors='coerce').astype('float64')
    # 可选：把不合理值置 NaN（若反射率已×10000，范围可用 [-1000, 20000]）
    for c in need:
        X.loc[(X[c] < -1000) | (X[c] > 20000), c] = np.nan

    def sdiv(num, den):
        # 安全除法：分母接近 0 时置 NaN，避免 inf
        return num / den.mask(den.abs() < eps)

    # 指数计算（保持你的缩放与常量）
    X['NDVI'] = scale * sdiv(X['NIR'] - X['R'], X['NIR'] + X['R'])
    X['EVI']  = scale * 2.5 * sdiv(X['NIR'] - X['R'], X['NIR'] + 6*X['R'] - 7.5*X['B'] + 10000)
    X['LSWI'] = scale * sdiv(X['NIR'] - X['SWIR1'], X['NIR'] + X['SWIR1'])
    X['RVI']  = scale * sdiv(X['RE3'], X['G'])
    X['OSAVI']= scale * sdiv(X['NIR'] - X['R'], X['NIR'] + X['R'] + 1600)
    X['GCVI'] = scale * (sdiv(X['NIR'], X['G']) - 1.0)
    X['GVMI'] = scale * sdiv((X['NIR'] + 1000) - (X['SWIR1'] + 200),
                             (X['NIR'] + 1000) + (X['SWIR1'] + 200))
    X['NDRE'] = scale * sdiv(X['RE2'] - X['RE1'], X['RE2'] + X['RE1'])

    # REP：分母过小置 NaN，保持 nm 不缩放
    den_rep = (X['RE2'] - X['RE1']).mask((X['RE2'] - X['RE1']).abs() < eps)
    X['REP'] = 705.0 + (35.0 * (0.5*(X['R'] + X['RE3']) - X['RE1']) / den_rep)

    # 替换 inf/-inf→NaN
    vi_cols = ['NDVI','EVI','LSWI','RVI','OSAVI','GCVI','GVMI','NDRE','REP']
    X[vi_cols] = X[vi_cols].replace([np.inf, -np.inf], np.nan)

    # 物理/经验裁剪（按×1000 缩放）
    clip_map = {
        'NDVI': (-scale, scale),
        'EVI':  (-scale, scale),
        'LSWI': (-scale, scale),
        'OSAVI':(-scale, scale),
        'GVMI': (-scale, scale),
        'NDRE': (-scale, scale),
        # RVI/GCVI 常为正，给宽松上限以防少量异常
        'RVI':  (0, 10*scale),
        'GCVI': (-scale, 10*scale),
        # REP 波段位置合理区间
        'REP':  (680, 750),
    }
    for k,(lo,hi) in clip_map.items():
        X[k] = X[k].clip(lower=lo, upper=hi)

    # 可选：用中位数填补，或留 NaN 由下游处理
    # from sklearn.impute import SimpleImputer
    # imp = SimpleImputer(strategy='median')
    # X[vi_cols] = imp.fit_transform(X[vi_cols])

    # 回写到原 df
    df[vi_cols] = X[vi_cols].astype({'REP':'float32', **{c:'float32' for c in vi_cols if c!='REP'}})
    return df

# function of S-G filtering
def SG_filtering(df, bandList,SG_file):
    # 获取唯一的FID和日期
    FIDs = df['FID'].unique()
    #df['date'] = pd.to_datetime(df['date'], format='%Y-%m-%dT%H:%M:%S')  #reference year

    # 处理目标年份波段数据时需要使用下述date数据模式替换上述date数据模式
    df['date'] = pd.to_datetime(df['date'], format='%Y-%m-%d')  # target year

    # 计算天数 (DOY)
    start_date = df['date'].min()
    base_date = datetime(start_date.year, 1, 1)
    df['doy'] = (df['date'] - base_date).dt.days + 1
    doys = df['doy'].unique()

    # 临时存储处理后的数据
    temp_data_df = pd.DataFrame()

    for bandname in bandList:
        # 初始化每个band的完整数据框
        temp_data_list = []

        # 为每个DOY创建每个FID的列，并填充缺失值为NaN
        for doy in doys:
            # 创建一个包含所有FID的空列，默认值为NaN
            temp_doy_df = pd.DataFrame({'FID': FIDs})

            # 从原始数据中获取对应DOY的记录，并填充到temp_doy_df
            time_column_name = f'{bandname}_{int(doy)}'
            band_data = df[df['doy'] == doy][['FID', bandname]].rename(columns={bandname: time_column_name})
            temp_doy_df = temp_doy_df.merge(band_data, on='FID', how='left')
            temp_data_list.append(temp_doy_df)

        # 合并所有DOY数据
        combined_data = pd.concat(temp_data_list, axis=1, join='outer')

        # 确保没有重复列名
        combined_data = combined_data.loc[:, ~combined_data.columns.duplicated()]

        # 仅保留FID存在的数据
        combined_data = combined_data[combined_data['FID'].isin(FIDs)]

        # 对缺失值进行线性插值
        vi_columns = [col for col in combined_data.columns if col.startswith(bandname)]
        combined_data[vi_columns] = combined_data[vi_columns].interpolate(method='linear', axis=1, limit_direction='both')

        # 确保数据列数足够应用 Savitzky-Golay 滤波器
        window_length = 5  # 滑动窗口大小
        polyorder = 2      # 多项式阶数

        if len(vi_columns) >= window_length:
            # 使用 Savitzky-Golay 滤波器进行平滑
            combined_data[vi_columns] = combined_data[vi_columns].apply(
                lambda x: savgol_filter(x.ffill().bfill(), window_length, polyorder)
                if x.notna().sum() >= window_length else x
             )
            # combined_data[vi_columns] = combined_data[vi_columns].apply(
            #     lambda x: fit_harmonic_curve(np.arange(len(x)), x.ffill().bfill())
            #     if x.notna().sum() >= 3 else x  # 确保有足够的数据点
            # )

        # 确保平滑后数据没有NaN
        combined_data[vi_columns] = combined_data[vi_columns].ffill().bfill()

        # 重设索引，确保 temp_data_df 和 combined_data 的行数一致
        combined_data = combined_data.reset_index(drop=True)

        # 第一次赋值时直接赋值 temp_data_df
        if temp_data_df.empty:
            temp_data_df = combined_data[vi_columns]
        else:
            # 合并数据
            temp_data_df = pd.concat([temp_data_df, combined_data[vi_columns]], axis=1)

    # 重置索引并添加FID列
    temp_data_df = temp_data_df.reset_index(drop=True)
    temp_data_df['FID'] = FIDs

    temp_data_df.to_csv(SG_file, index=False)

# main process to get the reference curve
def process_tileFile(filename,start_date,end_date,sr_bands,vi_bands,out_filenameSuffix):
    VI_SG_file = out_filenameSuffix + 'VI_SGfeatures.csv'
    SR_SG_file = out_filenameSuffix + 'SR_SGfeatures.csv'
    cl_file = out_filenameSuffix + 'cropLabel.csv'

    primary_data = pd.read_csv(filename)
    data = primary_data.drop(['system:index','.geo'], axis=1)

    # filter related time data for harmonic filtering
    #data['date'] = pd.to_datetime(data['date'], format='%Y-%m-%dT%H:%M:%S') # reference year
    # 处理目标年份波段数据时需要使用下述date数据模式替换上述date数据模式
    data['date'] = pd.to_datetime(data['date'], format='%Y-%m-%d')  # target year

    filtered_data = data[(data['date'] >= start_date) & (data['date'] <= end_date)].copy()

    # convert date to doy for x vlues input for harmonic function
    base_date = datetime(start_date.year,1,1)
    filtered_data['doy'] = (filtered_data['date'] - base_date).dt.days + 1

    # get reference wheat label of 2022
    #get_crop_label(filtered_data,cl_file)
    print('  get crop label done.')
    # get S-G filtering surface value
    SG_filtering(filtered_data,sr_bands,SR_SG_file)
    print('  get sr sg filtering done.')
    filtered_data_vis = add_VIs(filtered_data)
    SG_filtering(filtered_data_vis,vi_bands,VI_SG_file)
    print('  get vi sg filtering done.')

In [None]:
year = 2020
tilename = '30TYR'
filename = f'/content/drive/MyDrive/test_{tilename}/primary_{tilename}_{year}.csv'
output_fileSuff = f'/content/drive/MyDrive/test_{tilename}/{tilename}_{year}_'
start_date = datetime(year-1, 9, 1)
end_date = datetime(year, 9, 1)
SR_bands = ['B', 'G', 'R', 'RE1', 'RE2', 'RE3', 'NIR', 'RE4', 'SWIR1', 'SWIR2']
VI_bands = ['NDVI','EVI','LSWI','OSAVI','RVI','GCVI','GVMI','NDRE','REP']
process_tileFile(filename,start_date,end_date,SR_bands,VI_bands,output_fileSuff)

## target sr pre-procession in target year

In [None]:
year = 2021
tilename = '30TYR'
filename = f'/content/drive/MyDrive/test_{tilename}/primary_{tilename}_{year}_objects.csv'
output_fileSuff = f'/content/drive/MyDrive/test_{tilename}/{tilename}_{year}_'
start_date = datetime(year-1, 9, 1)
end_date = datetime(year, 9, 1)
SR_bands = ['B', 'G', 'R', 'RE1', 'RE2', 'RE3', 'NIR', 'RE4', 'SWIR1', 'SWIR2']
VI_bands = ['NDVI','EVI','LSWI','OSAVI','RVI','GCVI','GVMI','NDRE','REP']
process_tileFile(filename,start_date,end_date,SR_bands,VI_bands,output_fileSuff)

## confidence learning to remove noise points

In [None]:
import glob
import os
import rasterio
import numpy as np
import pandas as pd
from cleanlab.classification import CleanLearning
from sklearn.ensemble import RandomForestClassifier
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from matplotlib import cm
from scipy.optimize import curve_fit
from datetime import datetime
from scipy.signal import savgol_filter
import re
from sklearn.preprocessing import LabelEncoder

In [None]:
def clean_referenceLabel(tilename, year, debug=True, write_report=True):
    out_dir = f'/content/drive/MyDrive/test_{tilename}'
    out_filepath = f'{out_dir}/{tilename}_{year}_cleaned_labels.csv'
    report_path = f'{out_dir}/{tilename}_{year}_bad_features_report.csv'

    label_file = f'{out_dir}/{tilename}_{year}_cropLabel.csv'
    VI_file = f'{out_dir}/{tilename}_{year}_VI_SGfeatures.csv'
    SR_file = f'{out_dir}/{tilename}_{year}_SR_SGfeatures.csv'

    label_df = pd.read_csv(label_file, index_col='FID')
    VI_df = pd.read_csv(VI_file)
    SR_df = pd.read_csv(SR_file)
    feature_df = pd.merge(VI_df, SR_df, on='FID')

    # 选列
    pat = re.compile(r'^(EVI|LSWI|OSAVI|RVI|R|NIR|SWIR1|SWIR2)_(\d+)$')
    selected_columns = [c for c in feature_df.columns if pat.match(c)]
    print(f"Selected cols: {len(selected_columns)}")

    if debug:
        print(f"Total cols: {len(feature_df.columns)}, Selected cols: {len(selected_columns)}")
        if len(selected_columns) == 0:
            print("Sample of columns:", feature_df.columns[:50].tolist())
            raise ValueError("selected_columns is empty. Check prefixes / column names.")

    # 与标签对齐
    fid_column = label_df.index
    cdl_labels_binary = pd.to_numeric(label_df['Label'], errors='coerce').astype('Int64')
    feature_df = feature_df[feature_df['FID'].isin(fid_column)]
    valid_fids = cdl_labels_binary[cdl_labels_binary.notna()].index

    X = feature_df.set_index('FID').loc[valid_fids, selected_columns].astype('float64').copy()
    y = cdl_labels_binary.loc[valid_fids].astype(int)

    if debug:
        print("X shape:", X.shape, "y shape:", y.shape)
        uniq = np.unique(y)
        print("y unique:", uniq)

    # —— 诊断非有限/超界值 —— #
    max32 = np.finfo(np.float32).max
    vals = X.values
    mask_inf = np.isinf(vals)
    mask_nan = np.isnan(vals)
    mask_big = np.abs(vals) > max32
    mask_bad = mask_inf | mask_nan | mask_big

    bad_rows = mask_bad.any(axis=1)
    bad_cols = mask_bad.any(axis=0)

    n_bad_rows = int(bad_rows.sum())
    n_bad_cols = int(bad_cols.sum())

    if debug:
        print(f"Bad rows: {n_bad_rows}, Bad cols: {n_bad_cols}")
        if n_bad_cols:
            bad_cols_counts = pd.Series(mask_bad.sum(axis=0), index=X.columns).sort_values(ascending=False)
            print("Top bad columns:\n", bad_cols_counts[bad_cols_counts > 0].head(20))

        if n_bad_rows:
            bad_fids = X.index[bad_rows]
            print("First bad FIDs:", bad_fids[:10].tolist())
            # 展示每个坏行的具体坏列（前 5 个）
            for fid in bad_fids[:5]:
                ri = X.index.get_loc(fid)
                cols = [c for c, b in zip(X.columns, mask_bad[ri]) if b]
                print(f"FID {fid} bad cols: {cols[:15]}")

    # 写出完整报告并中断，防止进入 cl.fit 再报错
    if n_bad_rows or n_bad_cols:
        if write_report:
            records = []
            row_idx = np.where(bad_rows)[0]
            for r in row_idx:
                fid = X.index[r]
                cols = [c for c, b in zip(X.columns, mask_bad[r]) if b]
                records.append({"FID": fid, "bad_cols": "|".join(cols)})
            pd.DataFrame(records).to_csv(report_path, index=False)
            print("Bad-feature report ->", report_path)
        raise ValueError("Features contain inf/NaN/too-large values. See printed diagnostics/report.")

    # —— 无问题后再训练 —— #
    rf_classifier = RandomForestClassifier()
    cl = CleanLearning(clf=rf_classifier)

    noise_threshold = 50
    max_iterations = 20

    #previous_labels = y.copy()
    le = LabelEncoder()
    y_enc = pd.Series(le.fit_transform(y.values), index=y.index)  # 0..K-1
    K = len(le.classes_)
    previous_labels = y_enc.copy()

    for iteration in range(max_iterations):
        if debug: print(f"Fitting iteration {iteration+1} ...")

        # 每轮新建 CleanLearning，避免带入上轮的 K
        class_counts = np.bincount(previous_labels.values, minlength=K)
        minc = int(class_counts.min())
        if minc < 2:
            if debug: print(f"Stop at iter {it+1}: min class count < 2")
            break
        cv = min(5, minc)

        cl = CleanLearning(clf=RandomForestClassifier(), cv_n_folds=cv, verbose=debug)
        cl.fit(X, previous_labels.values)
        li = cl.find_label_issues(X=X, labels=previous_labels.values)

        idx_issue = np.where(li['is_label_issue'].values)[0]
        preds = li['predicted_label'].astype(int).values

        next_labels = previous_labels.copy()
        if len(idx_issue) > 0:
            next_labels.iloc[idx_issue] = preds[idx_issue]

        previous_labels = next_labels

        current_noise_count = int(li['is_label_issue'].sum())
        print(f"Iteration {iteration + 1}: Noise points = {current_noise_count}")
        if current_noise_count < noise_threshold:
            print(f"Stopped at iteration {iteration + 1}: Noise < {noise_threshold}")
            break

    # 噪声标记在编码域比较
    is_noise_series = (previous_labels != y_enc).astype(int)

    # 导出前把置信标签反编码回原始标签域
    conf_label_enc = previous_labels.reindex(label_df.index)                     # 对齐导出索引
    mask = conf_label_enc.notna()
    conf_label = pd.Series(index=label_df.index, dtype='float64')
    conf_label[mask] = le.inverse_transform(conf_label_enc[mask].astype(int).values)

    # is_noise 与导出索引对齐
    is_noise = is_noise_series.reindex(label_df.index).fillna(1).astype(int).to_numpy()

    result_df = pd.DataFrame({
        'FID': label_df.index,
        'latitude': label_df['latitude'],
        'longitude': label_df['longitude'],
        'Map_Label': label_df['Label'],                 # 原始域
        'Confidence_Label': conf_label.to_numpy(),      # 反编码后标签
        'Is_Noise': is_noise
    })
    result_df.to_csv(out_filepath, index=False)
    if debug: print("Saved cleaned labels ->", out_filepath)

In [None]:
tilename = '30TYR'
year = 2020
clean_referenceLabel(tilename,year)

## reclustering

In [None]:
# reclustering the wheat samples and non wheat samples using EVI value after SG filtering
def reclustering(tilename, year, debug=True):
    out_dir = f'/content/drive/MyDrive/test_{tilename}'
    out_filepath = f'{out_dir}/{tilename}_{year}_clustered_labels.csv'

    result_file = f'{out_dir}/{tilename}_{year}_cleaned_labels.csv'
    VI_file     = f'{out_dir}/{tilename}_{year}_VI_SGfeatures.csv'

    # 读数据
    result_df  = pd.read_csv(result_file)
    feature_df = pd.read_csv(VI_file)

    # 确保用 FID 对齐
    if 'FID' not in result_df.columns or 'FID' not in feature_df.columns:
        raise KeyError("Both files must contain 'FID' column.")
    feature_df = feature_df.set_index('FID')
    result_df  = result_df.set_index('FID', drop=False)

    # 选 EVI_* 特征
    evi_cols = [c for c in feature_df.columns if re.match(r'^EVI_\d+$', c)]
    if debug:
        print("EVI cols:", len(evi_cols))
    if not evi_cols:
        raise ValueError("No EVI_* columns found.")

    features_VI = feature_df[evi_cols].astype('float64')
    # 缺失填补（稳健）
    if features_VI.isna().any().any():
        imp = SimpleImputer(strategy='median')
        features_VI = pd.DataFrame(imp.fit_transform(features_VI),
                                   index=features_VI.index, columns=features_VI.columns)

    # 过滤噪声：兼容 0/1 或 True/False
    if result_df['Is_Noise'].dropna().isin([0,1]).all():
        keep_mask = result_df['Is_Noise'] == 0
    else:
        keep_mask = result_df['Is_Noise'] == False

    filtered_data = result_df[keep_mask].copy()
    filtered_features = features_VI.loc[filtered_data.index.intersection(features_VI.index)]

    if debug:
        print("After noise filter:", filtered_data.shape[0], "samples")

    # 二值化标签：label==1 -> 1，否则 0
    if 'Confidence_Label_Bin' in filtered_data.columns:
        y_bin = filtered_data['Confidence_Label_Bin'].astype(int)
    elif 'Confidence_Label' in filtered_data.columns:
        y_bin = (pd.to_numeric(filtered_data['Confidence_Label'], errors='coerce').fillna(-1).astype(int) == 1).astype(int)
    elif 'Map_Label' in filtered_data.columns:
        y_bin = (pd.to_numeric(filtered_data['Map_Label'], errors='coerce').fillna(-1).astype(int) == 1).astype(int)
    else:
        raise KeyError("Need one of ['Confidence_Label_Bin','Confidence_Label','Map_Label'] in result file.")

    if debug:
        print("Binary label counts:", y_bin.value_counts().to_dict())

    # 拆分
    idx0 = filtered_data.index[y_bin == 0]
    idx1 = filtered_data.index[y_bin == 1]
    X0 = filtered_features.loc[idx0]
    X1 = filtered_features.loc[idx1]

    def safe_kmeans(X, k, name):
        n = len(X)
        if n == 0:
            if debug: print(f"{name}: 0 samples. Skip clustering.")
            return np.array([], dtype=int)
        if k > n:
            if debug: print(f"{name}: reduce n_clusters {k}->{n}")
            k = n
        return KMeans(n_clusters=k, random_state=42).fit_predict(X)

    # 聚类
    clusters_label_0 = safe_kmeans(X0, 5, "label==0")
    clusters_label_1 = safe_kmeans(X1, 2, "label==1")

    # 生成聚类编码
    lab0 = pd.Series((clusters_label_0 + 1) * 10, index=idx0) if len(clusters_label_0) else pd.Series(dtype='float64')
    lab1 = pd.Series((clusters_label_1 + 1) * 10 + 1, index=idx1) if len(clusters_label_1) else pd.Series(dtype='float64')

    # 写回
    result_df['Cluster'] = np.nan
    result_df.loc[lab0.index, 'Cluster'] = lab0
    result_df.loc[lab1.index, 'Cluster'] = lab1

    # 导出
    final_cols = [c for c in ['FID','latitude','longitude','Map_Label','Confidence_Label','Confidence_Label_Bin','Cluster'] if c in result_df.columns]
    final_result = result_df[final_cols].reset_index(drop=True)
    final_result.to_csv(out_filepath, index=False)
    if debug:
        print("Saved ->", out_filepath, "| rows:", len(final_result))


In [None]:
tilename = '30TYR'
year = 2020
reclustering(tilename,year)

## generate reference curve

In [None]:
#生成谐波参数名
def genetate_feature_names(bandname):
    feature_names = [f'constant_{bandname}_{bandname}', f'cos_1_{bandname}_{bandname}', f'cos_2_{bandname}_{bandname}', f'cos_3_{bandname}_{bandname}', f'sin_1_{bandname}_{bandname}', f'sin_2_{bandname}_{bandname}', f'sin_3_{bandname}_{bandname}']
    return feature_names

# 定义谐波函数
def harmonic_function(x, feature_row,bandname):
    feature_name = genetate_feature_names(bandname)
    """ 构建谐波函数，基于给定特征 """
    constant = feature_row[feature_name[0]]
    harmonic_sum = (
        constant +
        feature_row[feature_name[1]] * np.cos(x) +
        feature_row[feature_name[4]] * np.sin(x) +
        feature_row[feature_name[2]] * np.cos(2 * x) +
        feature_row[feature_name[5]] * np.sin(2 * x) +
        feature_row[feature_name[3]] * np.cos(3 * x) +
        feature_row[feature_name[6]] * np.sin(3 * x)
    )
    return harmonic_sum

# 定义要拟合的谐波模型
def harmonic_model(x, constant, cos_1, cos_2, cos_3, sin_1, sin_2, sin_3):
    """ 拟合的谐波模型，基于谐波函数的形式 """
    return (
        constant +
        cos_1 * np.cos(x) + sin_1 * np.sin(x) +
        cos_2 * np.cos(2 * x) + sin_2 * np.sin(2 * x) +
        cos_3 * np.cos(3 * x) + sin_3 * np.sin(3 * x)
    )

# 获取谐波参数 based on VI value
def get_harmonic_parameter_usingValue(features_df,start_date,end_date,labels_df,bandname,labelColumn):

    merged_df = features_df.merge(labels_df[['FID', labelColumn]], on='FID', how='inner')

    # 定义时间序列的 x 轴范围 (比如从 0 到 2π)
    base_date = datetime(start_date.year,1,1)
    start_doy = (start_date- base_date).days + 1
    end_doy = (end_date- base_date).days + 1
    x_values = np.arange(start_doy-1, end_doy,15)

    clusters = merged_df[labelColumn].dropna().unique()
    features_name = [col for col in features_df.columns if col.startswith(f'{bandname}_')]

    # 提取 features_name 中的 DOY
    feature_doys = [int(col.split('_')[1]) for col in features_name]  # 提取 DOY 值
    feature_doys_set = set(feature_doys)

    # 计算交集 DOY
    x_values_set = set(x_values)
    common_doys = feature_doys_set.intersection(x_values_set)

    # 根据交集筛选 features_name 和 x_values
    filtered_features_name = [col for col in features_name if int(col.split('_')[1]) in common_doys]
    filtered_x_values = [doy for doy in x_values if doy in common_doys]

    # 转换为 NumPy 数组（可选）
    filtered_x_values = np.array(filtered_x_values)
    t_values = filtered_x_values * np.pi / 365

    all_popt = []

    for cluster in clusters:
        # 筛选属于当前聚类的样本
        cluster_data = merged_df[merged_df[labelColumn] == cluster]
        features = cluster_data[filtered_features_name]
        features_mean = features.median()
        popt, pcov = curve_fit(harmonic_model, t_values, features_mean)
        all_popt.append((cluster, popt))

    return all_popt

# 导出谐波参数到csv文件
def export_cluster_parameters_to_csv(all_bands_popt, output_file):
    # Step 1: 获取所有波段名称
    bands = [band for band, _ in all_bands_popt]

    # Step 2: 初始化存储结果的列表（存储列名）
    columns = ['Cluster']  # 第一列是 Cluster 名
    for band in bands:
        columns.extend([f'{band}_constant', f'{band}_cos1', f'{band}_cos2', f'{band}_cos3', f'{band}_sin1', f'{band}_sin2', f'{band}_sin3'])

    # 存储每个 cluster 的行数据
    result_rows = []

    # Step 3: 遍历所有波段，按 cluster 聚合数据
    # 假设 all_bands_popt 是 [(band_name, [(cluster1, popt1), (cluster2, popt2), ...]), ...] 形式的列表
    cluster_keys = list(set([cluster for _, clusters_popt in all_bands_popt for cluster, _ in clusters_popt]))  # 获取所有 cluster 的 key

    for cluster in cluster_keys:
        row = [cluster]  # 初始化每行的第一个元素是 cluster 名

        # 遍历每个波段的参数
        for band, band_popt in all_bands_popt:
            # 查找该 cluster 对应的 popt
            popt = next((popt for c, popt in band_popt if c == cluster), [None] * 7)  # 如果找不到 cluster，填充 None

            # 确保 popt 的参数顺序为 [constant, cos1, cos2, cos3, sin1, sin2, sin3]
            row.extend(popt)

        # 将该行数据添加到结果中
        result_rows.append(row)

    # Step 4: 将数据转换为 pandas DataFrame 并导出为 CSV
    df = pd.DataFrame(result_rows, columns=columns)
    df.to_csv(output_file, index=False)

# get reference curve of EVI and SR
def get_reference_curve_parameters(tilename,year):

    start_date = datetime(year, 10, 1)
    end_date = datetime(year + 1, 7, 1)
    out_dir = f'/content/drive/MyDrive/test_{tilename}'

    label_file = f'{out_dir}/{tilename}_{year}_cleaned_labels.csv'
    label_df = pd.read_csv(label_file)

    VI_harmonicValue_file = f'{out_dir}/{tilename}_{year}_VI_SGfeatures.csv'
    VI_harmonicValue_df = pd.read_csv(VI_harmonicValue_file)

    SR_harmonicValue_file = f'{out_dir}/{tilename}_{year}_SR_SGfeatures.csv'
    SR_harmonicValue_df = pd.read_csv(SR_harmonicValue_file)

    SR_band = ['B', 'G', 'R', 'RE1', 'RE2', 'RE3', 'NIR', 'RE4', 'SWIR1', 'SWIR2']
    VI_band = ['EVI','LSWI','OSAVI','RVI']

    output_SR_file = f'{out_dir}/{tilename}_{year}_SR_Reference_parameters.csv'
    output_VI_file = f'{out_dir}/{tilename}_{year}_VI_Reference_parameters.csv'

    all_SRbands_popt = []
    all_VIbands_popt = []

    labelColumn = 'Confidence_Label'

    for band in SR_band:
        SR_parameter = get_harmonic_parameter_usingValue(SR_harmonicValue_df,start_date,end_date,label_df,band,labelColumn)
        all_SRbands_popt.append((band, SR_parameter))
    export_cluster_parameters_to_csv(all_SRbands_popt, output_SR_file)
    print('SR reference curve export done.')

    for band in VI_band:
        VI_parameter = get_harmonic_parameter_usingValue(VI_harmonicValue_df,start_date,end_date,label_df,band,labelColumn)
        all_VIbands_popt.append((band, VI_parameter))
    export_cluster_parameters_to_csv(all_VIbands_popt, output_VI_file)
    print('VI reference curve export done.')


In [None]:
tilename = '30TYR'
year = 2020
get_reference_curve_parameters(tilename,year)

## plot the reference curve

In [None]:
from datetime import datetime
def generate_parameter_names(bandname):
    feature_names = [f'{bandname}_constant', f'{bandname}_cos1', f'{bandname}_cos2', f'{bandname}_cos3', f'{bandname}_sin1', f'{bandname}_sin2', f'{bandname}_sin3']
    return feature_names
# 定义谐波函数
def harmonic_function(x, feature_row,bandname):
    feature_name = generate_parameter_names(bandname)
    """ 构建谐波函数，基于给定特征 """
    constant = feature_row[feature_name[0]]
    harmonic_sum = (
        constant +
        feature_row[feature_name[1]] * np.cos(x) +
        feature_row[feature_name[4]] * np.sin(x) +
        feature_row[feature_name[2]] * np.cos(2 * x) +
        feature_row[feature_name[5]] * np.sin(2 * x) +
        feature_row[feature_name[3]] * np.cos(3 * x) +
        feature_row[feature_name[6]] * np.sin(3 * x)
    )
    return harmonic_sum
def plot_reference_curve(tilename):
    label_filepath = '/content/drive/MyDrive/test_30TYR/30TYR_2020_VI_Reference_parameters.csv'
    label_df = pd.read_csv(label_filepath)
    bands = ['EVI','LSWI','OSAVI','RVI']
    #bands = ['VV']
    feature_EVI_name = generate_parameter_names(bands[0])
    #clusters = [11,21,10,20,30,40,50]
    clusters = [1,2,3,4,5,6]
    start_date = datetime(2020, 10, 1)
    end_date = datetime(2021, 7, 1)
    base_date = datetime(start_date.year,1,1)
    start_doy = (start_date- base_date).days + 1
    end_doy = (end_date- base_date).days + 1
    x_values = np.arange(start_doy, end_doy + 1,15)
    t_values = x_values * np.pi / 365

    cluster_name_map = {
        1: 'Winter Wheat',
        2: 'winter rapeseed',
        3: 'other winter crops',
        4: 'spring crops',
        5: 'summer crops',
        6: 'Other crops'
    }

    for cluster in clusters:
        # 筛选属于当前聚类的样本
        cluster_data = label_df[label_df['Cluster'] == cluster]

        # 初始化当前聚类的谐波曲线
        mean_harmonic_values = np.zeros_like(t_values)

        # 对每个样本计算谐波函数并取平均
        for _, row in cluster_data[feature_EVI_name].iterrows():
            harmonic_values = harmonic_function(t_values, row,'EVI')
            mean_harmonic_values += harmonic_values

        # 平均每个聚类的谐波曲线
        mean_harmonic_values /= len(cluster_data)
        print(cluster,' sample number:',len(cluster_data))

        # 绘制聚类的平均谐波曲线
        plt.plot(x_values, mean_harmonic_values, label=cluster_name_map.get(cluster, f'Cluster {cluster}'))

    # 添加图例、标题和标签
    plt.title('Harmonic Curves for Each Cluster')
    plt.xlabel('DOY')
    plt.ylabel('Harmonic Value')
    plt.legend()
    plt.show()

plot_reference_curve('30TYR')

#3. get training objects (GEE & python)

## DTSS method to calculate similarity distance

In [None]:
!pip install rasterio matplotlib
!pip install tslearn joblib

In [None]:
import numpy as np
import pandas as pd
import os
import glob
from tslearn.metrics import dtw_path
from joblib import Parallel, delayed
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from datetime import datetime

In [None]:
#**** functions related to get refer VI curve
# add VIs
def add_VIs(df):
    # EVI (Enhanced Vegetation Index)
    df['EVI'] = 1000 * 2.5 * (df['NIR'] - df['R']) / (df['NIR'] + 6 * df['R'] - 7.5 * df['B'] + 10000)
    return df

# generate harmonic parameter name based on a given bandname
def generate_parameter_names(bandname):
    feature_names = [f'{bandname}_constant', f'{bandname}_cos1', f'{bandname}_cos2', f'{bandname}_cos3', f'{bandname}_sin1', f'{bandname}_sin2', f'{bandname}_sin3']
    return feature_names

# define the harmonic function
def harmonic_function(x, feature_row,bandname):
    parameter_names = generate_parameter_names(bandname)
    """ 构建谐波函数，基于给定特征 """
    constant = feature_row[parameter_names[0]]
    VI_harmonic_values = (
        constant +
        feature_row[parameter_names[1]] * np.cos(x) +
        feature_row[parameter_names[4]] * np.sin(x) +
        feature_row[parameter_names[2]] * np.cos(2 * x) +
        feature_row[parameter_names[5]] * np.sin(2 * x) +
        feature_row[parameter_names[3]] * np.cos(3 * x) +
        feature_row[parameter_names[6]] * np.sin(3 * x)
    )
    return VI_harmonic_values

#get refer vi curve of whole growing period based on the harmonic parameters
def get_refer_VIcurve(VI_file,t_values,bandname):
    VI_df = pd.read_csv(VI_file)
    x_values = t_values * np.pi / 365

    refer_VI_curve_list = VI_df.apply(lambda row: harmonic_function(x_values, row, bandname), axis=1)

    refer_VI_curve_df = pd.DataFrame(refer_VI_curve_list.tolist(), columns=[f'Time_{int(t)}' for t in t_values])
    refer_VI_curve_df['Cluster'] = VI_df['Cluster']
    return refer_VI_curve_df

# get the spectral reflectance curve based on the harmonic parameters and given time
def get_refer_srCurve(refer_SRpara,SR_bandnames,start_date,time_Ts,interval):
    refer_SRpara = refer_SRpara.iloc[0]
    refer_SR_curve_list = []
    time_Ts = np.array(time_Ts)

    start_doy = start_date.timetuple().tm_yday
    x = start_doy + (time_Ts-1) * interval * np.pi / 365

    for sr_bandname in SR_bandnames:
        sr_value = harmonic_function(x, refer_SRpara,sr_bandname)
        refer_SR_curve_list.append(sr_value)
    refer_SR_curve_array = np.array(refer_SR_curve_list).T
    return refer_SR_curve_array

# get the refer curve time of given time T based on the path result of dtw
def get_referTime(optimal_paths,time_T):
    refer_Ts = []
    for path in optimal_paths:
        refer_T = None
        for refer_index,sample_index in path:
            if sample_index == time_T:
                refer_T = refer_index
                break
        refer_Ts.append(refer_T)
    return refer_Ts

# get the spectral reflectance curve of target samples based on the given filename
def get_target_srCurve(file,timeT,sr_bandnames,FID_df):
    primary_df = pd.read_csv(file)
    primary_df = primary_df.drop(['system:index','.geo'], axis=1)
    primary_df['date'] = pd.to_datetime(primary_df['date'], format='%Y-%m-%dT%H:%M:%S')
    target_sr_df = primary_df[(primary_df['date'] == timeT)]
    target_sr_df = target_sr_df[target_sr_df['FID'].isin(FID_df['FID'])]
    target_SR_curve_list = target_sr_df[sr_bandnames].values.tolist()
    return target_SR_curve_list

# get the spectral reflectance curve of target samples based on the given filename after SG filtering
def get_target_srCurve_SG(file,start_date, timeT,sr_bandnames,FID_df):
    primary_df = pd.read_csv(file)
    base_date = datetime(start_date.year, 1, 1)
    doy = (timeT - base_date).days + 1
    sr_columns = [f'{band}_{doy}' for band in sr_bandnames]

    existing_columns = [col for col in sr_columns if col in primary_df.columns]
    if not existing_columns:
        return []
    target_sr_df = primary_df[primary_df['FID'].isin(FID_df['FID'])]
    target_SR_curve_list = target_sr_df[sr_columns].values.tolist()
    return target_SR_curve_list

#*** get target VI value of given time period based on the primary surface reflectance value
def get_target_VIs(target_features_file, VI_bandname, start_date, timeT):
    primary_data = pd.read_csv(target_features_file)
    primary_data = primary_data.drop(['system:index', '.geo'], axis=1)
    primary_data['date'] = pd.to_datetime(primary_data['date'], format='%Y-%m-%dT%H:%M:%S')
    timeT_df = primary_data[(primary_data['date'] == timeT)]
    time_T_FIDs = timeT_df['FID'].unique()

    data = add_VIs(primary_data)

    # Convert date to doy
    base_date = datetime(start_date.year, 1, 1)
    data['doy'] = (data['date'] - base_date).dt.days + 1
    filtered_data = data[(data['date'] >= start_date) & (data['date'] <= timeT)].copy()

    doys = filtered_data['doy'].unique()
    temp_data_list = []

    # Prepare a list to store data for each doy
    for doy in doys:
        df = filtered_data[filtered_data['doy'] == doy]
        time_column_name = f'{VI_bandname}_{int(doy)}'
        df = df[['FID','latitude','longitude', VI_bandname]].rename(columns={VI_bandname: time_column_name})
        temp_data_list.append(df)

    combined_data = pd.concat(temp_data_list, axis=1, join='outer')
    combined_data = combined_data.loc[:, ~combined_data.columns.duplicated()]
    combined_data = combined_data[combined_data['FID'].isin(time_T_FIDs)]

    # Interpolate missing values in the VI columns
    vi_columns = [col for col in combined_data.columns if col.startswith(VI_bandname)]
    combined_data[vi_columns] = combined_data[vi_columns].interpolate(method='linear', axis=1, limit_direction='both')
    combined_data[vi_columns] = combined_data[vi_columns].fillna(method='ffill').fillna(method='bfill')

    return combined_data

#*** get target VI value of given time period based on the SG filtering value
def get_target_VIs_SG(target_features_file, FID_file,VI_bandname, start_date, timeT):
    FID_df = pd.read_csv(FID_file)
    FID_df = FID_df.drop_duplicates(subset='FID')
    FID_location = FID_df.set_index('FID')[['latitude', 'longitude']]

    primary_data = pd.read_csv(target_features_file)

    all_VI_columnsName = [col for col in primary_data.columns if col.startswith(f'{VI_bandname}_')]
    # get the given time period VI value
    base_date = datetime(start_date.year, 1, 1)
    start_doy = (start_date - base_date).days + 1
    timeT_doy = (timeT - base_date).days + 1
    doys = np.arange(start_doy, timeT_doy + 1,15)

    filtered_columns = [col for col in all_VI_columnsName if int(col.split('_')[1]) in doys]
    existing_columns = [col for col in filtered_columns if col in primary_data.columns]
    if existing_columns:
        filtered_data = primary_data[['FID'] + existing_columns]
        filtered_data = filtered_data.copy()
        filtered_data['latitude'] = filtered_data['FID'].map(FID_location['latitude'])
        filtered_data['longitude'] = filtered_data['FID'].map(FID_location['longitude'])
    else:
        # 如果没有匹配列，返回一个空 DataFrame
        filtered_data = pd.DataFrame(columns=['FID', 'latitude', 'longitude'])
    return filtered_data

#*** functions related of DTLS similarity calculation
# function to compute DTW distance and path for one tme series
def compute_dtw(reference,sample,search_radius):
    distance, path = dtw_path(reference, sample,global_constraint='sakoe_chiba',sakoe_chiba_radius=search_radius)
    return distance, path

# function to compute SAD of the target surface reflectance at given time T and refer surface reflectance at related time refer_T
def compute_sad(refer_SRs, target_SRs):
    refer_SRs = np.array(refer_SRs)
    target_SRs = np.array(target_SRs)
    if refer_SRs.shape != target_SRs.shape:
        print('refer_SRs',refer_SRs.shape)
        print('target_SRs',target_SRs.shape)
        raise ValueError("refer_SRs 和 target_SRs 的形状必须匹配！")
    dot_product = np.einsum('ij,ij->i', refer_SRs, target_SRs)
    refer_magnitudes = np.linalg.norm(refer_SRs, axis=1)
    target_magnitudes = np.linalg.norm(target_SRs, axis=1)
    # Avoid division by zero by setting SAD to 1.0 for zero-magnitude cases
    zero_mask = (refer_magnitudes == 0) | (target_magnitudes == 0)
    # Calculate cosine similarity
    cos_theta = np.divide(dot_product, refer_magnitudes * target_magnitudes, where=~zero_mask)
    # Clamp values to the range [-1, 1] to avoid any issues with arccos
    cos_theta = np.clip(cos_theta, -1.0, 1.0)
    # Calculate the spectral angle distance in radians and normalize by pi, then scale by 10000
    spectral_angle_distance = 2 * np.arccos(cos_theta) / np.pi
    # Set SAD to 1.0 (or any indicator value you choose) where magnitudes are zero
    spectral_angle_distance[zero_mask] = 1.0
    return spectral_angle_distance*10000

# DTLS method for sample similatity distance calculation for a given cluster and a given time T
def DTLS_singleCluster(refer_curve,target_curve_df,target_lastT_srCurve,start_date,timeT,timeT_index,refer_SRpara,SR_bandnames,search_radius,interval):

    # convert the refer curve dataframe to array for dtw method input
    time_columns = [col for col in refer_curve.columns if col.startswith('Time_')]

    refer_curve_array = refer_curve.iloc[0][time_columns].values

    # convert the target curve dataframe to array for dtw method input
    vi_columns = [col for col in target_curve_df.columns if col.startswith('EVI_')]
    samples_curve_series = target_curve_df.loc[:, vi_columns].values
    extended_samples_curve_series = np.pad(
        samples_curve_series,
        pad_width=((0, 0), (search_radius, search_radius)),  # No padding on rows, padding on columns
        mode='edge'
    )

    # Parallelize the DTW computation across multiple cores
    num_cores = -1  # Use all available cores. You can specify a number like num_cores=4 to limit.
    results = Parallel(n_jobs=num_cores)(delayed(compute_dtw)(refer_curve_array, sample,search_radius) for sample in extended_samples_curve_series)
    # Extract distances and paths from the results
    optimal_path,dtw_distance = zip(*results)

    # get the refer surface reflectance curves of given last time using the harmonic parameters
    relavent_target_lastT = timeT_index + search_radius - 1
    refer_lastT = get_referTime(optimal_path,relavent_target_lastT)
    refer_lastT_srCurve = get_refer_srCurve(refer_SRpara,SR_bandnames,start_date,refer_lastT,interval)

    # get the related spectral angle distance of related time period
    spectral_angle_distance = compute_sad(refer_lastT_srCurve,target_lastT_srCurve)

    DTLS_distance = dtw_distance + spectral_angle_distance

    return dtw_distance,spectral_angle_distance #DTLS_distance

#*** main producer of DTLS for sample similarity distance calculation for each cluster at given time time_T
def DTLS_distance_calculate(target_features_fileDir,refer_SR_file,refer_VI_file,VI_bandname,SR_bandnames,out_fileDir,start_date,timeT,timeT_index,search_radius,tilname,interval,target_year):
    # get refer sr harmonic parameters
    refer_SRpara_df = pd.read_csv(refer_SR_file)
    # get refer vi curve
    referT_indexs = np.linspace(1-search_radius, timeT_index+search_radius, timeT_index+2*search_radius)
    start_doy = start_date.timetuple().tm_yday
    referT_values = start_doy + (referT_indexs-1)*interval
    refer_VI_curve_df = get_refer_VIcurve(refer_VI_file,referT_values,VI_bandname)

    # get target vi value of each sample
    target_VI_file = os.path.join(target_features_fileDir, f'{tilname}_{target_year}_VI_SGfeatures.csv')
    FID_file = '/content/drive/MyDrive/test_30TYR/primary_30TYR_2021_objects.csv'
    target_VI_df = get_target_VIs_SG(target_VI_file,FID_file,VI_bandname,start_date,timeT)

    if target_VI_df.empty:
        print(timeT," is empty.")
        return

    # get the target surface reflectance curve of given last time using the related sr file
    target_sr_file = os.path.join(target_features_fileDir, f'{tilname}_{target_year}_SR_SGfeatures.csv')
    target_lastT_srCurve = get_target_srCurve_SG(target_sr_file,start_date,timeT,SR_bandnames,target_VI_df)

    sample_cluster_resultDF = pd.DataFrame()
    sample_cluster_resultDF['FID'] = target_VI_df['FID']

    target_column = f'EVI_{int(start_doy + (timeT_index-1)*interval)}'
    if target_column not in target_VI_df.columns:
        print(f"Column '{target_column}' does not exist. Exiting function.")
        return

    sample_cluster_resultDF['EVI'] = target_VI_df[target_column]

    clusters = refer_VI_curve_df['Cluster'].unique()
    for cluster in clusters:
        print('  cluster ',cluster,' labeling calculating ....')
        # get the refer curve information and values of given cluster
        refer_VI_curve_cluster = refer_VI_curve_df[refer_VI_curve_df['Cluster'] == cluster]
        # get refer surface reflectance harmonic parameters of given cluster
        refer_SRpara_cluster = refer_SRpara_df[refer_SRpara_df['Cluster'] == cluster]

        # get the DTLS index value of given cluster without SAR
        dtw_distance_column,sad_distance_column = DTLS_singleCluster(refer_VI_curve_cluster,target_VI_df,target_lastT_srCurve,start_date,timeT,timeT_index,refer_SRpara_cluster,SR_bandnames,search_radius,interval)

        # define the DTLS value column name of given cluster and add to the reslut dataframe
        dtw_distance_columnName = f'{int(cluster)}_dtw_distance'
        sad_distance_columnName = f'{int(cluster)}_sad_distance'
        sample_cluster_resultDF[dtw_distance_columnName] = dtw_distance_column#dtw_distance_column
        sample_cluster_resultDF[sad_distance_columnName] = sad_distance_column#sad_distance_column

    sample_cluster_resultDF['cluster label'] = sample_cluster_resultDF[[distance_columnName for distance_columnName in sample_cluster_resultDF.columns if 'distance' in distance_columnName]].idxmin(axis=1)
    sample_cluster_resultDF['cluster label'] = sample_cluster_resultDF['cluster label'].str.replace('_distance', '', regex=False)
    sample_cluster_resultDF['latitude'] = target_VI_df['latitude']
    sample_cluster_resultDF['longitude'] = target_VI_df['longitude']
    # weite the final DTLS value of each sample with each cluster to a csv file
    sample_label_file = os.path.join(out_fileDir, f'sample_distance_{timeT}_{tilname}.csv')
    sample_cluster_resultDF.to_csv(sample_label_file, index=False)

def DTLS_distance_combinate(Results_dir,VI_bandname,tilname,timeT_indexs):
    sample_label_file0_basename = next(f for f in os.listdir(Results_dir) if f.startswith('sample_distance_'))
    sample_label_file0 = os.path.join(Results_dir, sample_label_file0_basename)
    sample_label_df0 = pd.read_csv(sample_label_file0)

    max_EVI = sample_label_df0.set_index('FID')[VI_bandname]  # 将 FID 设为索引
    max_sad_distances = sample_label_df0.set_index('FID')[[col for col in sample_label_df0.columns if '_sad_distance' in col]].copy()
    max_sad_distances = max_sad_distances/10
    cluster_names = set(col.split('_')[0] for col in sample_label_df0.columns if '_dtw_distance' in col)

    for time_T in timeT_indexs:
        #time_T = int(time_T)
        sample_label_file = os.path.join(Results_dir, f'sample_distance_{time_T}_{tilname}.csv')
        if not os.path.exists(sample_label_file):
            print(f"File {sample_label_file} does not exist. Skipping...")
            continue

        final_result_file = os.path.join(Results_dir, f'sample_label_{time_T}_{tilname}.csv')
        distance_df = pd.read_csv(sample_label_file).set_index('FID')
        final_distance_df = pd.DataFrame(index=distance_df.index)

        EVI_cur = distance_df[VI_bandname]
        dtw_distance_curs = distance_df[[col for col in distance_df.columns if '_dtw_distance' in col]]
        sad_distance_curs = distance_df[[col for col in distance_df.columns if '_sad_distance' in col]]
        sad_distance_curs = sad_distance_curs/10

        max_EVI = max_EVI.reindex(max_EVI.index.union(distance_df.index), fill_value=-np.inf)
        max_sad_distances = max_sad_distances.reindex(max_EVI.index)

        merged_EVI = pd.concat([EVI_cur, max_EVI], axis=1, keys=['EVI_cur', 'EVI_max'])

        for cluster in cluster_names:
            sad_distance_cur = sad_distance_curs[f'{cluster}_sad_distance']
            sad_distance_max = max_sad_distances[f'{cluster}_sad_distance']

            # 更新 sad 距离：如果当前 EVI 大于最大 EVI，则用当前时相的 sad 距离，否则用最大 EVI 的 sad 距离
            updated_sad_distance = sad_distance_cur.where(
                (merged_EVI['EVI_cur'] > merged_EVI['EVI_max']) | merged_EVI['EVI_max'].isna(),
                sad_distance_max
            )

            # 更新最大 EVI 和对应的 sad_distance
            max_EVI = merged_EVI['EVI_cur'].where(merged_EVI['EVI_cur'] > merged_EVI['EVI_max'], merged_EVI['EVI_max'])
            max_sad_distances[f'{cluster}_sad_distance'] = updated_sad_distance

            # 计算 dtls 距离
            dtw_distance = dtw_distance_curs[f'{cluster}_dtw_distance']
            dtls_distance = dtw_distance + updated_sad_distance
            final_distance_df[f'{cluster}_distance'] = dtls_distance

        # 获取最小距离的 cluster 作为标签
        final_distance_df['cluster label'] = final_distance_df[[col for col in final_distance_df.columns if 'distance' in col]].idxmin(axis=1)
        final_distance_df['cluster label'] = final_distance_df['cluster label'].str.replace('_distance', '', regex=False)
        final_distance_df['latitude'] = distance_df['latitude']
        final_distance_df['longitude'] = distance_df['longitude']

        # 将 FID 添加回最终结果并保存
        final_distance_df.reset_index(inplace=True)  # 重置索引以便输出时包含 FID 列
        final_distance_df.to_csv(final_result_file, index=False)
        print(f'Time step {time_T} calculation done.')

# main producer for generate training samples in target year
def generate_training_samples(target_features_fileDir, refer_SR_file, refer_VI_file, target_sampleResults_dir, tilename, refer_year, target_year):
    VI_bandname = 'EVI'
    SR_bandnames = ['B', 'G', 'R', 'RE1', 'RE2', 'RE3', 'NIR', 'RE4', 'SWIR1', 'SWIR2']
    search_radius = 2
    interval = 15
    start_date = datetime(target_year - 1, 10, 1)
    end_date = datetime(target_year, 7, 1)

    timeT_indexs = pd.date_range(start=start_date, end=end_date, freq='15D')
    for timeT_index,timeT in enumerate(timeT_indexs,start=1):
        print('time ',timeT,'th labeling calculating....')
        DTLS_distance_calculate(target_features_fileDir,refer_SR_file,refer_VI_file,VI_bandname,SR_bandnames,target_sampleResults_dir,start_date,timeT,timeT_index,search_radius,tilename,interval,target_year)
        print('time ',timeT,'th labeling finished....')
    DTLS_distance_combinate(target_sampleResults_dir,VI_bandname,tilename,timeT_indexs)


In [None]:
refer_year = 2020
target_year = 2021
tilename = '30TYR'
target_features_fileDir = f'/content/drive/MyDrive/test_{tilename}/'
refer_SR_file = f'/content/drive/MyDrive/test_{tilename}/{tilename}_{refer_year}_SR_Reference_parameters.csv'
refer_VI_file = f'/content/drive/MyDrive/test_{tilename}/{tilename}_{refer_year}_VI_Reference_parameters.csv'
target_sampleResults_dir = f'/content/drive/MyDrive/test_{tilename}/target_samples/'
os.makedirs(target_sampleResults_dir, exist_ok=True)
generate_training_samples(target_features_fileDir, refer_SR_file, refer_VI_file, target_sampleResults_dir, tilename, refer_year, target_year)

## Organize the label data into a single CSV file

In [None]:
import os
import re
import pandas as pd

def extract_timeT_from_filename(filename, pattern_prefix, tilename):
    """
    从文件名中提取时间字符串（如 sample_distance_2020-10-01 00:00:00_30TYR.csv → '2020-10-01'）
    """
    pattern = rf"{re.escape(pattern_prefix)}(\d{{4}}-\d{{2}}-\d{{2}}) \d{{2}}:\d{{2}}:\d{{2}}_{re.escape(tilename)}\.csv"
    match = re.search(pattern, filename)
    return match.group(1) if match else None  # 返回 '2020-10-01'

def build_all_timewise_labels(results_dir, tilename, out_prefix='timewise'):
    dist_files = [f for f in os.listdir(results_dir) if f.startswith('sample_distance_') and f.endswith(f'{tilename}.csv')]
    label_files = [f for f in os.listdir(results_dir) if f.startswith('sample_label_version2') and f.endswith(f'{tilename}.csv')]

    wide_from_dist = {}
    wide_from_label = {}
    latlon_dist = None
    latlon_label = None

    # ------- Distance 文件 -------
    for f in dist_files:
        timeT = extract_timeT_from_filename(f, 'sample_distance_', tilename)
        if timeT is None:
            continue
        df = pd.read_csv(os.path.join(results_dir, f)).set_index('FID')
        if 'cluster label' not in df.columns:
            continue
        label = df['cluster label'].astype(str).str.extract(r'(\d+)')[0].astype(int)
        label.name = f'T{timeT}'
        wide_from_dist[f'T{timeT}'] = label
        if latlon_dist is None and 'latitude' in df.columns and 'longitude' in df.columns:
            latlon_dist = df[['latitude', 'longitude']]

    # ------- Label 文件 -------
    for f in label_files:
        timeT = extract_timeT_from_filename(f, 'sample_label_version2', tilename)
        if timeT is None:
            continue
        df = pd.read_csv(os.path.join(results_dir, f)).set_index('FID')
        if 'cluster label' not in df.columns:
            continue
        label = df['cluster label'].astype(int)
        label.name = f'T{timeT}'
        wide_from_label[f'T{timeT}'] = label
        if latlon_label is None and 'latitude' in df.columns and 'longitude' in df.columns:
            latlon_label = df[['latitude', 'longitude']]

    # ------- 整合 Distance -------
    if wide_from_dist:
        dist_df = pd.concat(wide_from_dist.values(), axis=1)
        if latlon_dist is not None:
            dist_df = dist_df.join(latlon_dist, how='left')
        dist_df = dist_df.reset_index()
        time_cols = sorted([col for col in dist_df.columns if col.startswith('T')],
                           key=lambda x: x[1:])  # 排序按日期字符串
        dist_df = dist_df[['FID', 'latitude', 'longitude'] + time_cols]
        out_path = os.path.join(results_dir, f'{out_prefix}_labels_from_distance_{tilename}.csv')
        dist_df.to_csv(out_path, index=False)
        print(f'[saved] {out_path}')
    else:
        print('[warn] No distance-based label files found.')

    # ------- 整合 Label Version -------
    if wide_from_label:
        label_df = pd.concat(wide_from_label.values(), axis=1)
        if latlon_label is not None:
            label_df = label_df.join(latlon_label, how='left')
        label_df = label_df.reset_index()
        time_cols = sorted([col for col in label_df.columns if col.startswith('T')],
                           key=lambda x: x[1:])  # 排序按日期字符串
        label_df = label_df[['FID', 'latitude', 'longitude'] + time_cols]
        out_path = os.path.join(results_dir, f'{out_prefix}_labels_from_labelFiles_{tilename}.csv')
        label_df.to_csv(out_path, index=False)
        print(f'[saved] {out_path}')
    else:
        print('[warn] No version2 label files found.')


In [None]:
Results_dir = '/content/drive/MyDrive/test_30TYR/target_samples/'
tilname = '30TYR'
build_all_timewise_labels(results_dir=Results_dir, tilename=tilname)

#4. multiple random forest classification (GEE)

## upload the samples to asset for classification
take the time T April as example


In [None]:
! pip install earthengine-api
import ee
ee.Authenticate()
ee.Initialize(project='crops-mapping-gaoyuan')
from ee.batch import Export
!pip install geemap
from google.colab import drive
drive.mount('/content/drive')
import time
import numpy as np

In [None]:
def export_points_from_csv_balanced(
    csv_path,
    date_str,                                  # 例: '2020-04-14'
    out_asset_id,                              # 例: 'users/your_name/samples_30TYR_20200414'
    n_pos=1000, n_neg=1000,
    pos_label=1,
    prefer_neg_labels=(2,3,4,5,6),            # 优先覆盖的非1标签
    lon_col='longitude', lat_col='latitude', fid_col='FID',
    random_state=42
):
    # 1) 读取与列校验
    df = pd.read_csv(csv_path)
    label_col = f'T{date_str}'
    need_cols = [fid_col, lon_col, lat_col, label_col]
    for c in need_cols:
        if c not in df.columns:
            raise ValueError(f'CSV 缺少必要列: {c}')

    # 基础清洗
    sub = df[need_cols].dropna(subset=[lon_col, lat_col, label_col]).copy()
    sub[lon_col] = pd.to_numeric(sub[lon_col], errors='coerce')
    sub[lat_col] = pd.to_numeric(sub[lat_col], errors='coerce')
    sub[label_col] = pd.to_numeric(sub[label_col], errors='coerce')
    sub = sub.dropna(subset=[lon_col, lat_col, label_col])

    # 2) 分层抽样
    rng = np.random.default_rng(random_state)

    # 正类
    pos_pool = sub[sub[label_col] == pos_label]
    if len(pos_pool) <= n_pos:
        pos_sample = pos_pool
    else:
        pos_sample = pos_pool.sample(n=n_pos, random_state=random_state)

    # 非1总池
    neg_pool_all = sub[sub[label_col] != pos_label].copy()

    # 先尽量覆盖 prefer_neg_labels = 2..6
    neg_samples = []
    present_pref_labels = [l for l in prefer_neg_labels if l in neg_pool_all[label_col].unique()]

    # 目标为各类平均分配
    if present_pref_labels:
        base = n_neg // len(present_pref_labels)
        rem  = n_neg %  len(present_pref_labels)

        # 每类目标配额
        target_per_class = {l: base for l in present_pref_labels}
        for l in present_pref_labels[:rem]:
            target_per_class[l] += 1

        # 逐类抽样，不足则全取
        taken = 0
        leftover = 0
        for l in present_pref_labels:
            cls_df = neg_pool_all[neg_pool_all[label_col] == l]
            want = target_per_class[l]
            got = min(len(cls_df), want)
            if got > 0:
                neg_samples.append(cls_df.sample(n=got, random_state=random_state))
            taken += got
            leftover += max(0, want - got)

        # 如果 2..6 不足，尝试用其它非1标签补齐
        if leftover > 0:
            others = neg_pool_all[~neg_pool_all[label_col].isin(present_pref_labels)]
            if len(others) > 0:
                need = min(leftover, len(others))
                neg_samples.append(others.sample(n=need, random_state=random_state))
    else:
        # 没有 2..6 的任何类时，直接从非1池中抽 n_neg
        need = min(n_neg, len(neg_pool_all))
        if need > 0:
            neg_samples.append(neg_pool_all.sample(n=need, random_state=random_state))

    neg_sample = pd.concat(neg_samples, axis=0).drop_duplicates(subset=[fid_col]) if neg_samples else neg_pool_all.iloc[0:0]
    # 若仍不足 n_neg，尝试从剩余非1中再补
    if len(neg_sample) < n_neg:
        remain_pool = neg_pool_all.drop(neg_sample.index, errors='ignore')
        need = min(n_neg - len(neg_sample), len(remain_pool))
        if need > 0:
            neg_sample = pd.concat([neg_sample, remain_pool.sample(n=need, random_state=random_state)], axis=0)

    # 截断到最多 n_neg
    if len(neg_sample) > n_neg:
        neg_sample = neg_sample.sample(n=n_neg, random_state=random_state)

    # 汇总
    final_df = pd.concat([pos_sample, neg_sample], axis=0)
    # 打乱（可选）
    final_df = final_df.sample(frac=1.0, random_state=random_state)

    print(f'正类(= {pos_label}) 可用 {len(pos_pool)}，抽取 {len(pos_sample)}')
    print(f'非{pos_label} 可用 {len(neg_pool_all)}，抽取 {len(neg_sample)}')
    print('最终样本量：', len(final_df))

    # 3) 转为 ee.FeatureCollection
    feats = []
    for r in final_df.to_dict('records'):
        lon = float(r[lon_col]); lat = float(r[lat_col])
        lbl = int(r[label_col])
        fid = r[fid_col]
        try:
            fid = int(fid)
        except Exception:
            pass
        geom = ee.Geometry.Point([lon, lat])
        props = {
            'FID': fid,
            'label': lbl,
            'date': date_str,
            'longitude': lon,
            'latitude': lat
        }
        feats.append(ee.Feature(geom, props))
    fc = ee.FeatureCollection(feats)

    # 4) 导出到 Asset
    desc = f'export_samples_balanced_{date_str.replace("-","")}'
    task = ee.batch.Export.table.toAsset(
        collection=fc,
        description=desc,
        assetId=out_asset_id
    )
    task.start()
    print(f'Started task: {desc}')
    print(f'Asset: {out_asset_id}')
    return fc

# ===== 用法示例 =====
# 先在 Colab 挂载你的 Drive：
# from google.colab import drive
# drive.mount('/content')

# 路径示例：你生成的“按日期为列”的宽表 CSV（label 或 distance 版本皆可）
csv_path = '/content/drive/MyDrive/test_30TYR/target_samples/timewise_labels_from_labelFiles_30TYR.csv'
date_str = '2021-04-14'
out_asset_id = 'projects/crops-mapping-gaoyuan/assets/samples_30TYR_20210414'  # 改成你的用户名路径

_ = export_points_from_csv_balanced(
    csv_path, date_str, out_asset_id,
    n_pos=1000, n_neg=1000,
    pos_label=1, prefer_neg_labels=(2,3,4,5,6),
    random_state=2025
)

## multiple Random Forest classification in GEE

In [None]:
# ---------- 1. 添加植被指数 ----------
def add_VIs(img):
    evi = img.expression(
        '2.5*(NIR - R)/(NIR + 6*R - 7.5*B + 1)',
        {
            'NIR': img.select('NIR'),
            'R': img.select('R'),
            'B': img.select('B')
        }
    ).rename('EVI')

    lswi = img.normalizedDifference(['NIR', 'SWIR1']).rename('LSWI')

    rep = img.select(['RE1', 'RE2', 'RE3']).expression(
        '((b("RE3") + b("RE1")) / 2) - b("RE2")'
    ).rename('REP')

    return img.addBands([evi, lswi, rep])

# ---------- 2. 样本二分类 ----------
def relabel(f):
    label = ee.Number(f.get('label'))
    return f.set('code', label.eq(1))  # 1 为冬小麦，其它为 0

# ---------- 3. 参数 ----------
year = 2021
interval = 15
tilename = '30TYR'

s2_borderIndex = ee.FeatureCollection("projects/crops-mapping-gaoyuan/assets/sentinel_2_index_shapefile")
roi = s2_borderIndex.filter(ee.Filter.eq('Name', tilename)).geometry()

samples = ee.FeatureCollection("projects/crops-mapping-gaoyuan/assets/samples_30TYR_20210414").map(relabel)

start_date = ee.Date.fromYMD(year - 1, 10, 1)
end_date = ee.Date.fromYMD(year, 10, 1)

s2SR = get_s2sr_images(start_date, end_date, interval, roi)
s2SR = s2SR.map(add_VIs)

sr_bands = ['B','G','R','RE1','RE2','RE3','NIR','SWIR1','SWIR2']
vi_bands = ['EVI', 'REP', 'LSWI']
all_bands = sr_bands + vi_bands

# ---------- 4. 每个图像分类器 ----------
def classify_image(img):
    features = img.select(all_bands)

    # 采样训练集
    training = features.sampleRegions(
        collection=samples,
        properties=['code'],
        scale=10,
        tileScale=4
    )

    size = training.size()

    # 有样本时分类，无样本时返回全掩膜图像，不计入平均
    classified = ee.Algorithms.If(
        size.gt(0),
        features.classify(
            ee.Classifier.smileRandomForest(100)
              .train(training, 'code', all_bands)
              .setOutputMode('PROBABILITY')
        ).multiply(100).toUint8()
         .updateMask(img.select('R').mask()),  # 保留原始掩膜
        ee.Image.constant(0).updateMask(ee.Image.constant(0))  # 完全掩膜，不影响平均
    )

    return ee.Image(classified).set('mask', img.select('R').mask())

# ---------- 5. 分类每个时相 ----------
classified_list = s2SR.map(classify_image)

# ---------- 6. 融合所有时相 ----------
classified_ic = ee.ImageCollection(classified_list)
final_result = classified_ic.sum()
image_count  = classified_ic.count()

final_prob = final_result.divide(image_count)
final_mask = final_prob.gte(50).selfMask()

# ---------- 7. 导出 ----------
task = ee.batch.Export.image.toDrive(
    image=final_prob.clip(roi),
    description='export_wheat_prop_2021',
    folder='test_30TYR',
    fileNamePrefix='final_wheat_30TYR_prop_2021',
    region=roi.bounds(),
    scale=10,
    maxPixels=1e13,
    fileFormat='GeoTIFF'
)
task.start()
print('Export task started to Google Drive.')
