In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchkeras
from plotly import graph_objects as go
from sklearn.preprocessing import MinMaxScaler

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
from sklearn.metrics import r2_score

import matplotlib as mpl
from matplotlib import cm

# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1.读取数据

In [None]:
data = np.load(r'../POD_Predictor/ningxia-mod13a2_2011-2020.npy')
print(data.min(), data.max())
(W, H) = data.shape[1], data.shape[2]

In [None]:
# 预处理
x, y, z = np.where(data<-3000)
for xx, yy, zz in zip(x, y, z):
    data[x,y,z]=-3000

print(data.min(), data.max())

In [None]:
# 可视化
plt.imshow(data.mean(axis=0), interpolation='nearest', cmap='RdYlGn_r')
plt.colorbar(shrink=.92)

# 2.POD分解

In [None]:
def POD_svd(X):
    '''
    input:
        X : m*n matrix with m features and n snapshot after 中心化
    return:
        k+1 : 累计方差99.5%的特征值个数
        u[:, :k+1] ： 对应特征向量 u[:,i]
        s[:k+1] ： 对应奇异值列表
        vh[:k+1, :] : 时间系数矩阵 vh[i,:]
    '''
    u, s, vh = np.linalg.svd(X, full_matrices=False)
    s1 = s**2 / X.shape[1]
    C_per = np.cumsum(s1) / s1.sum()
    # 求累计99.5%的下标
    k = 0
    for i in range(len(C_per)):
        if C_per[i] > 0.995:
            k = i
            print(C_per[i])
            break
    return k+1, u[:, :k+1], s[:k+1], vh[:k+1, :], C_per

In [None]:
# 定义预测步长
num_predict = 20

In [None]:
# POD分解
XX = data.reshape(data.shape[0], (W*H)).T
XX_mean = np.mean(XX, axis=1, keepdims=True)
XX = XX - XX_mean
print(XX.shape)
k, u, s, vh, s_per = POD_svd(XX[:, :-num_predict])
A = vh.T
k, u.shape, s.shape, A.shape

In [None]:
# 保存累计方差贡献率图
plt.figure(figsize=(10,3))
plt.plot(s_per[:k], marker='o', color='darkorange',markersize=10,markerfacecolor='none') 
plt.grid(True)
plt.tick_params(labelsize=20)
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.xticks([i*5 for i in range(k // 5)] + [k-1])
plt.yticks([s_per[0], s_per[k // 3], s_per[k]])
plt.savefig(f'results/result_feature_values_plot.png', dpi=300, bbox_inches='tight')

In [None]:
# 保存空间模态均值
plt.imshow(np.mean(u[:,:], axis=1, keepdims=True).reshape(W,H), cmap='rainbow')
plt.axis('off')
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
cbar = plt.colorbar(shrink=.92, pad=0.05)
cbar.outline.set_visible(False)
plt.savefig(f'results/result_spatial_mode_avg.png', dpi=300, bbox_inches='tight')

In [None]:
# 保存前三个空间模态
for idx in range(3):
    # idx=2
    plt.figure()
    plt.imshow(u[:,idx].reshape(W,H), cmap='rainbow')
    plt.axis('off')
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    cbar = plt.colorbar(shrink=.92, pad=0.05)
    cbar.outline.set_visible(False)
    plt.savefig(f'results/result_spatial_mode_{idx}.png', dpi=300, bbox_inches='tight')

In [None]:
# 可视化前三个空间模态对应的时间系数序列
plt.plot(vh[0,:])
plt.plot(vh[1,:])
plt.plot(vh[2,:])

# 3.时间系数预插值（线性插值）

In [None]:
# 参数
time_step = 10  # 插值位置 [0,n-2]
time_delta = 0.5    # 插值间隔 (0,1)

# 插值
print(A.shape)
A_interpolation = np.array([A[time_step, i] * time_delta + A[time_step+1, i] * (1-time_delta) for i in range(A.shape[1])])
A_interpolation = A_interpolation.reshape(1, A_interpolation.shape[0])
print(A_interpolation.shape)

# 4.合成未来图像

In [None]:
data_interpolation_list = []

for i in range(A_interpolation.shape[0]):

    img_interpolation = np.dot(u*s, A_interpolation[i])
    img_interpolation = img_interpolation.reshape(W, H)
    
    data_interpolation_list.append(img_interpolation)

data_interpolation = np.array(data_interpolation_list)
data_interpolation.shape

# 5.保存结果

In [None]:
# 保存POD插值结果

np.save(f'result_array.npy', data_interpolation)
for i in range(len(data_interpolation_list)):
    plt.figure()
    plt.imshow(data_interpolation[i], cmap="gist_earth_r")    # Spectral
    plt.axis('off')
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
    plt.margins(0, 0)
    cbar = plt.colorbar(shrink=.92)
    cbar.outline.set_visible(False)
    plt.axis('off')
    plt.savefig(f'results/result_img_{i+1}_true.png', dpi=300, bbox_inches='tight')
    plt.close()