In [None]:
import re
import time
import json
import hashlib
import requests
import numpy as np
import tensorflow as tf
from absl import logging
from notifier import notify
from alive_progress import alive_bar
from keras.preprocessing.image import load_img
from tensorflow.keras.applications.vgg19 import VGG19
from keras.applications.imagenet_utils import preprocess_input
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.callbacks import (
  ReduceLROnPlateau,
  EarlyStopping,
  ModelCheckpoint,
  TensorBoard
)

DATASET_PATH='./data/animes'
DATASET_JSON_PATH='./data/anime_data.json'
TFRECORD_PATH='./data/anime_data.tfrecord'
TG_ID="293701727"
TG_TOKEN="1878628343:AAEFVRsqDz63ycmaLOFS7gvsG969wdAsJ0w"
WEBHOOK_URL="https://discord.com/api/webhooks/796406472459288616/PAkiGGwqe0_PwtBxXYQvOzbk78B4RQP6VWRkvpBtw6Av0sc_mDa3saaIlwVPFjOIeIbt"

### Dataset anime

In [None]:
def dowload_image(url, anime_name, idx):
  #download image from url
  file_path = f'./data/animes/{anime_name}____{idx}.jpg' 
  if os.path.exists(file_path):
    return

  img_data = requests.get(url).content
  with open(file_path, 'wb') as handler:
    handler.write(img_data)

@notify(
  chat_id=TG_ID,
  api_token=TG_TOKEN,
  title='Anime images',
  msg='Finished downloading anime images'
)
def get_images(data):
  #with alive_bar(len(data)) as bar:
  for idx_a, anime_name in enumerate(data):
    urls = data[anime_name]
    for idx, url in enumerate(urls):
      if idx >= 400:
        break
      name_clean = re.sub(r'_+', r'_', re.sub(r'[\W\s]', r'_', anime_name))
      try:
        dowload_image(url['image'], name_clean, idx)
      except Exception as e:
        print(f'Error on download image {idx + 1} of {anime_name}')
        pass
    #bar()
    print(f'Progress: {idx_a + 1}/{len(data)} - {round((idx_a + 1)/len(data)*100, 2)}%')

def get_classes_anime(path):
  classes = set()
  for filename in os.listdir(path):
    class_name, _ = filename.split('____')
    classes.add(class_name)
  return list(classes)
      

In [None]:
anime_data = json.load(open(DATASET_JSON_PATH))
get_images(anime_data)

### TF functions

In [None]:
def get_class_id(class_name):
  return class_array.index(class_name)

def build_example(path_file, class_name):
  img_array = open(path_file, 'rb').read()
  #img = load_img(path_file, target_size=(224, 224))
  #img_array = np.array(img)
  #img_array = preprocess_input(img_array, mode='tf')
  #key = hashlib.sha256(img_array).hexdigest()
  example = tf.train.Example(
    features=tf.train.Features(feature={
    #'key': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf-8')])),
    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_array])),
    'class_id': tf.train.Feature(int64_list=tf.train.Int64List(value=[get_class_id(class_name)])),
    #'class_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[class_name.encode('utf-8')])),
    #'filepath': tf.train.Feature(bytes_list=tf.train.BytesList(value=[path_file.encode('utf-8')]))
  }))
  return example

def create_tfrecord(data_path):
  files = os.listdir(data_path)
  writer = tf.io.TFRecordWriter(TFRECORD_PATH)
  
  print('Started creating tfrecord')
  for idx, filename in enumerate(files):
    class_name, _ = filename.split('____')
    path_file = os.path.join(data_path, filename)
    tf_example = build_example(path_file, class_name)
    writer.write(tf_example.SerializeToString())
  print('Finished creating tfrecord')
  
  writer.close()

def parse_tfrecord(tfrecord, size):
  x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
  x_train = tf.image.decode_jpeg(x['image'], channels=3)
  x_train = tf.image.resize(x_train, (size, size))

  #class_id = tf.sparse.to_dense(x['class_id'], default_value=-1)
  class_id = x['class_id']
  if class_id is None:
    class_id = -1

  labels = tf.cast(class_id, tf.int64)
  y_train = labels
  #y_train = tf.stack([ labels ], axis=1)
  return x_train, y_train

def load_tfrecord_dataset(file_pattern, size=224):
  files = tf.data.Dataset.list_files(file_pattern)
  dataset = files.flat_map(tf.data.TFRecordDataset)
  return dataset.map(lambda x: parse_tfrecord(x, size))

IMAGE_FEATURE_MAP = {
  'image': tf.io.FixedLenFeature([], tf.string),
  'class_id': tf.io.FixedLenFeature([], tf.int64)
}

### TF dataset

In [None]:
create_tfrecord(DATASET_PATH)

In [None]:
tf_record = load_tfrecord_dataset(TFRECORD_PATH)

### Test models

In [None]:
def create_model(num_classes, input_shape, type_extractor = 'vgg') -> tf.keras.Model:
  if type_extractor == 'vgg':
    feature_extractor = VGG19(weights='imagenet', include_top=False, input_shape=input_shape)
  elif type_extractor == 'inception':
    feature_extractor = InceptionV3(weights='imagenet', include_top=False)
  elif type_extractor == 'resnet':
    feature_extractor = ResNet50V2(weights='imagenet', include_top=False)
  else:
    raise ValueError('type_extractor must be vgg, inception or resnet')
  
  model = tf.keras.Sequential()
  #model.add(tf.keras.layers.Input(input_shape, name='input'))
  model.add(feature_extractor)
  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(4096,activation=tf.nn.relu))
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.Dense(4096,activation=tf.nn.relu))
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.Dense(num_classes, activation=tf.nn.softmax))
  return model

def train(model, tf_dataset, epochs=50, mode=None, type_model='vgg'):
  if mode == 'eager_tf':
    avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
    avg_val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)
    
    for epoch in range(1, epochs + 1):
      for batch, (images, labels) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
          outputs = model(images, training=True)
          regularization_loss = tf.reduce_sum(model.losses)
          pred_loss = []
          for output, label, loss_fn in zip(outputs, labels, loss):
            pred_loss.append(loss_fn(label, output))
          total_loss = tf.reduce_sum(pred_loss) + regularization_loss
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        print("{}_train_{}, {}, {}".format(
          epoch, batch, total_loss.numpy(),
          list(map(lambda x: np.sum(x.numpy()), pred_loss))
        ))
        avg_loss.update_state(total_loss)
  elif mode == 'fit':
    callbacks = [
      ReduceLROnPlateau(verbose=1),
      EarlyStopping(patience=4, verbose=1),
      ModelCheckpoint(f'checkpoints/{type_model}.tf',
        verbose=1, save_weights_only=True
      ),
      TensorBoard(log_dir='logs')
    ]

    start_time = time.time()
    model.fit(
      tf_dataset,
      epochs=epochs,
      callbacks=callbacks
      #validation_data=val_dataset
    )
    end_time = time.time() - start_time
    print(f'Total Training Time: {end_time}')

In [None]:
for x, y in tf_record:
  print(x.shape, y.shape)
  break

In [None]:
#multi modal loss function
class_array = get_classes_anime(DATASET_PATH)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss = tf.keras.losses.SparseCategoricalCrossentropy()

model = create_model(num_classes=len(class_array), input_shape=(224, 224, 3), type_extractor='vgg')
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [None]:
for x, y in tf_record:
  print(x.shape, y.shape)
  print(y)
  #result = model(np.array(x).reshape((1, 224, 224, 3)), training=False)
  break

In [None]:
tf_record

In [None]:
train(model=model, tf_dataset=tf_record.batch(6), epochs=10, mode='fit', type_model='vgg')

In [None]:
%load_ext tensorboard