In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchkeras
from plotly import graph_objects as go
from sklearn.preprocessing import MinMaxScaler

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
from sklearn.metrics import r2_score

import matplotlib as mpl
from matplotlib import cm
from matplotlib.ticker import FormatStrFormatter

# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import os

# 文件夹名称列表
folders = [
    "result array",
    "result dataset img",
    "result feature values plot",
    "result img 1",
    "result img 2",
    "result spatial mode",
    "result time series"
]

# 创建文件夹
for folder in folders:
    os.makedirs(folder, exist_ok=True)  # exist_ok=True 表示如果文件夹已存在则不报错

# 1.读取数据

In [None]:
W = 128
H = 128
GROUP = '1'
data = np.load(f'XXX.npy')

In [None]:
plt.imshow(data.mean(axis=0), interpolation='nearest', cmap='RdYlGn_r')
plt.colorbar(shrink=.92)

# 2.POD分解

In [None]:
def POD_svd(X):
    '''
    input:
        X : m*n matrix with m features and n snapshot after 中心化
    return:
        k+1 : 累计方差99.5%的特征值个数
        u[:, :k+1] ： 对应特征向量 u[:,i]
        s[:k+1] ： 对应特征值列表
        vh[:k+1, :] : 时间系数矩阵 vh[i,:]
    '''
    u, s, vh = np.linalg.svd(X, full_matrices=False)
    # s = s**2 / X.shape[1]
    s1 = s**2 / X.shape[1]
    C_per = np.cumsum(s1) / s1.sum()
    # 求累计99.5%的下标
    k = 0
    for i in range(len(C_per)):
        if C_per[i] > 0.995:
            k = i
            print(C_per[i])
            break
    return k+1, u[:, :k+1], s[:k+1], vh[:k+1, :], C_per

In [None]:
XX = data.reshape(data.shape[0], (W*H)).T
XX_mean = np.mean(XX, axis=1, keepdims=True)
XX = XX - XX_mean
print(XX.shape)

data_train = XX[:, ::2]  # 偶数时间点（t=0,2,4,...）
data_test = XX[:, 1::2]  # 奇数时间点（t=1,3,5,...）

k, u, s, vh, s_per = POD_svd(data_train)
A = vh.T
k, u.shape, s.shape, A.shape

In [None]:
s_per[-1]

In [None]:
plt.figure(figsize=(10,3))
plt.plot(s_per[:k], marker='o', color='darkorange',markersize=6,markerfacecolor='none') # markerfacecolor='none'
plt.grid(True)
plt.tick_params(labelsize=20)
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.xticks([i*50 for i in range(k // 50)] + [k-1])
plt.yticks([s_per[0], s_per[k // 3], s_per[k]])
plt.savefig(f'result feature values plot/{GROUP}.png', dpi=300, bbox_inches='tight')

In [None]:
# 保存空间模态均值
plt.imshow(np.mean(u[:,:], axis=1, keepdims=True).reshape(W,H), cmap='rainbow')
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
cbar = plt.colorbar(shrink=.92, pad=0.05)
cbar.outline.set_visible(False)
plt.savefig(f'result spatial mode/{GROUP}_mean.png', dpi=300, bbox_inches='tight')

In [None]:
# 保存空间模态
for idx in range(3):
    # idx=2
    plt.figure()
    plt.imshow(u[:,idx].reshape(W,H), cmap='rainbow')
    plt.axis('off')
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    cbar = plt.colorbar(shrink=.92, pad=0.05)
    cbar.outline.set_visible(False)
    plt.savefig(f'result spatial mode/{GROUP}_{idx}.png', dpi=300, bbox_inches='tight')

In [None]:
# 绘制时间模态
plt.plot(vh[0,:])
plt.plot(vh[1,:])
plt.plot(vh[2,:])

# 3.时间系数插值

In [None]:
def temporal_interpolation_only(vh):
    """
    对每个时间序列进行线性插值，仅返回中间插值点（不包含原始点）
    :param vh: shape (n_series, t)
    :return: shape (n_series, t - 1)
    """
    return (vh[:, :-1] + vh[:, 1:]) / 2

print(vh.shape)
vh_interp = temporal_interpolation_only(vh)
print(vh_interp.shape)

In [None]:
# 插值结果可视化
fig = go.Figure()
fig.add_trace(go.Scatter(y=vh[:, 0], name='y_true'))
fig.add_trace(go.Scatter(y=[0] + list(vh_interp[:, 0]), name='y_pred'))
fig.show()

# 4.合成

In [None]:
A_interp = vh_interp.T
data_interp_list = []

for i in range(A_interp.shape[0]):
    img_interp = np.dot(u*s, A_interp[i])
    img_interp = img_interp.reshape(W, H)
    
    data_interp_list.append(img_interp)

data_interp = np.array(data_interp_list)
print(data_interp.min(), data_interp.max())
print(data_interp.shape)

# 评价

In [None]:
data_true = data[1::2].astype(np.float32)
# 反中心化
data_pred = data_interp + XX_mean.reshape(W, H)
print(data_true.min(), data_true.max())
print(data_pred.min(), data_pred.max())
data_min, data_max = min(data_true.min(), data_pred.min()), max(data_true.max(), data_pred.max())
print(data_min, data_max)
data_true = (data_true - data_min) / (data_max - data_min)
data_pred = (data_pred - data_min) / (data_max - data_min)

# 计算ssim, psnr, mse
ssim_list = []
psnr_list = []
mse_list = []
for i in range(len(data_interp_list)):
    img_true, img_pred = data_true[i], data_pred[i]
    print(img_true.shape, img_pred.shape)
    print(img_true.min(), img_true.max())
    print(img_pred.min(), img_pred.max())
    ssim_list.append(ssim(img_true, img_pred, data_range=1.0))
    psnr_list.append(psnr(img_true, img_pred, data_range=1.))
    mse_list.append(mse(img_true, img_pred))
    print(i, ssim_list[-1], psnr_list[-1], mse_list[-1])
    print()

# 计算平均值
avg_ssim = np.mean(ssim_list)
avg_psnr = np.mean(psnr_list)
avg_mse = np.mean(mse_list)

# 打印结果
print("Average SSIM: {:.4f}".format(avg_ssim))
print("Average PSNR: {:.2f} dB".format(avg_psnr))
print("Average MSE: {:.6f}".format(avg_mse))
print(f'k:{k}')

In [None]:
data_true.shape, data_pred.shape

In [None]:
plt.imshow(data_true[0], cmap='gist_earth_r')
plt.colorbar()

In [None]:
# 保存图片
cmap_name = 'gist_earth_r'


data_true = data_true * (data_max - data_min) + data_min
data_pred = data_pred * (data_max - data_min) + data_min
print(data_true.min(), data_true.max())
print(data_pred.min(), data_pred.max())

np.save(f'result array/{GROUP}.npy', data_pred)


os.makedirs(f'result img/{GROUP}', exist_ok=True)

for i in range(len(data_interp_list)):
    print(i, data_true[i].shape, data_pred[i].shape)
    print(i, data_pred[i].min(), data_pred[i].max())
    
    plt.figure()
    plt.imshow(data_true[i], cmap=cmap_name)    # Spectral
    plt.axis('off')
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    cbar = plt.colorbar(shrink=.92)
    cbar.outline.set_visible(False)
    plt.axis('off')
    plt.savefig(f'result img/{GROUP}/{i+1}_true.png', dpi=300, bbox_inches='tight')
    plt.close()


    norm = mpl.colors.Normalize(vmin=data_true[i].min(), vmax=data_true[i].max())
    plt.figure()
    plt.imshow(data_pred[i], cmap=cmap_name, norm=norm)    # Spectral
    plt.axis('off')
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    cbar = plt.colorbar(shrink=.92)
    cbar.outline.set_visible(False)
    plt.axis('off')
    plt.savefig(f'result img/{GROUP}/{i+1}_pred.png', dpi=300, bbox_inches='tight')
    plt.close()