In [1]:
!pip install tensorflow_datasets




In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
from tfkerassurgeon.operations import delete_layer, replace_layer, insert_layer
import tensorflow as tf

from tensorflow.keras.applications import vgg16
from tensorflow.keras.applications.vgg16 import preprocess_input
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
tf.config.experimental.set_memory_growth(physical_devices[0], True)
#tf.config.experimental.set_memory_growth(physical_devices[1], True)


#used to fix bug in keras preprocessing scope
temp = tf.zeros([4, 32, 32, 3])  # Or tf.zeros
preprocess_input(temp)
print("processed")

processed


In [4]:
IMAGE_SIZE = (224, 224)
TRAIN_SIZE = 50000
VALIDATION_SIZE = 10000
BATCH_SIZE_PER_GPU = 96
global_batch_size = (BATCH_SIZE_PER_GPU * 1)
NUM_CLASSES = 10

Dataset code

In [5]:
def flip(x: tf.Tensor) -> tf.Tensor:
    """Flip augmentation

    Args:
        x: Image to flip

    Returns:
        Augmented image
    """
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)

    return x

def color(x: tf.Tensor) -> tf.Tensor:
    """Color augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """
    x = tf.image.random_hue(x, 0.08)
    x = tf.image.random_saturation(x, 0.6, 1.6)
    x = tf.image.random_brightness(x, 0.05)
    x = tf.image.random_contrast(x, 0.7, 1.3)
    return x

def rotate(x: tf.Tensor) -> tf.Tensor:
    """Rotation augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """

    return tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))

def zoom(x: tf.Tensor) -> tf.Tensor:
    """Zoom augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """

    # Generate 20 crop settings, ranging from a 1% to 20% crop.
    scales = list(np.arange(0.8, 1.0, 0.01))
    boxes = np.zeros((len(scales), 4))

    for i, scale in enumerate(scales):
        x1 = y1 = 0.5 - (0.5 * scale)
        x2 = y2 = 0.5 + (0.5 * scale)
        boxes[i] = [x1, y1, x2, y2]

    def random_crop(img):
        # Create different crops for an image
        crops = tf.image.crop_and_resize([img], boxes=boxes, box_indices=np.zeros(len(scales)), crop_size=IMAGE_SIZE)
        # Return a random crop
        return crops[tf.random.uniform(shape=[], minval=0, maxval=len(scales), dtype=tf.int32)]


    choice = tf.random.uniform(())

    # Only apply cropping 50% of the time
    return tf.cond(choice < 0.5, lambda: x, lambda: random_crop(x))

def normalize(input_image):
  return preprocess_input(input_image)

@tf.function
def load_image_train(datapoint):
  input_image, label = tf.image.resize(datapoint["image"], IMAGE_SIZE), datapoint['label']
  # if tf.random.uniform(()) > 0.5:
  #   input_image = tf.image.flip_left_right(input_image)
  augmentations = [flip, color, zoom, rotate]
  for f in augmentations:
    input_image = tf.cond(tf.random.uniform(()) > 0.75, lambda: f(input_image), lambda: input_image)

  #input_image = preprocess_input(input_image)
  input_image = normalize(input_image)

  return input_image, tf.one_hot(label, depth=NUM_CLASSES)

@tf.function
def load_image_test(datapoint):
  input_image, label = tf.image.resize(datapoint["image"], IMAGE_SIZE), datapoint['label']
  #input_image = preprocess_input(input_image)

  input_image = normalize(input_image)

  return input_image, tf.one_hot(label, depth=NUM_CLASSES)

In [6]:
import math
class LayerBatch(tf.keras.utils.Sequence):
    
    def __init__(self, input_model, dataset):
        self.input_model = input_model
        self.dataset = dataset.__iter__()
        
    def __len__(self):
        return math.ceil(TRAIN_SIZE // global_batch_size )
    
    def __getitem__(self, index):
        X, y = self.input_model(next(self.dataset))
        return X, y
    
import math
class LayerTest(tf.keras.utils.Sequence):
    
    def __init__(self, input_model, dataset):
        self.input_model = input_model
        self.dataset = dataset.__iter__()
        
    def __len__(self):
        return math.ceil(VALIDATION_SIZE // global_batch_size )
    
    def __getitem__(self, index):
        X, y = self.input_model(next(self.dataset))
        return X, y

In [7]:
def add_layers(inputs, layers=2):
    X = tf.keras.layers.SeparableConv2D(name=f'sep_conv_{build_replacement.counter}', filters=inputs.get_shape()[-1]//4, 
                                        kernel_size= (3,3),
                                        padding='Same')(inputs)
    #X = tf.keras.layers.BatchNormalization(name=f'batch_norm_{build_replacement.counter}')(X)
    X = tf.keras.layers.ReLU(name=f'relu_{build_replacement.counter}')(X)
    
    build_replacement.counter += 1
    
    X = tf.keras.layers.concatenate([inputs, X])
    for i in range(1, layers):
        X = tf.keras.layers.SeparableConv2D(name=f'sep_conv_{build_replacement.counter}', filters=inputs.get_shape()[-1],
                                            kernel_size=(3,3), 
                                            padding='Same')(X)
        #X = tf.keras.layers.BatchNormalization(name=f'batch_norm_{build_replacement.counter}')(X)
        X = tf.keras.layers.ReLU(name=f'relu_{build_replacement.counter}')(X)
        build_replacement.counter += 1
    
    return X
    
def build_replacement(get_output, layers=2):
    inputs = tf.keras.Input(shape=get_output.output[0].shape[1::])
    
    X = add_layers(inputs, layers)
    replacement_layers = tf.keras.Model(inputs=inputs, outputs=X)
    return replacement_layers

build_replacement.counter = 0

In [8]:

def make_list(X):
    if isinstance(X, list):
        return X
    return [X]

def list_no_list(X):
    if len(X) == 1:
        return X[0]
    return X

def replace_layer(model, replace_layer_subname, replacement_fn,
**kwargs):
    """
    args:
        model :: keras.models.Model instance
        replace_layer_subname :: str -- if str in layer name, replace it
        replacement_fn :: fn to call to replace all instances
            > fn output must produce shape as the replaced layers input
    returns:
        new model with replaced layers
    quick examples:
        want to just remove all layers with 'batch_norm' in the name:
            > new_model = replace_layer(model, 'batch_norm', lambda **kwargs : (lambda u:u))
        want to replace all Conv1D(N, m, padding='same') with an LSTM (lets say all have 'conv1d' in name)
            > new_model = replace_layer(model, 'conv1d', lambda layer, **kwargs: LSTM(units=layer.filters, return_sequences=True)
    """
    model_inputs = []
    model_outputs = []
    tsr_dict = {}

    model_output_names = [out.name for out in make_list(model.output)]

    for i, layer in enumerate(model.layers):
        ### Loop if layer is used multiple times
        for j in range(len(layer._inbound_nodes)):

            ### check layer inp/outp
            inpt_names = [inp.name for inp in make_list(layer.get_input_at(j))]
            outp_names = [out.name for out in make_list(layer.get_output_at(j))]

            ### setup model inputs
            if 'input' in layer.name:
                for inpt_tsr in make_list(layer.get_output_at(j)):
                    model_inputs.append(inpt_tsr)
                    tsr_dict[inpt_tsr.name] = inpt_tsr
                continue

            ### setup layer inputs
            # I added the exception model_3_3/Identity:0 I think the problem is that is the input layer
            inpt = list_no_list([tsr_dict[name]  for name in inpt_names])

            ### remake layer 
            if layer.name in replace_layer_subname:
              if "relu" in layer.name or 'bn' in layer.name:
                print('deleting ' + layer.name)
                x = inpt
              else:
                print('replacing '+layer.name)
                x = replacement_fn(inpt)
            else:
                x = layer(inpt)

            ### reinstantialize outputs into dict
            for name, out_tsr in zip(outp_names, make_list(x)):

                ### check if is an output
                if name in model_output_names:
                    model_outputs.append(out_tsr)
                tsr_dict[name] = out_tsr

    return tf.keras.models.Model(model_inputs, model_outputs)

In [9]:
def replac(inp):
    
    return add_layers(inp, layers=2)

In [10]:
dataset, info = tfds.load('cifar10', with_info=True)

INFO:absl:Overwrite dataset info from restored data version.
INFO:absl:Reusing dataset cifar10 (/root/tensorflow_datasets/cifar10/1.0.2)
INFO:absl:Constructing tf.data.Dataset for split None, from /root/tensorflow_datasets/cifar10/1.0.2


make the upscaled cifar dataset

In [11]:
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train.shuffle(buffer_size=1000).batch(global_batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


test_dataset = dataset['test'].map(load_image_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_dataset = test_dataset.batch(global_batch_size).repeat()
test_dataset = test_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)


In [12]:
model = tf.keras.models.load_model('./base_model_cifar10_resnet50.h5')

In [13]:
model.compile(optimizer=tf.optimizers.SGD(learning_rate=.01, momentum=.9, nesterov=True), loss='mse', metrics=['acc'])

In [14]:
model.evaluate(test_dataset, steps=VALIDATION_SIZE//global_batch_size)

 10/104 [=>............................] - ETA: 46s - loss: 0.0091 - acc: 0.9363

KeyboardInterrupt: 

## Make list of target layers to replace

In [None]:
import pprint
targets = []
for i, layer in enumerate(model.layers):
    if layer.__class__.__name__ == "Conv2D":
        if layer.kernel_size[0] == 3:
            #print(f'{i} layer {layer.name} , kernel size {layer.kernel_size}')
            targets.append({'name': layer.name, 'layer': i})

pprint.pprint(targets)

targets[0]['name'][:-4]

# add all layers to replace
for target in targets:
    for i, layer in enumerate(model.layers):
        if target['name'][:-4] in layer.name:
            if 'to_replace' in target:
                target['to_replace'].append((layer.name, i))
            else:
                target['to_replace'] = [(layer.name, i)]

In [None]:
from hyperdash import monitor_cell

In [None]:
%%monitor_cell "concat_layers"
for target in targets:
    
    print(f"training layer {target['name']}")
    tf.keras.backend.clear_session()
    model = tf.keras.models.load_model('base_model_cifar10_resnet50.h5')
    in_layer, out_layer = target['to_replace'][0][1], target['to_replace'][-1][1]
    get_output = tf.keras.Model(inputs=model.input, outputs=[model.layers[in_layer - 1].output, model.layers[out_layer].output])


    replacement_layers = build_replacement(get_output, layers=2)
    replacement_len = len(replacement_layers.layers)
    layer_train_gen = LayerBatch(get_output, train_dataset)
    layer_test_gen = LayerTest(get_output, test_dataset)




    MSE = tf.losses.MeanSquaredError()

    optimizer=tf.keras.optimizers.SGD(.01, momentum=.9, nesterov=True)
    replacement_layers.compile(loss=MSE, optimizer=optimizer)

    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=5, min_lr=.0001, factor=.3, verbose=1)
    early_stop = tf.keras.callbacks.EarlyStopping(patience=15, min_delta=.001, restore_best_weights=True, verbose=1)
    history = replacement_layers.fit(x=layer_train_gen,
                                   epochs=100,
                                   steps_per_epoch=TRAIN_SIZE // global_batch_size,
                                   validation_data=layer_test_gen,
                                   shuffle=False,
                                   callbacks=[reduce_lr, early_stop],
                                   validation_steps=VALIDATION_SIZE // global_batch_size,
                                   verbose=2)
    target['weights'] = [replacement_layers.layers[1].get_weights(), replacement_layers.layers[3].get_weights()]

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
for target in targets:
    
    print(f'Replacing Layer {target["name"]}')
    
    tf.keras.backend.clear_session()
    
    model = tf.keras.models.load_model('base_model_cifar10_resnet50.h5')
    
    layers_to_replace = [x[0] for x in target['to_replace']]
    weights_to_replace = [target['to_replace'][0][1], target['to_replace'][-1][1]]
    
    new_model = replace_layer(model, layers_to_replace, lambda x: replac(x))
    new_model.layers[weights_to_replace[0]].set_weights(target['weights'][0])
    new_model.layers[weights_to_replace[1]].set_weights(target['weights'][1])
    new_model.compile(optimizer=tf.keras.optimizers.SGD(.1), loss="categorical_crossentropy", metrics=['accuracy'])
    target['score'] = new_model.evaluate(test_dataset, steps=VALIDATION_SIZE // global_batch_size)
    

In [None]:
for target in targets:
    pprint.pprint(f"name: {target['name']} score: {target['score']}")