## E4040 2024 Fall Project
### Improving CNN Robustness via CS Shapley Value-guided Augmentation

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import pandas as pd

tf.test.gpu_device_name()

### Import CIFAR-10 data and training ResNet18

In [None]:
from utils.ResNet18_trainer import ResNet18_trainer, load_cifar10_dataset

In [None]:
batch_size = 128
lr = 0.01
momentum = 0.9
decay = 0.0005
log_period = 100
epochs=160
num_classes=10

In [None]:
train_ds, test_ds = load_cifar10_dataset(batch_size)

In [None]:
trainer = ResNet18_trainer(
    train_ds=train_ds,
    test_ds=test_ds,
    num_classes=num_classes,
    epochs=epochs,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    decay=decay
)

In [None]:
trainer.run()

In [None]:
trainer.model.save("~/save_model")
print("Model saved to ~/save_model")

### Calculation of Shapley values

In [None]:
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import load_model
from pathlib import Path
import time
#########from utils.utils import set_seed, cifar10_std, cifar10_mean
from tensorflow.keras.preprocessing.image import save_img


from utils.Shapley_sample import getShapley_pixel, getShapley_freq, getShapley_freq_dis, sample_mask, getShapley_freq_softmax, visual_shap


In [None]:
model_name = 'ResNet18'
dataset_name = 'cifar10'
data_path = '~/data'
output_path = './output'
model_path = "./save_model"
sample_times = 2000
num_per_class = 150
#testdata = False
mask_size = 16
n_per_batch = 1
start_num = 0
get_freq_by_dis = False
fix_mask = False
split_n = 1
static_center = False
#norm = True
batchsize = 64
seed = 111

set_seed(111)


In [None]:

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
#nclass = 10

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batchsize)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batchsize)

In [None]:
# Load model
model = tf.keras.models.load_model(model_path)


In [None]:
count = [0 for _ in range(num_classes)]

dataset = test_dataset
for img, y in dataset:
    bs, w, h, c = img.shape
    img, y = tf.convert_to_tensor(img), tf.convert_to_tensor(y)

    for k in range(bs):
        if count[y[k].numpy()[0]] < 0:
            count[y[k].numpy()[0]] += 1
        elif count[y[k].numpy()[0]] < 50:
            shap_value = getShapley_freq_softmax(
                img, y, model, sample_times, mask_size, k,
                n_per_batch=n_per_batch,
                split_n=split_n,
                static_center=False
            )

            shap_path = os.path.join('', "shap_result", f"{y[k].numpy()[0]}")
            Path(shap_path).mkdir(parents=True, exist_ok=True)
            np.save(os.path.join(shap_path, f"{count[y[k].numpy()[0]]}_freq.npy"), shap_value)

            visual_path = os.path.join(shap_path, f"{count[y[k].numpy()[0]]}_freq_shap.png")
            visual_shap(shap_value, mask_size, mask_size, visual_path)

            raw_img = img[k].numpy()
            img_path = os.path.join(shap_path, f"{count[y[k].numpy()[0]]}.png")
            #if norm:
            #    for j in range(3):
            #        raw_img[:, :, j] = raw_img[:, :, j] * cifar10_std[j] + cifar10_mean[j]
            #save_img(img_path, raw_img)

            count[y[k].numpy()[0]] += 1
            print(f"Class {y[k].numpy()[0]} sample {count[y[k].numpy()[0]]} completed.")


### Reconstructing Shapley Values

### Train ResNet18 under AT with CSA