In [None]:
import tensorflow as tf
from tensorflow.python.keras.utils import data_utils
import numpy as np
import h5py

import randaugment

# need to have imagenet training and validation files in TFRecord format. 

SO_ABS_FILEPATH = ""
IMAGENET_TRAIN_FILEPATH = ""
IMAGENET_VAL_FILEPATH = ""


np.set_printoptions(precision=4)
#tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
      
with tf.device('/cpu:0'):
    vgg16_module = tf.load_op_library(SO_ABS_FILEPATH)

print(tf.version.VERSION)
print(tf.config.experimental.list_physical_devices('GPU'))
print(tf.test.is_built_with_gpu_support)

In [2]:
def get_train_filenames(base_directory = IMAGENET_TRAIN_FILEPATH):

    train_filenames = []

    for i in range(1024):
        if i<10:
            train_filenames.append(base_directory + "/train-0000"+str(i)+"-of-01024")
        elif i<100:
            train_filenames.append(base_directory + "/train-000"+str(i)+"-of-01024")
        elif i<1000:
            train_filenames.append(base_directory + "/train-00"+str(i)+"-of-01024")
        else:
            train_filenames.append(base_directory + "/train-0"+str(i)+"-of-01024")

    return train_filenames

def get_val_filenames(base_directory = IMAGENET_VAL_FILEPATH):

    val_filenames = []

    for i in range(128):
        if i<10:
            val_filenames.append(base_directory + "/validation-0000"+str(i)+"-of-00128")
        elif i<100:
            val_filenames.append(base_directory + "/validation-000"+str(i)+"-of-00128")
        elif i<1000:
            val_filenames.append(base_directory + "/validation-00"+str(i)+"-of-00128")

    return val_filenames


def save_weights_to_file (filename, weights, weights_vel, bias, bias_vel):
        
    weights_hdf5 = h5py.File(filename, 'w')
    weights_hdf5.create_dataset('weights', data = weights.numpy())
    weights_hdf5.create_dataset('weights_vel', data = weights_vel.numpy())
    weights_hdf5.create_dataset('bias', data = bias.numpy())
    weights_hdf5.create_dataset('bias_vel', data = bias_vel.numpy())
    weights_hdf5.close()

def read_weights_from_file(filename):
        
    file = h5py.File(filename, 'r')
    weights = file['weights']
    weights_vel = file['weights_vel']
    bias = file['bias']
    bias_vel = file['bias_vel']
    
    weights = tf.convert_to_tensor(weights[...],tf.float32)
    weights_vel = tf.convert_to_tensor(weights_vel[...],tf.float32)
    bias = tf.convert_to_tensor(bias[...],tf.float32)
    bias_vel = tf.convert_to_tensor(bias_vel[...],tf.float32)
    
    file.close()

    return weights, weights_vel, bias, bias_vel



image_feature_description = {
                                'image/height': tf.io.FixedLenFeature([], tf.int64),
                                'image/width': tf.io.FixedLenFeature([], tf.int64),
                                'image/colorspace': tf.io.FixedLenFeature([], tf.string),
                                'image/channels': tf.io.FixedLenFeature([], tf.int64),
                                'image/class/label': tf.io.FixedLenFeature([], tf.int64),
                                'image/class/synset': tf.io.FixedLenFeature([], tf.string),
                                'image/class/text': tf.io.FixedLenFeature([], tf.string),
                                'image/format': tf.io.FixedLenFeature([], tf.string),
                                'image/filename': tf.io.FixedLenFeature([], tf.string),
                                'image/encoded': tf.io.FixedLenFeature([], tf.string)
                                }


@tf.function
def _parse_single_image_function(example_proto):
    
    # Parse the input tf.Example proto using the given dictionary.
    return tf.io.parse_single_example(example_proto, image_feature_description)


#function found in TF source, possibly modified a bit
def distorted_bounding_box_crop(image,
                                bbox = [[[0,0,1.0,1.0]]],
                                min_object_covered=0.65,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.65, 0.85),
                                max_attempts=10,
                                ):
    # Generates cropped_image using one of the bboxes randomly distorted.
    # See `tf.image.sample_distorted_bounding_box` for more documentation.
    # Args:
    # image_bytes: `Tensor` of binary image data.
    # bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
    #     where each coordinate is [0, 1) and the coordinates are arranged
    #     as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
    #     image.
    # min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
    #     area of the image must contain at least this fraction of any bounding
    #     box supplied.
    # aspect_ratio_range: An optional list of `float`s. The cropped area of the
    #     image must have an aspect ratio = width / height within this range.
    # area_range: An optional list of `float`s. The cropped area of the image
    #     must contain a fraction of the supplied image within in this range.
    # max_attempts: An optional `int`. Number of attempts at generating a cropped
    #     region of the image of the specified constraints. After `max_attempts`
    #     failures, return the entire image.
    # scope: Optional `str` for name scope.
    # Returns:
    # cropped image `Tensor`
    

    shape = tf.shape(image)
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
                                                shape,
                                                bounding_boxes=bbox,
                                                min_object_covered=min_object_covered,
                                                aspect_ratio_range=aspect_ratio_range,
                                                area_range=area_range,
                                                max_attempts=max_attempts,
                                                use_image_if_no_bounding_boxes=True
                                                                            )
    bbox_begin, bbox_size, _ = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    
    image = tf.image.crop_to_bounding_box(image, offset_y, offset_x, target_height, target_width)

    return image

#function found in TF source, possibly modified a bit
def _crop(image, offset_height, offset_width, crop_height, crop_width):
    # Crops the given image using the provided offsets and sizes.
    # Note that the method doesn't assume we know the input image size but it does
    # assume we know the input image rank.
    # Args:
    #   image: an image of shape [height, width, channels].
    #   offset_height: a scalar tensor indicating the height offset.
    #   offset_width: a scalar tensor indicating the width offset.
    #   crop_height: the height of the cropped image.
    #   crop_width: the width of the cropped image.
    # Returns:
    #   the cropped (and resized) image.
    # Raises:
    #   InvalidArgumentError: if the rank is not 3 or if the image dimensions are
    #     less than the crop size.
    
    original_shape = tf.shape(image)
    cropped_shape = tf.stack([crop_height, crop_width, original_shape[2]])
    offsets = tf.cast(tf.stack([offset_height, offset_width, 0]),tf.int32)

    # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
    # define the crop size.
    image = tf.slice(image, offsets, cropped_shape)
    return tf.reshape(image, cropped_shape)

#function found in TF source, possibly modified a bit
def _central_crop(image_list, crop_height, crop_width):
    # Performs central crops of the given image list.
    # Args:
    #   image_list: a list of image tensors of the same dimension but possibly
    #     varying channel.
    #   crop_height: the height of the image following the crop.
    #   crop_width: the width of the image following the crop.
    # Returns:
    #   the list of cropped images.
    
    outputs = []
    for image in image_list:
      image_height = tf.shape(image)[0]
      image_width = tf.shape(image)[1]

      offset_height = (image_height - crop_height) / 2
      offset_width = (image_width - crop_width) / 2

      outputs.append(_crop(image, offset_height, offset_width,
                          crop_height, crop_width))
    return outputs

def _mean_image_subtraction(image, means):
    means = tf.reshape(means,[1,1,3])
    return tf.math.subtract(image,means)

#function found in TF source, possibly modified a bit
def _smallest_size_at_least(height, width, smallest_side=256):
    # Computes new shape with the smallest side equal to `smallest_side`.
    # Computes new shape with the smallest side equal to `smallest_side` while
    # preserving the original aspect ratio.
    # Args:
    #   height: an int32 scalar tensor indicating the current height.
    #   width: an int32 scalar tensor indicating the current width.
    #   smallest_side: A python integer or scalar `Tensor` indicating the size of
    #     the smallest side after resize.
    # Returns:
    #   new_height: an int32 scalar tensor indicating the new height.
    #   new_width: and int32 scalar tensor indicating the new width.
    
    height = tf.cast(height,tf.float32)
    width = tf.cast(width,tf.float32)
    smallest_side = tf.cast(smallest_side,tf.float32)

    scale = tf.cond(tf.greater(height, width),
                    lambda: smallest_side / width,
                    lambda: smallest_side / height)
    new_height = tf.cast(height * scale,tf.int32)
    new_width = tf.cast(width * scale,tf.int32)
    return new_height, new_width

#function found in TF source, possibly modified a bit
def _aspect_preserving_resize(image, smallest_side=256):
    # Resize images preserving the original aspect ratio.
    # Args:
    #   image: A 3-D image `Tensor`.
    #   smallest_side: A python integer or scalar `Tensor` indicating the size of
    #     the smallest side after resize.
    # Returns:
    #   resized_image: A 3-D tensor containing the resized image.
    

    shape = tf.shape(image)
    height = shape[0]
    width = shape[1]
    new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
    resized_image = tf.image.resize(image, [new_height, new_width], 'bicubic')
    return tf.reshape(resized_image, [new_height, new_width, 3])


@tf.function
def preprocess_for_eval(image, pretrained_tf_model = False, output_height = 224, output_width =224, resize_side = 256):
    # Preprocesses the given image for evaluation.
    # Args:
    #   image: A `Tensor` representing an image of arbitrary size.
    #   output_height: The height of the image after preprocessing.
    #   output_width: The width of the image after preprocessing.
    #   resize_side: The smallest side of the image for aspect-preserving resizing.
    # Returns:
    #   A preprocessed image.
    

    image = tf.io.decode_jpeg(image)
    if tf.shape(image)[2] == 1: #assert image is rgb
        image = tf.image.grayscale_to_rgb(image)
    image = _aspect_preserving_resize(image, resize_side)
    image = _central_crop([image], output_height, output_width)[0]
    image.set_shape([output_height, output_width, 3])

    image = tf.clip_by_value(image, 0.0, 255.0)
    image = tf.cast(image,tf.uint8)
    image = tf.cast(image,tf.float32)

    if ( pretrained_tf_model ): _mean_image_subtraction(image, [123.68, 116.779, 103.939])
    else: image = _mean_image_subtraction(image, [124.0, 117.0, 104.0])

    if ( pretrained_tf_model ): image = image[..., ::-1]
    else: image = image/tf.reshape([58.393, 57.12, 57.375],[1,1,3])

    return image


@tf.function
def training_preprocessing(image):
    output_height = 224
    output_width = 224
    resize_side = 256
    
    image = tf.io.decode_jpeg(image)
    if tf.shape(image)[2] == 1: #make sure image is rgb
        image = tf.image.grayscale_to_rgb(image)
    
    image = distorted_bounding_box_crop(image)
    image = _aspect_preserving_resize(image, resize_side)
    image = _central_crop([image], output_height, output_width)[0]
    image.set_shape([output_height, output_width, 3])
    image = tf.clip_by_value(image, 0.0, 255.0)
    image = tf.cast(image,tf.uint8)

    image = randaugment.distort_image_with_randaugment(image, magnitude=20)
    image = tf.cast(image,tf.float32)

    #image = _mean_image_subtraction(image, [123.68, 116.779, 103.939]) #RGB CHANNELS
    image = _mean_image_subtraction(image, [124.0, 117.0, 104.0])
    image = image/tf.reshape([58.393, 57.12, 57.375],[1,1,3]) #divide by std
    return image
    

def parse_and_preprocess_training_image(example_proto, tf_vgg16 = False): # @tf.function here seems slower
    parsed_example = _parse_single_image_function(example_proto)
    image = parsed_example['image/encoded']
    image = training_preprocessing(parsed_example['image/encoded'])

    if(tf_vgg16): label = tf.one_hot(parsed_example['image/class/label'] -1, 1000, dtype=tf.float32)
    else: label = tf.one_hot(parsed_example['image/class/label'] -1, 1000, dtype=tf.int8)
    return (image,label)

def parse_and_preprocess_eval_image(example_proto, tf_vgg16 = False, pretrained_tf_model = False):
    parsed_example = _parse_single_image_function(example_proto)

    image = preprocess_for_eval(parsed_example['image/encoded'], pretrained_tf_model)

    if(tf_vgg16 or pretrained_tf_model): label = tf.one_hot(parsed_example['image/class/label'] -1, 1000, dtype=tf.float32)
    else: label = tf.one_hot(parsed_example['image/class/label'] -1, 1000, dtype=tf.int8)
    
    return (image,label)



def get_val_tf_dataset(times_to_run = 150, prefetch = True):
    val_filenames = get_val_filenames()
    val_dataset = tf.data.Dataset.from_tensor_slices(val_filenames) #dataset containing the filenames

    VAL_BATCH_SIZE = 32 * times_to_run
    val_dataset = val_dataset.interleave(tf.data.TFRecordDataset, cycle_length=8, num_parallel_calls=8, deterministic=False)
    val_dataset = val_dataset.map(parse_and_preprocess_eval_image, num_parallel_calls=8, deterministic=False)
    val_dataset = val_dataset.batch(VAL_BATCH_SIZE)
    if(prefetch): val_dataset = val_dataset.prefetch(1)
    return val_dataset

def get_train_tf_dataset(times_to_run = 100):
    train_filenames = get_train_filenames()
    train_dataset = tf.data.Dataset.from_tensor_slices(train_filenames) #dataset containing the filenames

    TRAIN_BATCH_SIZE = 32*times_to_run
    train_dataset = train_dataset.shuffle(tf.shape(train_filenames,out_type=tf.int64)[0])
    train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.interleave(tf.data.TFRecordDataset, cycle_length=8, num_parallel_calls=8, deterministic=False)
    train_dataset = train_dataset.map( lambda x: parse_and_preprocess_training_image(x,False), num_parallel_calls=8, deterministic=False)
    train_dataset = train_dataset.batch(TRAIN_BATCH_SIZE)
    train_dataset = train_dataset.prefetch(1)
    return train_dataset

In [3]:
def get_vgg_pretrained_weights():

    WEIGHTS_PATH = ('https://storage.googleapis.com/tensorflow/keras-applications/'
                    'vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5')
    weights_path = data_utils.get_file('vgg16_weights_tf_dim_ordering_tf_kernels.h5',WEIGHTS_PATH,cache_subdir='models',file_hash='64373286793e3c8b2b4e3219cbf3544b')
    weight_file = h5py.File(weights_path, 'r')

    #this is basicaly a dictionary which contains 2 levels of keys. Params are found on the lowest levels (weights,biases). Some of the second keys are empty.
    first_lvl_keys = list(weight_file.keys())
    vgg_weights = np.array([],float)
    vgg_bias = np.array([],float)
    for keys1 in first_lvl_keys:
        second_lvl_keys = list(weight_file[keys1].keys())
        if second_lvl_keys:
            vgg_weights = np.concatenate( (vgg_weights, weight_file[keys1][second_lvl_keys[0]][:].flatten()) ) #list contains 16 numpy arrays which have shape(HWCN)
            vgg_bias = np.concatenate( (vgg_bias, weight_file[keys1][second_lvl_keys[1]][:].flatten()) )

    weight_file.close()
    return tf.convert_to_tensor(vgg_weights,tf.float32), tf.convert_to_tensor(vgg_bias,tf.float32)

def win_transform_weights(vgg_weights):
    with tf.device('/cpu:0'):
        vgg_weights = vgg16_module.vgg16_weight_trans(vgg_weights) #compute winograd weights

    return vgg_weights


def initialize_vgg16_weights():
    he_init = tf.keras.initializers.he_normal() # assumes the weight format is (...,C,N) for conv weights and (C,N) for fc weights
    glorot_uni = tf.keras.initializers.GlorotUniform()
    
    weights = tf.reshape(he_init(shape=(3,3,3,64)),[-1])
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,64,64)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,64,128)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,128,128)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,128,256)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,256,256)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,256,256)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,256,512)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,512,512)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,512,512)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,512,512)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,512,512)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((3,3,512,512)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((25088,4096)),[-1])],axis=0 )
    weights = tf.concat( [weights,tf.reshape(he_init((4096,4096)),[-1])],axis=0 )
    
    weights = tf.concat( [weights,tf.reshape(glorot_uni([4096,1000]),[-1])],axis=0 )
    
    weights_vel = tf.zeros([138344128],tf.float32) # 138344128 = total weight params
    bias = tf.zeros([13416],tf.float32) # 13416 = total bias params
    bias_vel = tf.zeros([13416],tf.float32)
    
    return tf.Variable(weights),tf.Variable(weights_vel),tf.Variable(bias),tf.Variable(bias_vel)

def get_train_weights(initial_weights,w_filename):
    if initial_weights == 'pre':
        weights,bias = get_vgg_pretrained_weights()
        bias_vel = tf.zeros([13416],tf.float32)
        weights_vel = tf.zeros([138344128],tf.float32)
    elif initial_weights == 'file':
        weights,weights_vel,bias,bias_vel = read_weights_from_file(w_filename)
    elif initial_weights == 'new':
        weights,weights_vel,bias,bias_vel = initialize_vgg16_weights()
    else: raise ValueError("initial weights arg not recognized (pre,file,new)")

    return weights,bias,weights_vel,bias_vel

def get_val_weights(initial_weights,filename = None):
    if initial_weights == 'pre':
        weights,bias = get_vgg_pretrained_weights()
    elif initial_weights == 'file':
        weights,_,bias,_ = read_weights_from_file(filename)
    else: raise ValueError("initial weights arg not recognized (pre,file)")

    return weights,bias

def get_accuracy(labels,guesses,N,H): #labels(N*H) guesses(N)
    correct = 0
    
    for i in range(N):
        if (labels[ i*H + guesses[i] ] == 1): correct += 1
    return correct/N

#@tf.function
def vgg16_inference(val_dataset,initial_weights,filename = None):

    weights, bias = get_val_weights(initial_weights,filename)

    labels = tf.zeros([0],tf.int8)
    guesses_tensor = tf.zeros([0],tf.int32)
    
    weights = vgg16_module.vgg16_weight_trans(weights) # get winograd weights
    
    for vgg_input in val_dataset:
        labels = tf.concat( [labels,tf.reshape(vgg_input[1],[-1])],axis=0)
        times_to_run = tf.cast(tf.shape(vgg_input[0])[0]/32,tf.int32)
        
        guesses = vgg16_module.vgg16_custom_infer( vgg_input[0],weights,bias,times_to_run)
        
        guesses_tensor = tf.concat([guesses_tensor,guesses],axis=0)
    
    return guesses_tensor,labels


def vgg16_custom_training(  train_dataset,
                            val_dataset,
                            weights,
                            weights_vel,
                            bias,
                            bias_vel,
                            lr,
                            momentum,
                            reg,
                            times_to_run,
                            out_mode,
                            steps,
                            save_filename):

    current_step = 0
    rolling_acc = 0
    
    for batch in train_dataset:
       
        weights,weights_vel,bias,bias_vel,loss,acc = vgg16_module.vgg16_custom_train_normal(batch[0],
                                                                                            batch[1],
                                                                                            weights,
                                                                                            weights_vel,
                                                                                            bias,
                                                                                            bias_vel,
                                                                                            reg,
                                                                                            momentum,
                                                                                            lr,
                                                                                            times_to_run,
                                                                                            out_mode
                                                                                            )
        
        current_step +=1
        
        if rolling_acc == 0:
            rolling_acc = acc
        else:
            rolling_acc = 0.95*rolling_acc + 0.05*acc

        print("loss: " + str(loss[0].numpy()) + " categorical accuracy: " + str(rolling_acc[0].numpy()), end="\r", flush=True)
        
        if (current_step >= steps): 
            save_weights_to_file(weights,weights_vel,bias,bias_vel)
            print("categorical accuracy: " + str(rolling_acc[0].numpy()))
            guesses_tensor,labels = vgg16_inference(val_dataset,weights,bias)
            print(get_accuracy(labels,guesses_tensor, tf.shape(guesses_tensor)[0], 1000))
            return weights,weights_vel,bias,bias_vel

        if ( (current_step * times_to_run) % 37500 == 0 ): save_weights_to_file(save_filename,weights,weights_vel,bias,bias_vel)

        if ( (current_step * times_to_run) % 18800 == 0 ):
            print("categorical accuracy: " + str(rolling_acc[0].numpy()))
            guesses_tensor,labels = vgg16_inference(val_dataset,weights,bias)
            print(get_accuracy(labels,guesses_tensor, tf.shape(guesses_tensor)[0], 1000))
        

                                  
def train_custom_vgg(train_dataset,val_dataset,times_to_run,initial_weights,epochs,lr,save_filename,load_filename=None):
    
    weights, bias, weights_vel, bias_vel = get_train_weights(initial_weights,load_filename)
    
    steps = int((epochs*37500)/times_to_run) #one epoch is around 1.2kk images so 37500 batches of 32
    
    weights,weights_vel,bias,bias_vel = vgg16_custom_training(  train_dataset,
                                                                val_dataset,
                                                                weights,
                                                                weights_vel,
                                                                bias,
                                                                bias_vel,
                                                                lr=lr,
                                                                momentum=0.9,
                                                                reg=4e-4,
                                                                times_to_run=times_to_run,
                                                                out_mode=2,
                                                                steps=steps,
                                                                save_filename=save_filename)
    
    return weights,weights_vel,bias,bias_vel

In [None]:
# Example inference run

weights_filename = "weights filename"

val_dataset = get_val_tf_dataset(times_to_run = 150)

with tf.device('/cpu:0'):
    guesses_tensor,labels = vgg16_inference(val_dataset,"file",weights_filename)
    print(get_accuracy(labels,guesses_tensor, tf.shape(guesses_tensor)[0], 1000))

In [None]:
# Example training run

save_filename = "save filename"
load_filename = "load filename"

times_to_run_train = 100 # times to run c++ code before returning (higher numbers require more RAM but reduce io overhead)
train_dataset = get_train_tf_dataset(times_to_run_train)
val_dataset = get_val_tf_dataset(times_to_run = 20, prefetch=False)

with tf.device('/cpu:0'):
    weights,weights_vel,bias,bias_vel = train_custom_vgg(   train_dataset,
                                                            val_dataset,
                                                            times_to_run_train,
                                                            initial_weights='file',
                                                            epochs=1,
                                                            lr=8e-5, #5e-3 max
                                                            load_filename = load_filename,
                                                            save_filename = save_filename
                                                            ) 