<a href="https://colab.research.google.com/github/iskra3138/ImageSr/blob/master/XAI_%ED%95%99%EC%8A%B5_DEMO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#@title DEMO 사전 준비 [Run Me!!!!!]
%tensorflow_version 2.x

import tensorflow as tf
print("Tensorflow version " + tf.__version__)

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.experimental.TPUStrategy(tpu)

import re
import numpy as np
from matplotlib import pyplot as plt

BUCKET = "gs://iskra3138_share"


import cv2

class GradCAM:
  def __init__(self, model, activation_layer):
    self.model = model
    self.activation_layer = activation_layer
    self.tensor_function = self._get_gradcam_tensor_function()

  # get partial tensor graph of CNN model
  def _get_gradcam_tensor_function(self):
    model_input = self.model.input
    y_c = self.model.output
    A_k = self.model.get_layer(self.activation_layer).output

    tensor_function = tf.keras.models.Model([model_input], [A_k, y_c])
    return tensor_function

  # generate Grad-CAM
  def generate(self, input_tensor):
    preds = self.model.predict(input_tensor)[0]
    class_idx = np.argmax(preds)
    
    with tf.GradientTape() as tape:
      conv_outputs, predictions = self.tensor_function(input_tensor) 
      loss = predictions[:, class_idx]

    output = conv_outputs[0]
    
    grads = tape.gradient(loss, conv_outputs)[0]    
    weights = np.mean(grads, axis=(0, 1))
    
    grad_cam = np.dot(output, weights)

    grad_cam = np.maximum(grad_cam, 0)
    grad_cam = cv2.resize(grad_cam, (224, 224))
    return grad_cam, preds, class_idx
  
  
import os

AUTO = tf.data.experimental.AUTOTUNE


IMG_WIDTH = 224
IMG_HEIGHT = 224
IMAGE_SIZE =  [IMG_HEIGHT, IMG_WIDTH]

batch_size = 8 * tpu_strategy.num_replicas_in_sync

## 본 실험에서는 16개의 tfrecord파일을 train/validation용으로 나눠서 사용합니다.
## train전용, validation전용 tfrecord 파일들이 있으면 특정해서 list 로 넘기시면 됩니다.
gcs_pattern = os.path.join(BUCKET, '*.tfrec')
validation_split = 0.19
filenames = tf.io.gfile.glob(gcs_pattern)
split = len(filenames) - int(len(filenames) * validation_split)
train_fns = filenames[:split]
validation_fns = filenames[split:]

def parse_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "file_name": tf.io.FixedLenFeature([], tf.string),  # one bytestring
        "label_name": tf.io.FixedLenFeature([], tf.string),  # one bytestring
        "label": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar, one integer
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    
    # FixedLenFeature fields are now ready to use: exmple['size']
    # VarLenFeature fields require additional sparse_to_dense decoding
    
    label = example['label']
    label = tf.one_hot(indices=label,
                      depth=5
                      )   
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32) ## make [0,255] to [0,1) resize 앞에 위치할 때만 [0,1), 즉 input이 float32가 아니어야 작동
    image = tf.image.resize(image, IMAGE_SIZE) ## method가 tf.image.ResizeMethod.NEAREST_NEIGHBOR 가 아니면 출력은 무조건 float32
    
    #file_name  = example['file_name']
    #label_name  = example['label_name']
    
    return image, label

def load_dataset(filenames):
  # Read from TFRecords. For optimal performance, we interleave reads from multiple files.
  records = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
  return records.map(parse_tfrecord, num_parallel_calls=AUTO)

def get_training_dataset():
  dataset = load_dataset(train_fns)

  # Create some additional training images by randomly flipping and
  # increasing/decreasing the saturation of images in the training set. 
  def data_augment(image, label):
    modified = tf.image.random_flip_left_right(image)
    modified = tf.image.random_flip_up_down(modified)
    return modified, label
  augmented = dataset.map(data_augment, num_parallel_calls=AUTO)

  # Prefetch the next batch while training (autotune prefetch buffer size).
  return augmented.repeat().shuffle(2048).batch(batch_size).prefetch(AUTO) 

training_dataset = get_training_dataset()
validation_dataset = load_dataset(validation_fns).batch(batch_size).prefetch(AUTO)

CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

def display_one(image, title, subplot, color):
  plt.subplot(subplot)
  plt.axis('off')
  plt.imshow(image)
  plt.title(title, fontsize=16, color=color)
  
# If model is provided, use it to generate predictions.
def display_nine(images, titles, title_colors=None):
  subplot = 331
  plt.figure(figsize=(13,13))
  for i in range(9):
    color = 'black' if title_colors is None else title_colors[i]
    display_one(images[i], titles[i], 331+i, color)
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()

def get_dataset_iterator(dataset, n_examples):
  return dataset.unbatch().batch(n_examples).as_numpy_iterator()

training_viz_iterator = get_dataset_iterator(training_dataset, 9)

def create_model():
  #pretrained_model = tf.keras.applications.ResNet101(weights='imagenet', input_shape=[*IMAGE_SIZE, 3], include_top=False)
  pretrained_model = tf.keras.applications.Xception(input_shape=[*IMAGE_SIZE, 3], include_top=False)
  pretrained_model.trainable = True
  x = pretrained_model.output
  x = tf.keras.layers.GlobalAveragePooling2D()(x)
  predictions = tf.keras.layers.Dense(5, activation='softmax', name='prediction')(x)
  model = tf.keras.Model(inputs=pretrained_model.input, outputs=predictions)
  
  #optimizer='adam',
  #optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=1e-4),
    
  model.compile(
    optimizer='adam',  
    loss = 'categorical_crossentropy',
    metrics=['accuracy']
  )
  return model

with tpu_strategy.scope(): # creating the model in the TPUStrategy scope means we will train the model on the TPU
  model = create_model()
weights = model.get_weights()

def count_data_items(filenames):
  # The number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
  n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
  return np.sum(n)

n_train = count_data_items(train_fns)
n_valid = count_data_items(validation_fns)

train_steps = n_train // batch_size
valid_steps = n_valid // batch_size
#train_steps = count_data_items(train_fns) // batch_size
#valid_steps = count_data_items(validation_fns) // batch_size

print("TRAINING IMAGES: ", n_train, ", STEPS PER EPOCH: ", train_steps)
print("VALIDATION IMAGES: ", n_valid, ", STEPS PER EPOCH: ", valid_steps)

activation_layer = 'block14_sepconv2_act' 

EPOCHS = 12

start_lr = 0.00001
min_lr = 0.00001
max_lr = 0.00005 * tpu_strategy.num_replicas_in_sync
rampup_epochs = 5
sustain_epochs = 0
exp_decay = .8

def lrfn(epoch):
  if epoch < rampup_epochs:
    return (max_lr - start_lr)/rampup_epochs * epoch + start_lr
  elif epoch < rampup_epochs + sustain_epochs:
    return max_lr
  else:
    return (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr
    
lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True)

rang = np.arange(EPOCHS)
y = [lrfn(x) for x in rang]
plt.plot(rang, y)
print('Learning rate per epoch:')



In [0]:
#@title Class별 Sampling [Run Me!!!!]

from google.colab import widgets

## 5개의 클래스마다 1장씩 샘플링 합니다.
sample_images=[]
sample_labels=[]
for i in range(5):
  def sampling(image, label):
    return tf.math.equal(tf.argmax(label, axis=0), i)
    #return tf.argmax(label, axis=0) == i

  sample_iterator = load_dataset(validation_fns).shuffle(100).filter(sampling).batch(1).as_numpy_iterator()
  image, label = next(sample_iterator)
  for j in range(1):
    sample_images.append(image[j])
    sample_labels.append(label[j])
        
gradcam_gen = GradCAM(model, activation_layer)

grid = widgets.Grid(2, 5, header_row=True, header_column=True)
grads, preds, class_idxs = [], [], []
for i in range(5):
  img_tensor = np.expand_dims(sample_images[i], axis=0)
  grad, pred, class_idx = gradcam_gen.generate(img_tensor)
  grads.append(grad)
  preds.append(pred)
  class_idxs.append(class_idx)

for n in range (5) :
  with grid.output_to(0, n):
    plt.imshow(sample_images[n])
    plt.title(CLASSES[np.argmax(sample_labels[n])].title())
    plt.axis('off')
  with grid.output_to(1, n):
    for pred in preds[n]:
      print ('{:.2f}'.format(pred), end=',')
    print ('\n')
      
    plt.imshow(sample_images[n])
    plt.imshow(grads[n], cmap='jet', alpha=0.5)
    plt.title(CLASSES[class_idxs[n]].title())
    plt.axis('off')
    

# 학습 DEMO

In [0]:
#@title 학습 DEMO [Run ME!!!]
from google.colab import widgets

model.set_weights(weights)
for i in range(12):
  history = model.fit(training_dataset, validation_data=validation_dataset,
                    steps_per_epoch=train_steps, epochs=i+1, initial_epoch=i, callbacks=[lr_callback])
  
  gradcam_gen = GradCAM(model, activation_layer)
  
  grid = widgets.Grid(2, 5, header_row=True, header_column=True)
  grads, preds, class_idxs = [], [], []
  for i in range(5):
    img_tensor = np.expand_dims(sample_images[i], axis=0)
    grad, pred, class_idx = gradcam_gen.generate(img_tensor)
    grads.append(grad)
    preds.append(pred)
    class_idxs.append(class_idx)

  for n in range (5) :
    with grid.output_to(0, n):
      plt.imshow(sample_images[n])
      plt.title(CLASSES[np.argmax(sample_labels[n])].title())
      plt.axis('off')
    with grid.output_to(1, n):
      for pred in preds[n]:
        print ('{:.2f}'.format(pred), end=',')
      print ('\n')

      plt.imshow(sample_images[n])
      plt.imshow(grads[n], cmap='jet', alpha=0.5)
      plt.title(CLASSES[class_idxs[n]].title())
      plt.axis('off')
