In [1]:
import os
import re
import numpy as np
import pandas as pd
from sklearn.metrics import r2_score as R2
from sklearn.metrics import root_mean_squared_error as RMSE
from sklearn.metrics import mean_absolute_error as MAE
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM

In [None]:
### 1. 删除不符合要求的数据点
categories = ["train", "val", "test"]
for category in categories:
    file_path = "/lustre1/g/geog_geors/skguan/dataset_23(season)_num12.csv"
    df = pd.read_csv(file_path)
    df = pd.DataFrame(df[f'{category}'].dropna(), columns=[f'{category}'])
    pattern = "(?P<site>.+) (?P<year>\\d+)-(?P<month>\\d+)-(?P<day>\\d+)"
    with open("logs/training_20250625_vit(Season_2Sat)_2164022.out", 'r') as f:
        lines = f.readlines()
        for line in lines:
            if "00:00:00" in line:
                sites = np.array(list(map(lambda x: x.split("/")[-2], df[f'{category}'])))
                dates = np.array(list(map(lambda x: re.match(".*_\\d{6}_(\\d{8})*", x.split("/")[-1]).group(1),
                                        df[f'{category}'])))
                regex = re.match(pattern, line)
                site = regex.group("site")
                date = regex.group("year") + regex.group("month") + regex.group("day")
                idx = np.argwhere((sites == site) & (dates == date))
                if len(idx) == 0:
                    continue
                index = int(idx.squeeze())
                index_names = df[df[f'{category}'] == df[f'{category}'].iloc[index]].index
                df.drop(index_names, inplace=True)
    df.to_csv(os.path.splitext(file_name)[0] + f"_{category}.csv")

In [None]:
### 2. 删除重复的数据点
file_path = "/lustre1/g/geog_geors/skguan/dataset_23(season)_num6.csv"
df = pd.read_csv(file_path)
df_stack = df[df.columns[1:]].stack().reset_index(drop=True)
df_sorted = df_stack.sort_values(key=lambda x: x.map(
    lambda y: re.match(".+_(\\d{6})_.+", os.path.basename(y)).group(1)
))
file_num = len(df_sorted)
tiles = set(map(
    lambda x: re.match("\\w+_(\\d{6})_", os.path.split(x)[-1]).group(1),
    df_sorted.values
))
dataset = {
    "train": [],
    "val": [],
    "test": []
}
train_num = int(file_num * .8)
val_num = train_num + int(file_num * .1)
num = 0
for tile in tiles:
    temp_files = filter(
        lambda x: re.match(f"LC09.*_{tile}_2023.*.tif", os.path.basename(x)),
        df_sorted.values
    )
    temp_files = list(temp_files)
    num += len(temp_files)
    train_ratio = (num - train_num) / file_num
    val_ratio = (num - val_num) / file_num
    if train_ratio < 0:
        dataset["train"].extend(temp_files)
    elif val_ratio < 0:
        dataset["val"].extend(temp_files)
    else:
        dataset["test"].extend(temp_files)
df = pd.DataFrame(
    dict([(k, pd.Series(v)) for k, v in dataset.items()])
)
df.to_csv(file_path)

In [2]:
### 3. 计算各波段的统计指标
def SAM(reference_spectrum, target_spectrum):
    """
    Calculate the spectral angle between two spectra.

    Parameters:
    - reference_spectrum: numpy array, the reference spectrum (e.g., endmember).
    - target_spectrum: numpy array, the target spectrum to compare.

    Returns:
    - angle: float, spectral angle in radians.
    """
    # Ensure inputs are numpy arrays
    reference_spectrum = np.array(reference_spectrum)
    target_spectrum = np.array(target_spectrum)

    # Compute dot product and norms
    dot_product = np.dot(reference_spectrum, target_spectrum)
    norm_ref = np.linalg.norm(reference_spectrum)
    norm_target = np.linalg.norm(target_spectrum)

    # Calculate the spectral angle
    cos_theta = dot_product / (norm_ref * norm_target)
    cos_theta = np.clip(cos_theta, -1, 1)  # Avoid numerical issues
    angle = np.arccos(cos_theta)

    return angle

def ERGAS(reference, fused, resolution_ratio):
    """
    Calculate the ERGAS metric.

    Parameters:
        reference (numpy.ndarray): Reference image (e.g., high-resolution ground truth).
        fused (numpy.ndarray): Fused or processed image.
        resolution_ratio (float): Ratio of the spatial resolutions (e.g., high-res/low-res).

    Returns:
        float: ERGAS value.
    """
    if reference.shape != fused.shape:
        raise ValueError("Reference and fused images must have the same dimensions.")

    # Flatten the images along the spectral bands
    bands = reference.shape[1]
    ergas_sum = 0

    for band in range(bands):
        ref_band = reference[:, band, :, :]
        fused_band = fused[:, band, :, :]
        # Mean of the reference band
        mean_ref = np.mean(ref_band)
        # Root Mean Square Error (RMSE) for the band
        rmse = np.sqrt(np.mean((ref_band - fused_band) ** 2))
        # Accumulate the ERGAS numerator
        ergas_sum += (rmse / mean_ref) ** 2

    # Final ERGAS calculation
    ergas = resolution_ratio * np.sqrt(ergas_sum / bands)
    return ergas
file_dir = "/lustre1/g/geog_geors/skguan/output/data_fusion/val_files"
file_path = os.path.join(file_dir, "val_result_fold0_swin(4-band).npz")
label_max = np.array([65454., 65454., 65454., 65455.,])
label_min = np.array([0., 0., 0., 0.])

data = np.load(file_path)
label_data = data['label']
pred_data = data['pred']
label_data = np.power(label_data, 3)
pred_data = np.power(pred_data, 3)
label_data = label_data * (label_max[:, None, None] - label_min[:, None, None]) + label_min[:, None, None]
pred_data = pred_data * (label_max[:, None, None] - label_min[:, None, None]) + label_min[:, None, None]
label_data = label_data * 2.75e-5 - 0.2
pred_data = pred_data * 2.75e-5 - 0.2

labels = ["Blue", "Green", "Red", "NIR", "SWIR1", "SWIR2"]

sam = SAM(label_data.reshape(-1,), pred_data.reshape(-1,))
ergas = ERGAS(label_data, pred_data, 1.0)
print(f"SAM: {sam:.4f}, ERGAS: {ergas:.4f}")

for i, label_name in enumerate(labels):
    label = label_data[:, i, :, :]
    pred = pred_data[:, i, :, :]
    psnr = PSNR(label, pred, data_range=1)
    ssim = SSIM(label, pred, data_range=1)
    label_list = label.reshape(-1,)
    pred_list = pred.reshape(-1,)
    r2 = R2(label_list, pred_list)
    rmse = RMSE(label_list, pred_list)
    mae = MAE(label_list, pred_list)
    print(f"{label_name}, R2: {r2:.3f}, RMSE: {rmse:.4f}, MAE: {mae:.3f}, "
          f"PSNR: {psnr:.3f}, SSIM: {ssim: .3f}")

SAM: 0.4098, ERGAS: 0.7458
Blue, R2: 0.628, RMSE: 0.0994, MAE: 0.032, PSNR: 20.054, SSIM:  0.789
Green, R2: 0.644, RMSE: 0.0915, MAE: 0.033, PSNR: 20.773, SSIM:  0.818
Red, R2: 0.664, RMSE: 0.0946, MAE: 0.036, PSNR: 20.480, SSIM:  0.819
NIR, R2: 0.655, RMSE: 0.0853, MAE: 0.045, PSNR: 21.379, SSIM:  0.822


IndexError: index 4 is out of bounds for axis 1 with size 4