# Preprocess

In [None]:
import numpy as np
import tensorflow as tf
import dataset
from tensorflow import estimator as tf_estimator
import models.losses as losses
import tensorflow as tf
from models.metrics import *
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from seggradcam.seggradcam import SegGradCAM, ClassRoI, PixelRoI
from seggradcam.seggradcam_block import SegGradCAM as SegGradCAM_block
from seggradcam.seggradcam_block import ClassRoI as ClassRoI_block
from seggradcam.seggradcam_block import PixelRoI as PixelRoI_block
from seggradcam.visualize_sgc import SegGradCAMplot
from matplotlib import colors
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
predict_threshold = 0.5

hparams = {
    # 数据路径
    'train_path': '../dataset/next_day_wildfire_spread_train*',
    'eval_path': '../dataset/next_day_wildfire_spread_eval*',
    'test_path': '../dataset/next_day_wildfire_spread_test*',
    
    # 特征
    'input_features': ['elevation', 'pdsi', 'NDVI', 'pr', 'sph', 'th', 'tmmn',
                  'tmmx', 'vs', 'erc', 'population', 'PrevFireMask'],
    'output_features': ['FireMask'],
    
    # 方位通道
    'azimuth_in_channel': None,
    'azimuth_out_channel': None,
    
    # 数据和模型参数
    'data_sample_size': 64,
    'sample_size': 32,
    'output_sample_size': 32,
    'batch_size': 128,
    'shuffle': False,
    'shuffle_buffer_size': 10000,
    'compression_type': None,
    'input_sequence_length': 1,
    'output_sequence_length': 1,
    'repeat': False,
    'clip_and_normalize': True,
    'clip_and_rescale': False,
    
    # 数据增强
    'random_flip': False,
    'random_rotate': False,
    'random_crop': False,
    'center_crop': True,
    
    # 其他参数
    'downsample_threshold': 0.0,
    'binarize_output': True
}

TITLES = [
  'Elevation',
  'Wind\ndirection',
  'Wind\nvelocity',
  'Min\ntemp',
  'Max\ntemp',
  'Humidity',
  'Precip',
  'Drought',
  'Vegetation',
  'Population\ndensity',
  'Energy\nrelease\ncomponent',
  'Previous\nfire\nmask',
  'Fire\nmask',
  'Predict\nmask',
]
# Number of rows of data samples to plot
n_rows = 30
# Number of data variables
n_features = 12
# Variables for controllong the color map for the fire masks

test_dataset = dataset.make_dataset(
    hparams,
    mode = tf_estimator.ModeKeys.PREDICT
)

# Integrated Gradients

In [None]:
image_id = 46
image_list = [5, 15, 17, 21, 24, 25, 29, 33, 35, 36, 38, 39, 41, 42, 43, 44, 45, 46, 125]
inputs, labels = next(iter(test_dataset))
image =  inputs[image_id]

In [None]:
def integral_approximation(gradients):
    # Using the trapezoidal rule to approximate the integral of the gradients
    grads = (gradients[:-1] + gradients[1:]) / 2.0
    integrated_gradients = np.mean(grads, axis=0)
    return integrated_gradients


def integrated_gradients(input_image, model, baseline=None, steps=500):
    """
    计算集成梯度。

    :param input_image: 输入图像，形状为 (32, 32, 12)。
    :param model: 训练好的模型。
    :param baseline: 基线图像，用于比较。如果为 None，则使用全零图像。
    :param steps: 集成梯度的步数。
    :return: 集成梯度。
    """
    # 如果没有提供基线图像，则使用全零图像
    if baseline is None:
        baseline = np.zeros(input_image.shape)

    # 线性插值
    interpolated_images = np.array([baseline + (step / steps) * (input_image - baseline) for step in range(steps + 1)])
    interpolated_images = tf.convert_to_tensor(interpolated_images, dtype=tf.float32)
    # 计算预测
    with tf.GradientTape() as tape:
        tape.watch(interpolated_images)
        logits = model(interpolated_images)
        predictions = tf.math.sigmoid(logits)

    # 计算梯度
    gradients = tape.gradient(predictions, interpolated_images)

    # 计算步长的平均梯度
    avg_gradients = integral_approximation(gradients)

    # 计算集成梯度
    integrated_gradients = (input_image - baseline) * avg_gradients

    return integrated_gradients

In [None]:
autoencoder = tf.keras.models.load_model('saved_model/autoencoder_model', custom_objects={
    'masked_weighted_cross_entropy_with_logits': losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=3),
    'AUCWithMaskedClass': AUCWithMaskedClass(with_logits=True)
})
unet = tf.keras.models.load_model('saved_model/unet_model', custom_objects={
    'masked_weighted_cross_entropy_with_logits': losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=3),
    'AUCWithMaskedClass': AUCWithMaskedClass(with_logits=True)
})
resnet = tf.keras.models.load_model('saved_model/resnet_model', custom_objects={
    'masked_weighted_cross_entropy_with_logits': losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=3),
    'AUCWithMaskedClass': AUCWithMaskedClass(with_logits=True)
})
vit = tf.keras.models.load_model('saved_model/vit_model', custom_objects={
    'masked_weighted_cross_entropy_with_logits': losses.weighted_cross_entropy_with_logits_with_masked_class(pos_weight=3),
    'AUCWithMaskedClass': AUCWithMaskedClass(with_logits=True)
})

In [None]:
result = pd.DataFrame()
for i in range(len(image_list)):
    image = inputs[image_list[i]]
    ae_ig_attributions = integrated_gradients(image, autoencoder)
    resnet_ig_attributions = integrated_gradients(image, resnet)
    unet_ig_attributions = integrated_gradients(image, unet)

    ae_feature_importance = np.sum(ae_ig_attributions, axis=(0, 1))
    resnet_feature_importance = np.sum(resnet_ig_attributions, axis=(0, 1))
    unet_feature_importance = np.sum(unet_ig_attributions, axis=(0, 1))

    all_feature_importance = np.vstack([ae_feature_importance, resnet_feature_importance, unet_feature_importance])
    df = pd.DataFrame(all_feature_importance)
    df.insert(0, 'Model', ['AutoEncoder', 'Resnet', 'Unet'])
    df.insert(0, 'Image_ID', [image_list[i], image_list[i], image_list[i]])

    result = pd.concat([result, df], axis=0)

    fig, axes = plt.subplots(1, 12, figsize=(40,3))
    for j in range(12):
        ax = axes[j]
        im = ax.imshow(np.where(ae_ig_attributions > 0, ae_ig_attributions, 0)[:, :, j], cmap='viridis')
        fig.colorbar(im, ax=ax)
        ax.set_title(f'Feature {j+1}')
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(f'./ig/{image_list[i]}_ae.png')
    # plt.show()

    fig, axes = plt.subplots(1, 12, figsize=(40,3))
    for j in range(12):
        ax = axes[j]
        im = ax.imshow(np.where(resnet_ig_attributions > 0, resnet_ig_attributions, 0)[:, :, j], cmap='viridis')
        fig.colorbar(im, ax=ax)
        ax.set_title(f'Feature {j+1}')
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(f'./ig/{image_list[i]}_resnet.png')
    # plt.show()

    fig, axes = plt.subplots(1, 12, figsize=(40,3))
    for j in range(12):
        ax = axes[j]
        im = ax.imshow(np.where(unet_ig_attributions > 0, unet_ig_attributions, 0)[:, :, j], cmap='viridis')
        fig.colorbar(im, ax=ax)
        ax.set_title(f'Feature {j+1}')
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(f'./ig/{image_list[i]}_unet.png')
    # plt.show()

result.to_csv("./ig/feature_importance.csv")

In [None]:
result

In [None]:
# ae_ig_attributions = integrated_gradients(image, autoencoder)
# resnet_ig_attributions = integrated_gradients(image, resnet)
# unet_ig_attributions = integrated_gradients(image, unet)
# # vit_ig_attributions = integrated_gradients(image, vit)

In [None]:
# ae_feature_importance = np.sum(ae_ig_attributions, axis=(0, 1))
# resnet_feature_importance = np.sum(resnet_ig_attributions, axis=(0, 1))
# unet_feature_importance = np.sum(unet_ig_attributions, axis=(0, 1))
# # vit_feature_importance = np.sum(vit_ig_attributions, axis=(0, 1))

# # 打印每个通道的重要性
# print(ae_feature_importance)
# print(resnet_feature_importance)
# print(unet_feature_importance)
# # print(vit_feature_importance)

In [None]:
# # 将 NumPy 数组转换为 Pandas DataFrame
# result = pd.DataFrame()
# all_feature_importance = np.vstack([ae_feature_importance, resnet_feature_importance, unet_feature_importance])
# df = pd.DataFrame(all_feature_importance)
# df.insert(0, 'Model', ['AutoEncoder', 'Resnet', 'Unet'])
# df.insert(0, 'Image_ID', [image_id, image_id, image_id])

# result = pd.concat([result, df], axis = 0)
# result



In [None]:
# fig, axes = plt.subplots(1, 12, figsize=(40,3))
# for i in range(12):
#     ax = axes[i]
#     im = ax.imshow(np.where(ae_ig_attributions > 0, ae_ig_attributions, 0)[:, :, i], cmap='viridis')
#     fig.colorbar(im, ax=ax)
#     ax.set_title(f'Feature {i+1}')
#     ax.axis('off')
# plt.tight_layout()
# plt.savefig(f'./ig/{image_id}_ae.png', dpi=300)
# plt.show()

# fig, axes = plt.subplots(1, 12, figsize=(40,3))
# for i in range(12):
#     ax = axes[i]
#     im = ax.imshow(np.where(resnet_ig_attributions > 0, resnet_ig_attributions, 0)[:, :, i], cmap='viridis')
#     fig.colorbar(im, ax=ax)
#     ax.set_title(f'Feature {i+1}')
#     ax.axis('off')
# plt.tight_layout()
# plt.savefig(f'./ig/{image_id}_resnet.png')
# plt.show()

# fig, axes = plt.subplots(1, 12, figsize=(40,3))
# for i in range(12):
#     ax = axes[i]
#     im = ax.imshow(np.where(unet_ig_attributions > 0, unet_ig_attributions, 0)[:, :, i], cmap='viridis')
#     fig.colorbar(im, ax=ax)
#     ax.set_title(f'Feature {i+1}')
#     ax.axis('off')
# plt.tight_layout()
# plt.savefig(f'./ig/{image_id}_unet.png')
# plt.show()

# # fig, axes = plt.subplots(1, 12, figsize=(40,3))
# # for i in range(12):
# #     ax = axes[i]
# #     im = ax.imshow(np.where(vit_ig_attributions > 0, vit_ig_attributions, 0)[:, :, i], cmap='viridis')
# #     fig.colorbar(im, ax=ax)
# #     ax.set_title(f'Feature {i+1}')
# #     ax.axis('off')
# # plt.tight_layout()
# # plt.show()

In [None]:
# positive_attributions = np.where(ig_attributions > 0, ig_attributions, 0)

# # 现在对这个只包含正值的数组求和
# feature_importance = np.sum(positive_attributions, axis=(0, 1))

# # 打印每个通道的重要性
# print(feature_importance)

In [None]:
# normalized_distribution
# positive_attributions = np.where(normalized_distribution > 0, normalized_distribution, 0)

# # 现在对这个只包含正值的数组求和
# feature_importance = np.sum(positive_attributions, axis=(0, 1))

# # 打印每个通道的重要性
# print(feature_importance)

In [None]:
# masked_feature_importance = np.sum(masked_ig_attributions, axis=(0, 1))

# # 打印每个通道的重要性
# print(masked_feature_importance)

In [None]:
# import matplotlib.pyplot as plt

# fig, axes = plt.subplots(1, 12, figsize=(40,3))

# # 遍历所有的特征通道
# for i in range(12):
#     ax = axes[i]
#     im = ax.imshow(ig_attributions[:, :, i], cmap='viridis')
#     # 为每个子图添加颜色条
#     fig.colorbar(im, ax=ax)
#     ax.set_title(f'Feature {i+1}')
#     ax.axis('off')

# plt.tight_layout()
# plt.show()

In [None]:
# fig, axes = plt.subplots(1, 12, figsize=(40,3))

# # 遍历所有的特征通道
# for i in range(12):
#     ax = axes[i]
#     im = ax.imshow(positive_attributions[:, :, i], cmap='viridis')
#     # 为每个子图添加颜色条
#     fig.colorbar(im, ax=ax)
#     ax.set_title(f'Feature {i+1}')
#     ax.axis('off')

# plt.tight_layout()
# plt.show()