# Final Project for IANNwTF 2022/23 


Learning to colorize grayscale dog pictures with the Stanford Dog Dataset.  

In [498]:
import tensorflow as tf
import numpy as np
import tensorboard
from PIL import Image
import os
from datetime import datetime
from skimage.color import rgb2lab, rgb2gray, lab2rgb
from skimage.io import imread, imshow
import matplotlib.pyplot as plt
import random
from tqdm import tqdm
import pickle
from keras.layers import Dense, Conv2D, Reshape, GlobalAveragePooling2D, MaxPooling2D, UpSampling2D, Flatten


In [499]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)


1 Physical GPUs, 1 Logical GPUs


In [500]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [501]:
# prepare data

# makes images same size and fills gaps at the edges with black pixels

def distortion_free_resize(image, img_size):
    w, h = img_size
    image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
    # Check tha amount of padding needed to be done.
    pad_height = h - tf.shape(image)[0]
    pad_width = w - tf.shape(image)[1]

    # Only necessary if you want to do same amount of padding on both sides.
    if pad_height % 2 != 0:
        height = pad_height // 2
        pad_height_top = height + 1
        pad_height_bottom = height
    else:
        pad_height_top = pad_height_bottom = pad_height // 2

    if pad_width % 2 != 0:
        width = pad_width // 2
        pad_width_left = width + 1
        pad_width_right = width
    else:
        pad_width_left = pad_width_right = pad_width // 2

    image = tf.pad(
        image,
        paddings=[
            [pad_height_top, pad_height_bottom],
            [pad_width_left, pad_width_right],
            [0, 0],
        ],
    )

    #image = tf.transpose(image, perm=[1, 0, 2])
    return image


In [502]:
def prepare_datasets():
    # go through folders 
    # make pairs of images + breed
    # (not needed for grayscale but might need it later)
    # divide into test and train
    base_path = "data/Images"
    lookup_table_breeds = {}
    inverse_lookup_table = {}
    train_img = []
    train_lbl = []
    test_img = []
    test_lbl = []
    for num,folder in enumerate(os.listdir(base_path)):
        lookup_table_breeds[folder[10:]] = num
        image_paths = os.path.join(base_path, folder)
        for count, image_path in enumerate(os.listdir(image_paths)):
            path = os.path.join(image_paths, image_path)
            if 0.9 * len(list(folder)) < count:
                # makes images same size and fills gaps at the edges with black pixels
                image = distortion_free_resize(tf.image.decode_jpeg(tf.io.read_file(path),3), (128,128))
                # convert into Lab color space
                train_img.append(rgb2lab(image/255))
                train_lbl.append(lookup_table_breeds[folder[10:]])

            else:
                # makes images same size and fills gaps at the edges with black pixels
                image = distortion_free_resize(tf.image.decode_jpeg(tf.io.read_file(path),3), (128,128))
                # convert into Lab color space
                test_img.append(rgb2lab(image/255))            
                test_lbl.append(lookup_table_breeds[folder[10:]])

    inverse_lookup_table = {v: k for k, v in lookup_table_breeds.items()}
    with open('saved_lookup_table.pkl', 'wb') as f:
        pickle.dump(inverse_lookup_table, f)

    train_images = tf.data.Dataset.from_tensor_slices(train_img)
    tf.data.Dataset.save(train_images, "saved_datasets/train_images")
    print(train_images)
    train_labels = tf.data.Dataset.from_tensor_slices(train_lbl)
    tf.data.Dataset.save(train_labels, "saved_datasets/train_labels")
    print(train_labels)

    test_images = tf.data.Dataset.from_tensor_slices(test_img)
    tf.data.Dataset.save(test_images, "saved_datasets/test_images")
    print(test_images)
    test_labels = tf.data.Dataset.from_tensor_slices(test_lbl)
    tf.data.Dataset.save(test_labels, "saved_datasets/test_labels")
    print(test_labels)

    return train_images, train_labels, test_images, test_labels, inverse_lookup_table


In [503]:
def load_datasets():
    train_images = tf.data.Dataset.load("saved_datasets/train_images")
    train_labels = tf.data.Dataset.load("saved_datasets/train_labels")
    test_images = tf.data.Dataset.load("saved_datasets/test_images")
    test_labels = tf.data.Dataset.load("saved_datasets/test_labels")

    with open('saved_lookup_table.pkl', 'rb') as f:
        inverse_lookup_table = pickle.load(f)


    return train_images, train_labels, test_images, test_labels, inverse_lookup_table

datasets_stored = True

if datasets_stored:
    train_images, train_labels, test_images, test_labels, inverse_lookup_table = load_datasets()
else:
    train_images, train_labels, test_images, test_labels, inverse_lookup_table = prepare_datasets()


In [504]:
batch_size = 64

def preprocess_dataset(images, labels):
    
    # flip each image left-right with a chance of 0.5
    images = images.map(lambda x: (tf.reverse(x, axis=[-2])) if random.random() < 0.5 else (x))
    images = images.map(lambda x: (tf.reverse(x, axis=[-3])) if random.random() < 0.5 else (x))

    # divide into greyscale input and color output

    images = images.map(lambda x: ((tf.expand_dims(x[:,:,0], -1))/100, (x[:,:,1:]/128)))
    labels = labels.map(lambda x: tf.one_hot(x, 120))
    labels = labels.map(lambda x: (tf.cast(x, tf.int16)))

    
    # or zip first and then do calculations??
    zipped = tf.data.Dataset.zip((images, labels))
    
    zipped = zipped.cache().shuffle(1000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return zipped



train_dataset = preprocess_dataset(train_images, train_labels)
test_dataset = preprocess_dataset(test_images, test_labels)


print(train_dataset)
print(test_dataset)

# the dataset has the format
# greyscale images (64,64), a and b terms from lab color space (64,64,2), onehotted labels (120)

<PrefetchDataset element_spec=((TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 2), dtype=tf.float32, name=None)), TensorSpec(shape=(None, 120), dtype=tf.int16, name=None))>
<PrefetchDataset element_spec=((TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 128, 128, 2), dtype=tf.float32, name=None)), TensorSpec(shape=(None, 120), dtype=tf.int16, name=None))>


In [505]:


# or take different crops from the pictures

# show sample pictures from dataset


In [506]:
class Low_Level_Features(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(64, 3, activation='relu', padding='same', strides=1) 
        self.conv2 = Conv2D(128, 3, activation='relu', padding='same', strides=1) 
        self.conv3 = Conv2D(128, 3, activation='relu', padding='same', strides=2) 
        self.conv4 = Conv2D(256, 3, activation='relu', padding='same', strides=1) 
        self.conv5 = Conv2D(256, 3, activation='relu', padding='same', strides=2) 
        self.conv6 = Conv2D(512, 3, activation='relu', padding='same', strides=1) 

    def __call__(self, x, training=False):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)

        return x

In [507]:
class Mid_Level_Features(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(512, 3, activation='relu', padding='same', strides=1) 
        self.conv2 = Conv2D(256, 3, activation='relu', padding='same', strides=1) 

    def __call__(self, x, training=False):
        x = self.conv1(x)
        x = self.conv2(x)

        return x

In [508]:
class High_Level_Features(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(512, 3, activation='relu', padding='same', strides=2) 
        self.conv2 = Conv2D(512, 3, activation='relu', padding='same', strides=1) 
        self.conv3 = Conv2D(512, 3, activation='relu', padding='same', strides=2) 
        self.conv4 = Conv2D(512, 3, activation='relu', padding='same', strides=1) 
        self.flatten = Flatten()
        self.dense1 = Dense(1024, activation="relu")
        self.dense2 = Dense(512, activation="relu")
        self.dense3 = Dense(256, activation="relu")

    def __call__(self, x, training=False):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)

        return x

In [509]:
class Classification_Network(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = Dense(256, activation="relu")
        self.dense2 = Dense(120, activation="softmax")

    def __call__(self, x, training=False):
        x = self.dense1(x)
        x = self.dense2(x)
        return x

In [510]:
class Fusion_Layer(tf.keras.Model):
    def __init__(self):
        super().__init__()
        #32,256,256
        #

        self.repeat_layer = tf.keras.layers.RepeatVector(32*32)
        self.reshape = tf.keras.layers.Reshape(([32,32,256]))
        self.concat = tf.keras.layers.Concatenate(axis=3)
        self.conv = Conv2D(256, kernel_size=1,strides=1, activation="relu", padding="same")


    def __call__(self, mid_level, global_vector, training=False):
        x = self.repeat_layer(global_vector) 
        x = self.reshape(x)
        x = self.concat([mid_level, x]) 
        x = self.conv(x)

        return x 


In [511]:
class Colorization_Network(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(128, 3, activation='relu', padding='same', strides=1) 
        self.upsampling1 = UpSampling2D(2)
        self.conv2 = Conv2D(64, 3, activation='relu', padding='same', strides=1) 
        self.conv3 = Conv2D(64, 3, activation='relu', padding='same', strides=1)
        self.upsampling2 = UpSampling2D(2) 
        self.conv4 = Conv2D(32, 3, activation='relu', padding='same', strides=1)
        self.conv5 = Conv2D(2, 3, activation='tanh', padding='same', strides=1) 

    def __call__(self, input):
        x = self.conv1(input)
        x = self.upsampling1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.upsampling2(x)
        x = self.conv4(x)
        x = self.conv5(x)

        return x


In [512]:
class Colorization_Model(tf.keras.Model):
    def __init__(self, low_level):
        super().__init__()
        self.low_level = low_level
        self.mid_level = Mid_Level_Features()
        self.fusion = Fusion_Layer()
        self.colorization = Colorization_Network()


    def call(self, input, high_level_input, training=False):
        low = self.low_level(input)
        middle = self.mid_level(low)
        fused = self.fusion(middle, high_level_input)
        colored = self.colorization(fused)
        return colored



In [513]:
class Classification_Model(tf.keras.Model):
    def __init__(self, low_level):
        super().__init__()
        self.low_level = low_level 
        self.high_level = High_Level_Features()
        self.classification = Classification_Network()




    def call(self, input, training=False):
        low = self.low_level(input)
        high = self.high_level(low)
        label = self.classification(high)
        return high, label

 


In [514]:
class Only_Colorization_Model(tf.keras.Model):
    def __init__(self, optimizer, loss_function_color):
        super().__init__()
        self.low_level = Low_Level_Features()
        self.mid_level = Mid_Level_Features()
        self.colorization = Colorization_Network()


        self.metrics_list = [
            tf.keras.metrics.Mean(name="loss_color"),
        ]

        self.optimizer = optimizer
        self.loss_function_color = loss_function_color

        

    @property
    def metrics(self):
        return self.metrics_list

    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_state()


    def call(self, input, training=False):
        low = self.low_level(input)
        middle = self.mid_level(low, training=training)
        colored = self.colorization(middle)
        return colored


    @tf.function
    def train_step(self, data):
        images,  label = data
        grey_image, color_image = images
        with tf.GradientTape() as color_tape: 
            predicted_color = self(grey_image, training = True)
            loss_color = self.loss_function_color(color_image, predicted_color)

        gradients_color = color_tape.gradient(loss_color, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients_color, self.trainable_variables))
        self.metrics[0].update_state(loss_color)  
        return predicted_color

    @tf.function
    def test_step(self, data):
        images, label = data
        grey_image, color_image = images    
        predicted_color = self(grey_image, training = False)
        loss_color = self.loss_function_color(color_image, predicted_color)
        self.metrics[0].update_state(loss_color)  
        return predicted_color


In [515]:
class Only_Classification_Model(tf.keras.Model):
    def __init__(self, optimizer, loss_function_category):
        super().__init__()
        self.low_level = Low_Level_Features()        
        self.classification_model = Classification_Model(self.low_level)

        self.metrics_list = [
            tf.keras.metrics.Mean(name="loss_category")]

        self.optimizer = optimizer
        self.loss_function_category = loss_function_category

    @property
    def metrics(self):
        return self.metrics_list

    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_state()

    def call(self, input, training=False):
        _, label = self.classification_model(input)
        return label

    @tf.function
    def train_step(self, data):
        images,  label = data
        grey_image, color_image = images
        with tf.GradientTape() as class_tape: 
            predicted_label = self(grey_image, training = True)
            loss_category = self.loss_function_category(label, predicted_label)

        gradients_category = class_tape.gradient(loss_category, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients_category, self.trainable_variables))
        self.metrics[0].update_state(loss_category)  
        return predicted_label

    @tf.function
    def test_step(self, data):
        images, label = data
        grey_image, color_image = images    
        predicted_label = self(grey_image, training = False)
        loss_category  = self.loss_function_category(label, predicted_label)            
        self.metrics[0].update_state(loss_category)  
        return predicted_label


In [516]:
class Model(tf.keras.Model):
    def __init__(self, optimizer, loss_function_color, loss_function_category):
        super().__init__()
        self.low_level = Low_Level_Features()        

        self.colorization_model = Colorization_Model(self.low_level)
        self.classification_model = Classification_Model(self.low_level)

        self.metrics_list = [
            tf.keras.metrics.Mean(name="loss_color"),
            tf.keras.metrics.Mean(name="loss_category")]

        self.optimizer = optimizer
        self.loss_function_color = loss_function_color
        self.loss_function_category = loss_function_category

    @property
    def metrics(self):
        return self.metrics_list

    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_state()

    def call(self, input, training=False):
        high_level_info, label = self.classification_model(input)
        colored = self.colorization_model(input, high_level_info)
        return colored, label

    @tf.function
    def train_step(self, data):
        images,  label = data
        grey_image, color_image = images
        with tf.GradientTape() as color_tape, tf.GradientTape() as class_tape: 
            predicted_color, predicted_label = self(grey_image, training = True)
            loss_color = self.loss_function_color(color_image, predicted_color)
            loss_category = self.loss_function_category(label, predicted_label)

        gradients_color = color_tape.gradient(loss_color, self.colorization_model.trainable_variables)
        gradients_category = class_tape.gradient(loss_category, self.classification_model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients_color, self.colorization_model.trainable_variables))
        self.optimizer.apply_gradients(zip(gradients_category, self.classification_model.trainable_variables))
        self.metrics[0].update_state(loss_color)  
        self.metrics[1].update_state(loss_category)  
        return predicted_color, predicted_label

    @tf.function
    def test_step(self, data):
        images, label = data
        grey_image, color_image = images    
        predicted_color, predicted_label = self(grey_image, training = False)
        loss_color = self.loss_function_color(color_image, predicted_color)
        loss_category  = self.loss_function_category(label, predicted_label)            
        self.metrics[0].update_state(loss_color)  
        self.metrics[1].update_state(loss_category)  
        return predicted_color, predicted_label


In [517]:
# autoencoder from https://arxiv.org/pdf/1712.03400.pdf

# model

# create the whole autoencoder model
# (steal from https://towardsdatascience.com/image-colorization-using-convolutional-autoencoders-fdabc1cb1dbe )

#encoder
class Encoder(tf.keras.Model):
  def __init__(self):
    super().__init__()
    #input 1,128,128
    self.conv1 = Conv2D(64, 3, activation='relu', padding='same', strides=1) 
    self.conv2 = Conv2D(128, 3, activation='relu', padding='same', strides=2) 
    self.conv3 = Conv2D(128, 3, activation='relu', padding='same', strides=1) 
    self.conv4 = Conv2D(256, 3, activation='relu', padding='same', strides=2) 
    self.conv5 = Conv2D(256, 3, activation='relu', padding='same', strides=1) 
    self.conv6 = Conv2D(512, 3, activation='relu', padding='same', strides=1) 
    self.conv7 = Conv2D(512, 3, activation='relu', padding='same', strides=1) 
    self.conv8 = Conv2D(256, 3, activation='relu', padding='same', strides=1) 

    self.flatten = Flatten()



  @tf.function
  def __call__(self, x, training=False):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.conv6(x)
    x = self.conv7(x)
    x = self.conv8(x)
    x = self.flatten(x)
    return x


# decoder
class Decoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.reshape = Reshape((32, 32, 256))
        
        self.conv1 = Conv2D(256, 3, activation="relu", padding="same", strides=1)
        self.conv2 = Conv2D(128, 3, activation="relu", padding="same")
        self.upsampling2 = UpSampling2D(2)
        self.conv3 = Conv2D(64, 3, activation="relu", padding="same")
        self.conv4 = Conv2D(64, 3, activation="tanh", padding="same")
        self.upsampling4 = UpSampling2D(2)
        self.conv5 = Conv2D(32, 3, activation="tanh", padding="same")
        self.conv5 = Conv2D(2, 3, activation="tanh", padding="same")


    @tf.function
    def __call__(self, x, training=False):
        x = self.reshape(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upsampling2(x)        
        x = self.conv3(x)
        x = self.conv4(x)        
        x = self.upsampling4(x)
        x = self.conv5(x)
        return x

class Autoencoder(tf.keras.Model):
  def __init__(self, optimizer, loss_function):
    super().__init__()
    self.enc = Encoder()
    self.dec = Decoder()

    self.metrics_list = [
      tf.keras.metrics.Mean(name="loss")]

    self.optimizer = optimizer
    self.loss_function = loss_function

  @property
  def metrics(self):
    return self.metrics_list
  
  def get_encoder(self):
    return self.enc
   
  def get_decoder(self):
    return self.dec
    
  def reset_metrics(self):
     for metric in self.metrics:
        metric.reset_state()

  def call(self, input, training=False):
    embedding = self.enc(input)
    output = self.dec(embedding)
    return output

  @tf.function
  def train_step(self, data):
    images,  label = data
    grey_image, color_image = images
    with tf.GradientTape() as tape: 
      prediction = self(grey_image, training = True)
      loss = self.loss_function(color_image, prediction)

    gradients = tape.gradient(loss, self.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients,self.trainable_variables))
    self.metrics[0].update_state(loss)  
    return gradients

  @tf.function
  def test_step(self, data):
    images, label = data
    grey_image, color_image = images    
    prediction = self(grey_image, training = False)
    loss = self.loss_function(color_image, prediction)
    self.metrics[0].update_state(loss)
    return prediction, color_image


In [518]:
def visualize(predicted_color, predicted_label, data):
    images = data[0]
    labels = data[1]
    #print(data)
    #print(data[0])
    
    grey_images = images[0]
    grey_image = grey_images[0]
    color_images = images[1]
    color_image = color_images[0]
    label = labels[0]
    

    predicted_color_lab_scaled = tf.clip_by_value(((predicted_color + [0, 0]) * [128, 128]), -128, 127)
    color_image_lab_scaled = tf.clip_by_value(((color_image + [0, 0]) * [128, 128]), -128, 127)
    greyscale = tf.squeeze((grey_image + [0]) * [100], axis=-1)
        
    rgb_prediction = lab2rgb([greyscale, predicted_color_lab_scaled[:,:,0], predicted_color_lab_scaled[:,:,1]], channel_axis=0)
    rgb_original = lab2rgb([greyscale, color_image_lab_scaled[:,:,0], color_image_lab_scaled[:,:,1]], channel_axis=0)
    rgb_prediction = tf.transpose(rgb_prediction, perm=[1,2,0])
    rgb_original = tf.transpose(rgb_original, perm=[1,2,0])
    
    #print(predicted_color_lab_scaled)


    #labels

    predicted_label = inverse_lookup_table[tf.argmax(predicted_label).numpy()]
    true_label = inverse_lookup_table[tf.argmax(label).numpy()]

    print(predicted_label, true_label)


    fig, ax = plt.subplots(1, 2, figsize = (18, 30))
    ax[0].imshow(rgb_prediction) 
    ax[0].axis('off')
    ax[0].set_title('pred: ' + predicted_label)
    
    ax[1].imshow(rgb_original) 
    ax[1].axis('off')
    ax[1].set_title('orig: ' + true_label)
    plt.imshow(rgb_original)
    
   

In [519]:
# training loop

# log results with tensorboard 
# save model to be able to reuse it

def training_loop(model, train_ds, test_ds, epochs, train_summary_writer, test_summary_writer, save_path):
    for epoch in range(epochs):
        model.reset_metrics()

        
        for data in tqdm(train_ds, position=0, leave=True):
            predicted_color, predicted_label = model.train_step(data)


        with train_summary_writer.as_default():
            tf.summary.scalar(model.metrics[0].name, model.metrics[0].result(), step=epoch)
            tf.summary.scalar(model.metrics[1].name, model.metrics[1].result(), step=epoch)
        
        print("Epoch: ", epoch+1)
        print("Loss Color: ", model.metrics[0].result().numpy(), "(Train)")
        print("Loss Category: ", model.metrics[1].result().numpy(), "(Train)")
        model.reset_metrics()

        last_data = None
        for data in tqdm(test_ds, position=0, leave=True):
            predicted_color, predicted_label = model.test_step(data)
            last_data = data


        with test_summary_writer.as_default():
            tf.summary.scalar(model.metrics[0].name, model.metrics[0].result(), step=epoch)
            tf.summary.scalar(model.metrics[1].name, model.metrics[1].result(), step=epoch)
            
        print("Loss Color: ", model.metrics[0].result().numpy(), "(Test)")
        print("Loss Category: ", model.metrics[1].result().numpy(), "(Test)")

        if (epoch // 5) == 0:
            visualize(predicted_color[0], predicted_label[0], last_data)





In [520]:
def training_loop_colorization(model, train_ds, test_ds, epochs, train_summary_writer, test_summary_writer, save_path):
    for epoch in range(epochs):
        model.reset_metrics()

        
        for data in tqdm(train_ds, position=0, leave=True):
            predicted_color = model.train_step(data)


        with train_summary_writer.as_default():
            tf.summary.scalar(model.metrics[0].name, model.metrics[0].result(), step=epoch)
        
        print("Epoch: ", epoch+1)
        print("Loss Color: ", model.metrics[0].result().numpy(), "(Train)")
        model.reset_metrics()

        last_data = None
        for data in tqdm(test_ds, position=0, leave=True):
            predicted_color = model.test_step(data)
            last_data = data


        with test_summary_writer.as_default():
            tf.summary.scalar(model.metrics[0].name, model.metrics[0].result(), step=epoch)
            
        print("Loss Color: ", model.metrics[0].result().numpy(), "(Test)")

        if (epoch // 10) == 0:
            visualize(predicted_color[0], last_data[1][0], last_data)

In [528]:
def training_loop_classification(model, train_ds, test_ds, epochs, train_summary_writer, test_summary_writer, save_path):
    for epoch in range(epochs):
        model.reset_metrics()

        
        for data in tqdm(train_ds, position=0, leave=True):
            predicted_label = model.train_step(data)


        with train_summary_writer.as_default():
            tf.summary.scalar(model.metrics[0].name, model.metrics[0].result(), step=epoch)
        
        print("Epoch: ", epoch+1)
        print("Loss Category: ", model.metrics[0].result().numpy(), "(Train)")
        model.reset_metrics()

        last_data = None
        for data in tqdm(test_ds, position=0, leave=True):
            predicted_label = model.test_step(data)
            last_data = data

        with test_summary_writer.as_default():
            tf.summary.scalar(model.metrics[0].name, model.metrics[0].result(), step=epoch)
            
        print("Loss Category: ", model.metrics[0].result().numpy(), "(Test)")

        print(inverse_lookup_table[tf.argmax(predicted_label[0]).numpy()], 
              inverse_lookup_table[tf.argmax(last_data[1][0]).numpy()])


In [529]:
# train

epochs = 100
optimizer = tf.keras.optimizers.Adadelta()
loss_function_color = tf.keras.losses.MeanSquaredError()
loss_function_category = tf.keras.losses.CategoricalCrossentropy()
#autoencoder = Autoencoder(optimizer=optimizer, loss_function=loss_function)

model= Model(optimizer=optimizer, loss_function_color=loss_function_color, loss_function_category=loss_function_category)
only_colorization = Only_Colorization_Model(optimizer=optimizer, loss_function_color=loss_function_color)
only_classification = Only_Classification_Model(optimizer=optimizer, loss_function_category=loss_function_category)

current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
save_path = f"models/{current_time}"
train_log_path = f"logs/{current_time}/train"
test_log_path = f"logs/{current_time}/test"
train_summary_writer = tf.summary.create_file_writer(train_log_path)
test_summary_writer = tf.summary.create_file_writer(test_log_path)
#training_loop_colorization(only_colorization, train_dataset, test_dataset, epochs, train_summary_writer, test_summary_writer, save_path)
training_loop_classification(only_classification, train_dataset, test_dataset, epochs, train_summary_writer, test_summary_writer, save_path)
#training_loop(model, train_dataset, test_dataset, epochs, train_summary_writer, test_summary_writer, save_path)

100%|██████████| 282/282 [00:35<00:00,  7.94it/s]


Epoch:  1
Loss Category:  4.787528 (Train)


100%|██████████| 41/41 [00:01<00:00, 26.70it/s]


Loss Category:  4.7875037 (Test)
Airedale miniature_poodle


100%|██████████| 282/282 [00:35<00:00,  8.04it/s]


Epoch:  2
Loss Category:  4.78749 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.97it/s]


Loss Category:  4.7875104 (Test)
Airedale bull_mastiff


100%|██████████| 282/282 [00:35<00:00,  8.05it/s]


Epoch:  3
Loss Category:  4.7874703 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.77it/s]


Loss Category:  4.7875075 (Test)
Pembroke bluetick


100%|██████████| 282/282 [00:35<00:00,  8.03it/s]


Epoch:  4
Loss Category:  4.787457 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.08it/s]


Loss Category:  4.787512 (Test)
Airedale Irish_water_spaniel


100%|██████████| 282/282 [00:35<00:00,  8.05it/s]


Epoch:  5
Loss Category:  4.787447 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.71it/s]


Loss Category:  4.7875104 (Test)
Saluki West_Highland_white_terrier


100%|██████████| 282/282 [00:35<00:00,  8.01it/s]


Epoch:  6
Loss Category:  4.7874336 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.08it/s]


Loss Category:  4.7875147 (Test)
Saluki flat-coated_retriever


100%|██████████| 282/282 [00:35<00:00,  8.02it/s]


Epoch:  7
Loss Category:  4.7874246 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.97it/s]


Loss Category:  4.7875166 (Test)
Samoyed African_hunting_dog


100%|██████████| 282/282 [00:34<00:00,  8.09it/s]


Epoch:  8
Loss Category:  4.787413 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.99it/s]


Loss Category:  4.7875204 (Test)
Sealyham_terrier French_bulldog


100%|██████████| 282/282 [00:34<00:00,  8.11it/s]


Epoch:  9
Loss Category:  4.787401 (Train)


100%|██████████| 41/41 [00:02<00:00, 16.12it/s]


Loss Category:  4.7875237 (Test)
Saluki German_shepherd


100%|██████████| 282/282 [00:34<00:00,  8.06it/s]


Epoch:  10
Loss Category:  4.7873907 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.11it/s]


Loss Category:  4.787525 (Test)
Samoyed malamute


100%|██████████| 282/282 [00:34<00:00,  8.16it/s]


Epoch:  11
Loss Category:  4.7873807 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.38it/s]


Loss Category:  4.7875204 (Test)
Airedale Irish_water_spaniel


100%|██████████| 282/282 [00:34<00:00,  8.07it/s]


Epoch:  12
Loss Category:  4.787372 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.22it/s]


Loss Category:  4.787539 (Test)
Maltese_dog wire-haired_fox_terrier


100%|██████████| 282/282 [00:34<00:00,  8.08it/s]


Epoch:  13
Loss Category:  4.787362 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.84it/s]


Loss Category:  4.7875323 (Test)
Maltese_dog kelpie


100%|██████████| 282/282 [00:34<00:00,  8.14it/s]


Epoch:  14
Loss Category:  4.7873507 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.33it/s]


Loss Category:  4.7875423 (Test)
Maltese_dog Shetland_sheepdog


100%|██████████| 282/282 [00:34<00:00,  8.09it/s]


Epoch:  15
Loss Category:  4.787341 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.30it/s]


Loss Category:  4.7875376 (Test)
Maltese_dog Lakeland_terrier


100%|██████████| 282/282 [00:34<00:00,  8.10it/s]


Epoch:  16
Loss Category:  4.7873297 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.15it/s]


Loss Category:  4.787539 (Test)
Maltese_dog wire-haired_fox_terrier


100%|██████████| 282/282 [00:35<00:00,  8.05it/s]


Epoch:  17
Loss Category:  4.7873216 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.94it/s]


Loss Category:  4.787546 (Test)
Maltese_dog schipperke


100%|██████████| 282/282 [00:35<00:00,  8.04it/s]


Epoch:  18
Loss Category:  4.787313 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.34it/s]


Loss Category:  4.7875476 (Test)
Maltese_dog Italian_greyhound


100%|██████████| 282/282 [00:35<00:00,  8.00it/s]


Epoch:  19
Loss Category:  4.7872996 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.40it/s]


Loss Category:  4.787559 (Test)
Maltese_dog curly-coated_retriever


100%|██████████| 282/282 [00:35<00:00,  8.04it/s]


Epoch:  20
Loss Category:  4.7872906 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.27it/s]


Loss Category:  4.787553 (Test)
Maltese_dog Border_terrier


100%|██████████| 282/282 [00:35<00:00,  8.05it/s]


Epoch:  21
Loss Category:  4.7872806 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.50it/s]


Loss Category:  4.7875524 (Test)
Maltese_dog Irish_wolfhound


100%|██████████| 282/282 [00:35<00:00,  8.00it/s]


Epoch:  22
Loss Category:  4.7872677 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.12it/s]


Loss Category:  4.787554 (Test)
Maltese_dog collie


100%|██████████| 282/282 [00:34<00:00,  8.19it/s]


Epoch:  23
Loss Category:  4.78726 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.99it/s]


Loss Category:  4.787563 (Test)
Maltese_dog malamute


100%|██████████| 282/282 [00:34<00:00,  8.17it/s]


Epoch:  24
Loss Category:  4.78725 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.69it/s]


Loss Category:  4.7875557 (Test)
Maltese_dog briard


100%|██████████| 282/282 [00:34<00:00,  8.06it/s]


Epoch:  25
Loss Category:  4.7872415 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.17it/s]


Loss Category:  4.7875714 (Test)
Maltese_dog African_hunting_dog


100%|██████████| 282/282 [00:35<00:00,  8.03it/s]


Epoch:  26
Loss Category:  4.787234 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.07it/s]


Loss Category:  4.7875614 (Test)
Maltese_dog giant_schnauzer


100%|██████████| 282/282 [00:34<00:00,  8.08it/s]


Epoch:  27
Loss Category:  4.78722 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.78it/s]


Loss Category:  4.7875843 (Test)
Maltese_dog standard_poodle


100%|██████████| 282/282 [00:35<00:00,  8.05it/s]


Epoch:  28
Loss Category:  4.7872157 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.05it/s]


Loss Category:  4.7875767 (Test)
Maltese_dog Old_English_sheepdog


100%|██████████| 282/282 [00:34<00:00,  8.07it/s]


Epoch:  29
Loss Category:  4.787201 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.94it/s]


Loss Category:  4.7875776 (Test)
Maltese_dog Tibetan_terrier


100%|██████████| 282/282 [00:34<00:00,  8.09it/s]


Epoch:  30
Loss Category:  4.787194 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.12it/s]


Loss Category:  4.7875752 (Test)
Maltese_dog Bernese_mountain_dog


100%|██████████| 282/282 [00:35<00:00,  8.04it/s]


Epoch:  31
Loss Category:  4.7871814 (Train)


100%|██████████| 41/41 [00:01<00:00, 28.35it/s]


Loss Category:  4.787582 (Test)
Maltese_dog Blenheim_spaniel


100%|██████████| 282/282 [00:35<00:00,  8.03it/s]


Epoch:  32
Loss Category:  4.78717 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.96it/s]


Loss Category:  4.7875943 (Test)
Maltese_dog Kerry_blue_terrier


100%|██████████| 282/282 [00:34<00:00,  8.06it/s]


Epoch:  33
Loss Category:  4.7871647 (Train)


100%|██████████| 41/41 [00:01<00:00, 29.17it/s]


Loss Category:  4.7875886 (Test)
Maltese_dog otterhound


100%|██████████| 282/282 [00:35<00:00,  7.99it/s]


Epoch:  34
Loss Category:  4.787154 (Train)


100%|██████████| 41/41 [00:01<00:00, 27.94it/s]


Loss Category:  4.7875953 (Test)
Maltese_dog Irish_wolfhound


  5%|▍         | 13/282 [00:01<00:34,  7.80it/s]


KeyboardInterrupt: 

In [531]:
%tensorboard --logdir logs