## E4040 2024 Fall Project
### Improving CNN Robustness via CS Shapley Value-guided Augmentation

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import pandas as pd

tf.test.gpu_device_name()

'/device:GPU:0'

### Import CIFAR-10 data and training ResNet18

In [2]:
from utils.ResNet18_trainer import ResNet18_trainer, load_cifar10_dataset#, delete_previous_checkpoints

In [3]:
batch_size = 128
lr = 0.1
momentum = 0.9
decay = 0.0005
log_period = 100
epochs=2  #120
num_classes=10
checkpoint_dir="./checkpoints"

In [4]:
train_ds, test_ds = load_cifar10_dataset(batch_size)

In [5]:
trainer = ResNet18_trainer(
    train_ds=train_ds,
    test_ds=test_ds,
    num_classes=num_classes,
    epochs=epochs,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    decay=decay,
    checkpoint_dir=checkpoint_dir
)

In [6]:
###########
if os.path.exists(checkpoint_dir):
    for filename in os.listdir(checkpoint_dir):
        file_path = os.path.join(checkpoint_dir, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)
    print("Deleted previous checkpoints.")
    

Deleted previous checkpoints.


In [7]:
trainer.run()

No checkpoint found. Starting from scratch.
Training Epoch 1
Epoch 1, Loss: 1.72958505153656, Accuracy: 38.86600112915039, Test Loss: 1.3595706224441528, Test Accuracy: 49.97999954223633
Checkpoint saved at: ./checkpoints/ckpt_epoch_1.h5
Training Epoch 2
Epoch 2, Loss: 1.6087888479232788, Accuracy: 45.30799865722656, Test Loss: 1.3945369720458984, Test Accuracy: 52.52000045776367
Checkpoint saved at: ./checkpoints/ckpt_epoch_2.h5


In [20]:
trainer.model.save("~/save_model")
print("Model saved to ~/save_model")

INFO:tensorflow:Assets written to: save_model/assets
Model saved


In [18]:
#trainer.model.build((None, 32, 32, 3))
#dummy_input = tf.random.normal((1, 32, 32, 3))
#trainer.model(dummy_input, training=False)

<tf.Tensor: shape=(1, 10), dtype=float32, numpy=
array([[8.2136859e-04, 9.8826218e-05, 8.2579497e-03, 3.1761732e-03,
        8.6224480e-03, 4.8236409e-04, 9.7743452e-01, 2.0195763e-04,
        3.0673956e-04, 5.9767353e-04]], dtype=float32)>

### Calculation of Shapley values

In [None]:
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import load_model
from pathlib import Path
import time
#########from utils.utils import set_seed, cifar10_std, cifar10_mean
from utils.ResNet18_trainer import std, mean
from tensorflow.keras.preprocessing.image import save_img


from utils.Shapley_sample import getShapley_pixel, getShapley_freq, getShapley_freq_dis, sample_mask, getShapley_freq_softmax, visual_shap


In [None]:
model_name = 'ResNet18'
dataset_name = 'cifar10'
data_path = '~/data'
output_path = './output'
model_path = "./save_model"
sample_times = 2000
num_per_class = 150
#testdata = False
mask_size = 16
n_per_batch = 1
start_num = 0
get_freq_by_dis = False
fix_mask = False
split_n = 1
static_center = False
#norm = True
batchsize = 64
seed = 111

set_seed(111)


In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batchsize)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batchsize)

In [None]:
# Load model
model = tf.keras.models.load_model(model_path)


In [None]:
count = [0 for _ in range(num_classes)]

dataset = test_dataset
for img, y in dataset:
    bs, w, h, c = img.shape
    img, y = tf.convert_to_tensor(img), tf.convert_to_tensor(y)

    for k in range(bs):
        if count[y[k].numpy()[0]] < 0:
            count[y[k].numpy()[0]] += 1
        elif count[y[k].numpy()[0]] < 50:
            shap_value = getShapley_freq_softmax(
                img, y, model, sample_times, mask_size, k,
                n_per_batch=n_per_batch,
                split_n=split_n,
                static_center=False
            )

            shap_path = os.path.join('', "shap_result", f"{y[k].numpy()[0]}")
            Path(shap_path).mkdir(parents=True, exist_ok=True)
            np.save(os.path.join(shap_path, f"{count[y[k].numpy()[0]]}_freq.npy"), shap_value)

            visual_path = os.path.join(shap_path, f"{count[y[k].numpy()[0]]}_freq_shap.png")
            visual_shap(shap_value, mask_size, mask_size, visual_path)

            raw_img = img[k].numpy()
            img_path = os.path.join(shap_path, f"{count[y[k].numpy()[0]]}.png")
            
            for j in range(3):
                raw_img[:, :, j] = raw_img[:, :, j] * std[j] + mean[j]
            save_img(img_path, raw_img)

            count[y[k].numpy()[0]] += 1
            print(f"Class {y[k].numpy()[0]} sample {count[y[k].numpy()[0]]} completed.")


### Reconstructing Shapley Values

In [None]:
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.image import resize
from pathlib import Path
import math
import matplotlib.pyplot as plt
from tensorflow.signal import fft2d, ifft2d
from tensorflow.keras.utils import save_img

In [None]:
path = "your_data_path"
recon_path = os.path.join(path, "reconstruction")
ifft_path = os.path.join(path, "ifft")
shap_path = os.path.join(path, "shap_result")

In [None]:
os.makedirs(recon_path, exist_ok=True)
os.makedirs(ifft_path, exist_ok=True)

In [None]:
train_norm = False  # 替代 cfg.train.norm
cifar10_mean = [0.4914, 0.4822, 0.4465]
cifar10_std = [0.2023, 0.1994, 0.2010]

In [None]:
# 定义数据增强操作
def preprocess_image(img_path):
    img = load_img(img_path)
    img = img_to_array(img) / 255.0
    img = resize(img, [32, 32])  # 假设目标是 CIFAR-10 大小
    return tf.convert_to_tensor(img, dtype=tf.float32)

In [None]:
# 遍历类文件夹
classes = os.listdir(shap_path)

In [None]:
# 处理每个类
for class_name in classes:
    result_path = os.path.join(shap_path, class_name)
    if not os.path.isdir(result_path):
        continue

    files = os.listdir(result_path)
    for file in files:
        if ".png" not in file or "shap" in file:
            continue

        # 加载图片
        img_path = os.path.join(result_path, file)
        img = preprocess_image(img_path)

        # FFT 转换
        freq = fft2d(tf.cast(img, tf.complex64))

        # 加载频域 shap 值
        freq_shap_file = file.split('.')[0] + "_freq.npy"
        freq_shap_path = os.path.join(result_path, freq_shap_file)
        freq_shap = np.load(freq_shap_path)
        mask_size = int(math.sqrt(freq_shap.size))
        freq_shap = tf.reshape(freq_shap, [1, mask_size, mask_size])

        # 扩展为多通道
        freq_shap = tf.tile(freq_shap, [3, 1, 1])  # 假设 3 个通道
        freq_shap = tf.image.resize(freq_shap, [img.shape[0], img.shape[1]])

        # 构造正负掩码
        mask_pos = tf.cast(freq_shap > 0, tf.complex64)
        mask_neg = tf.cast(freq_shap < 0, tf.complex64)

        # 去除中心频率
        mask_pos = tf.tensor_scatter_nd_update(
            mask_pos, [[mask_pos.shape[1] // 2, mask_pos.shape[2] // 2]], [0]
        )
        mask_neg = tf.tensor_scatter_nd_update(
            mask_neg, [[mask_neg.shape[1] // 2, mask_neg.shape[2] // 2]], [0]
        )

        # 生成正负频域图片
        pos_freq = freq * mask_pos
        neg_freq = freq * mask_neg

        pos_img = tf.math.real(ifft2d(pos_freq))
        neg_img = tf.math.real(ifft2d(neg_freq))

        # 去标准化
        for c in range(3):
            pos_img[..., c] = pos_img[..., c] * std[c] + mean[c]
            neg_img[..., c] = neg_img[..., c] * std[c] + mean[c]

        # 保存图片
        recon_class_path = os.path.join(recon_path, class_name)
        os.makedirs(recon_class_path, exist_ok=True)
        save_img(os.path.join(recon_class_path, f"{file.split('.')[0]}_pos.png"), pos_img)
        save_img(os.path.join(recon_class_path, f"{file.split('.')[0]}_neg.png"), neg_img)

### Train ResNet18 under AT with CSA