# MAIN

In [1]:
import tensorflow as tf
import numpy as np
import model_tf2.net_factory as netf
import os

In [2]:
tf.test.gpu_device_name()

'/device:GPU:0'

In [3]:
#init some global variables
num_train_files = 128 #number of training tfrecords
num_val_files = 64 #number of testing tfrecords
buffer_size = 100
num_channels = 3
img_size = 64
num_classes = 200

In [4]:
def get_filenames(is_training):
    """
    input: is_training 
    output: a list of training/validation file names
    """
    filenames = []
    if is_training:
        for i in range(num_train_files):
            filename = "data/tf_records/train/" + 'train-%05d-of-00128' % i
            filenames.append(filename)
    else:
        for i in range(num_val_files):
            filename = "data/tf_records/val/" + 'val-%05d-of-00064' % i
            filenames.append(filename)
    return filenames

def parse_record(record):
    """
    input: a tfrecord
    output: parsed tfrecord based on features
    """
    name_to_features = {
        'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        'image/class/synset': tf.io.FixedLenFeature([], tf.string),
        'image/encoded': tf.io.FixedLenFeature([], tf.string)
    }
    return tf.io.parse_single_example(record, name_to_features)

def preprocess_data(is_training):
    """
    input: bool is_training
    output: training/val X,y tuple ready to feed into models
    """
    #init X and y list
    X = []
    y = []
    #get all filenames
    filenames = get_filenames(is_training)
    raw_dataset = tf.data.TFRecordDataset(filenames)
    parsed_dataset = raw_dataset.map(parse_record)
    num_samples = None
    if is_training:
        num_samples = num_train_files*buffer_size
    else:
        num_samples = num_val_files*buffer_size
        
    print(num_samples)
    #get image and label from each parsed sample
    for parsed in parsed_dataset.take(num_samples):
        image = tf.io.decode_jpeg(parsed['image/encoded'], channels=3)
        label = parsed["image/class/label"]
        X.append(image)
        y.append(label)
    #reshape X to num_samples * num_channel * height * width 
    #cast to float 32
    #X = tf.image.convert_image_dtype(X, dtype=tf.float32, saturate=False)
    X = tf.reshape(tf.stack(X), (num_samples, num_channels, img_size, img_size))
    X = X/255
    #X = tf.cast(tf.reshape(tf.stack(X), 
                   #(num_samples, num_channels, img_size, img_size)),
                #tf.float32)
    y = tf.stack(y)
    #y = tf.cast(tf.stack(y), tf.float32)
    return X, y


In [5]:
def format_image(image):
    image = tf.io.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.transpose(tf.reshape(image,(img_size, img_size, num_channels)), [2,0,1])
    image /= 255.
    return image

def read_parsed(parsed):
    image = format_image(parsed['image/encoded'])
    label = parsed['image/class/label'] - 1
    
    return image, label
    
def get_dataset(filenames, batch_size=32):
    
    raw_dataset = tf.data.TFRecordDataset(filenames)
    parsed_dataset = raw_dataset.map(parse_record)
    dataset = parsed_dataset.map(read_parsed)
    
    dataset = dataset.shuffle(84)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    
    return dataset.repeat()

In [6]:
def create_model():
    model = netf.myfishnet()
    model.compile(optimizer="adam",
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def train(model, batch_size=256, epochs=10, cpnum=0):
    
    train_files = get_filenames(is_training=True)
    val_files = get_filenames(is_training=False)

    train_ds = get_dataset(train_files, batch_size=batch_size)
    val_ds = get_dataset(val_files, batch_size=batch_size)

    lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor=np.sqrt(0.1), cooldown=0, patience=5, min_lr=0.5e-6)
    
    #save checkpoints for quicker access later
    checkpoint_path = f"checkpoints/training_{cpnum}/cp.ckpt"
    checkpoint_dir = os.path.dirname(checkpoint_path)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                     save_weights_only=True,
                                                     verbose=1)
    
    # train model with cp_callback
    model.fit(
        train_ds, validation_data=val_ds,
        epochs=epochs, steps_per_epoch=100, validation_steps=50,
        callbacks = [lr_reducer, cp_callback]
    )

    return model

In [7]:
tf.config.run_functions_eagerly(True)

cpnum = 0

model = create_model()
model = train(model, batch_size=256, epochs=15, cpnum=cpnum)

  "Even though the tf.config.experimental_run_functions_eagerly "


Epoch 1/15

Epoch 00001: saving model to checkpoints/training_0/cp.ckpt
Epoch 2/15

Epoch 00002: saving model to checkpoints/training_0/cp.ckpt
Epoch 3/15

Epoch 00003: saving model to checkpoints/training_0/cp.ckpt
Epoch 4/15

Epoch 00004: saving model to checkpoints/training_0/cp.ckpt
Epoch 5/15

Epoch 00005: saving model to checkpoints/training_0/cp.ckpt
Epoch 6/15

Epoch 00006: saving model to checkpoints/training_0/cp.ckpt
Epoch 7/15

Epoch 00007: saving model to checkpoints/training_0/cp.ckpt
Epoch 8/15

Epoch 00008: saving model to checkpoints/training_0/cp.ckpt
Epoch 9/15

Epoch 00009: saving model to checkpoints/training_0/cp.ckpt
Epoch 10/15

Epoch 00010: saving model to checkpoints/training_0/cp.ckpt
Epoch 11/15

Epoch 00011: saving model to checkpoints/training_0/cp.ckpt
Epoch 12/15

Epoch 00012: saving model to checkpoints/training_0/cp.ckpt
Epoch 13/15

Epoch 00013: saving model to checkpoints/training_0/cp.ckpt
Epoch 14/15

Epoch 00014: saving model to checkpoints/traini

In [10]:
cpnum = 0
while cpnum < 10:
    model = create_model()
    model.load_weights(f'checkpoints/training_{cpnum}/cp.ckpt')
    cpnum += 1
    model = train(model, batch_size=256, epochs=5, cpnum=cpnum)

Epoch 1/5

Epoch 00001: saving model to checkpoints/training_1/cp.ckpt
Epoch 2/5

Epoch 00002: saving model to checkpoints/training_1/cp.ckpt
Epoch 3/5

Epoch 00003: saving model to checkpoints/training_1/cp.ckpt
Epoch 4/5

Epoch 00004: saving model to checkpoints/training_1/cp.ckpt
Epoch 5/5

Epoch 00005: saving model to checkpoints/training_1/cp.ckpt
Epoch 1/5

Epoch 00001: saving model to checkpoints/training_2/cp.ckpt
Epoch 2/5

Epoch 00002: saving model to checkpoints/training_2/cp.ckpt
Epoch 3/5

Epoch 00003: saving model to checkpoints/training_2/cp.ckpt
Epoch 4/5

Epoch 00004: saving model to checkpoints/training_2/cp.ckpt
Epoch 5/5

Epoch 00005: saving model to checkpoints/training_2/cp.ckpt
Epoch 1/5

Epoch 00001: saving model to checkpoints/training_3/cp.ckpt
Epoch 2/5

Epoch 00002: saving model to checkpoints/training_3/cp.ckpt
Epoch 3/5

Epoch 00003: saving model to checkpoints/training_3/cp.ckpt
Epoch 4/5

Epoch 00004: saving model to checkpoints/training_3/cp.ckpt
Epoch 


Epoch 00001: saving model to checkpoints/training_9/cp.ckpt
Epoch 2/5

Epoch 00002: saving model to checkpoints/training_9/cp.ckpt
Epoch 3/5

Epoch 00003: saving model to checkpoints/training_9/cp.ckpt
Epoch 4/5

Epoch 00004: saving model to checkpoints/training_9/cp.ckpt
Epoch 5/5

Epoch 00005: saving model to checkpoints/training_9/cp.ckpt
Epoch 1/5

Epoch 00001: saving model to checkpoints/training_10/cp.ckpt
Epoch 2/5

Epoch 00002: saving model to checkpoints/training_10/cp.ckpt
Epoch 3/5

Epoch 00003: saving model to checkpoints/training_10/cp.ckpt
Epoch 4/5

Epoch 00004: saving model to checkpoints/training_10/cp.ckpt
Epoch 5/5

Epoch 00005: saving model to checkpoints/training_10/cp.ckpt
