In [1]:
import os
import re
import numpy as np
import pandas as pd
import pickle as pkl
from osgeo import gdal
from FunctionalCode.StatsPlotFuncs import StatsPlot

In [None]:
### 1.绘制训练数据的直方图
input_dir = "/fossfs/skguan/data_fusion/landsat"
dataset_file = "/fossfs/skguan/data_fusion/dataset.csv"
df = pd.read_csv(dataset_file)
train_files = df['train'].dropna().astype(str).tolist()
train_files += df['val'].dropna().astype(str).tolist()
test_files = df['test'].dropna().astype(str).tolist()
files = test_files  # TODO: 这里修改需要绘制的数据
site_names = [x.split("/")[-2] for x in files][10:20]
data = np.empty(shape=(7, 0, 256, 256))
for i, site_name in enumerate(site_names):
    input_path = os.path.join(input_dir, site_name)
    input_file = os.path.join(input_path, site_name + ".npy")
    imgs = np.load(input_file)
    data = np.concatenate((data, imgs), axis=1)

data = data[1].reshape(-1,)
data = data[data > 0]
sp = StatsPlot()
sp.hist_plot(data)

In [None]:
### 2.绘制训练日志里面的精度曲线
metrics = {
    "train_r2": [], "train_rmse": [], "train_psnr": [], "train_ssim": [],
    "valid_r2": [], "valid_rmse": [], "valid_psnr": [], "valid_ssim": []
}
pattern = ".+, (?P<metric>\\w+)_r2: (?P<r2>-?\\d+\\.?\\d+), \\w+_rmse: (?P<rmse>\\d+\\.?\\d+), " \
    "\\w+_psnr: (?P<psnr>\\d+\\.?\\d+), \\w+_ssim: (?P<ssim>\\d+\\.?\\d+)"
with open("./logs/training_20250620_vit(LMF)_2148206.out", 'r') as f:
    lines = f.readlines()
    for line in lines:
        if "train_r2" in line or "valid_r2" in line:
            regex = re.match(pattern, line)
            metrics["%s_r2" % regex.group("metric")].append(float(regex.group("r2")))
            metrics["%s_rmse" % regex.group("metric")].append(float(regex.group("rmse")))
            metrics["%s_psnr" % regex.group("metric")].append(float(regex.group("psnr")))
            metrics["%s_ssim" % regex.group("metric")].append(float(regex.group("ssim")))

# with open("metrics copy.pkl", 'rb') as f:
#     metrics = pkl.load(f)

key = ["r2", "rmse", "psnr", "ssim"]
output_name = "metric_plot_vit(LMF)"
sp = StatsPlot()
sp.line_plot(metrics, key, col=2)

In [4]:
### 3.保存预测的数据
data = np.load("val_result_fold0_vit(SEASON_L).npz")
label_data = data['arr_0'][10] * 65535.0
pred_data = data['arr_1'][10] * 65535.0

from osgeo import gdal
def save_image(output_path, data):
    driver = gdal.GetDriverByName("GTiff")
    if 'int8' in data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(data.shape) == 3:
        band_num = data.shape[0]
        width = data.shape[1]
        length = data.shape[2]
    else:
        band_num = 1
        width = data.shape[0]
        length = data.shape[1]
    out_ds = driver.Create(
        output_path,
        length,
        width,
        band_num,
        datatype
    )

    if band_num > 1:
        for i in range(band_num):
            out_ds.GetRasterBand(i + 1).WriteArray(data[i])
    else:
        out_ds.GetRasterBand(1).WriteArray(data)

    del out_ds

save_image("label.tif", label_data)
save_image("pred.tif", pred_data)