In [1]:
import re
import time
import json
import hashlib
import requests
import numpy as np
import tensorflow as tf
from absl import logging
from notifier import notify
from alive_progress import alive_bar
from keras.preprocessing.image import load_img
from tensorflow.keras.applications.vgg19 import VGG19
from keras.applications.imagenet_utils import preprocess_input
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.callbacks import (
  ReduceLROnPlateau,
  EarlyStopping,
  ModelCheckpoint,
  TensorBoard
)

import os
import pandas as pd
import random as rd
import pickle as pkl
from selenium.webdriver import Chrome
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

DATASET_PATH='./data/animes'
DATASET_JSON_PATH='./data/anime_data.json'
DATASET_JSON_RANK='./data/anime_rank.json'
TFRECORD_PATH='./data/anime_data.tfrecord'
TG_ID="293701727"
TG_TOKEN="1878628343:AAEFVRsqDz63ycmaLOFS7gvsG969wdAsJ0w"
WEBHOOK_URL="https://discord.com/api/webhooks/796406472459288616/PAkiGGwqe0_PwtBxXYQvOzbk78B4RQP6VWRkvpBtw6Av0sc_mDa3saaIlwVPFjOIeIbt"

#seed random seed to 42 for reproducibility
rd.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

### Dataset anime

In [2]:
def dowload_image(url, anime_name, idx):
  #download image from url
  file_path = f'./data/animes/{anime_name}____{idx}.jpg' 
  if os.path.exists(file_path):
    return

  img_data = requests.get(url).content
  with open(file_path, 'wb') as handler:
    handler.write(img_data)

@notify(
  chat_id=TG_ID,
  api_token=TG_TOKEN,
  title='Anime images',
  msg='Finished downloading anime images'
)
def get_images(data):
  #with alive_bar(len(data)) as bar:
  for idx_a, anime_name in enumerate(data):
    urls = data[anime_name]
    for idx, url in enumerate(urls):
      if idx >= 400:
        break
      name_clean = re.sub(r'_+', r'_', re.sub(r'[\W\s]', r'_', anime_name))
      try:
        dowload_image(url['image'], name_clean, idx)
      except Exception as e:
        print(f'Error on download image {idx + 1} of {anime_name}')
        pass
    #bar()
    print(f'Progress: {idx_a + 1}/{len(data)} - {round((idx_a + 1)/len(data)*100, 2)}%')

def get_classes_anime(path):
  classes = set()
  for filename in os.listdir(path):
    class_name, _ = filename.split('____')
    classes.add(class_name)
  return list(classes)

def wait_for_it(driver, xpath, timeout=3):
  try:
    return WebDriverWait(driver, timeout).until(
        EC.presence_of_element_located((By.XPATH, xpath))
    )
  except Exception as e:
    return None

def iter_post(driver):
  anime_data = []

  xpath_next = '//a[@class="next_page"]'
  next_button = True

  while next_button is not None:
    if len(anime_data) > 400:
      break
    ul_element = wait_for_it(driver, '//ul[@id="post-list-posts"]')
    if ul_element is None:
      next_button = wait_for_it(driver, xpath_next)
      if next_button is not None:
        next_button.click()
        time.sleep(1)
      continue
    for i, li_element in enumerate(ul_element.find_elements(By.TAG_NAME, 'li')):
      a_video = li_element.find_element(By.XPATH, './a').get_attribute('href')
      a_image = li_element.find_element(By.XPATH, './div/a/img').get_attribute('src')
      anime_data.append({
        'video': a_video,
        'image': a_image
      })
    next_button = wait_for_it(driver, xpath_next)
    if next_button is not None:
      next_button.click()
      time.sleep(rd.randint(1, 2))
  return anime_data

def get_images(url, driver, anime_name):
  url_search = url + anime_name
  driver.get(url_search)
  return iter_post(driver)

def get_names(driver):
  names = []
  xpath_next = '//a[@class="next_page"]'
  next_button = wait_for_it(driver, xpath_next)
  
  while next_button is not None:
    for tr_element in driver.find_elements(By.XPATH, '//table[@class="highlightable"]/tbody/tr'):
      try:
        amount_post = tr_element.find_element(By.XPATH, './td[1]').text
        amount_post = int(amount_post)
        if amount_post >= 10:
          a_name = tr_element.find_element(By.XPATH, './td[2]/a[2]' ).text
          names.append(a_name)
      except Exception as e:
        print(e)
        pass
    next_button.click()
    time.sleep(rd.randint(1, 2))
    next_button = wait_for_it(driver, xpath_next)
  return names

def get_score(anime_name, driver):
  url_search = f'https://myanimelist.net/anime.php?cat=anime&q={anime_name}'
  driver.get(url_search)
  score = 0
  for filename in os.listdir(path):
    class_name, _ = filename.split('____')
    score += 1
  return score

def relevant_anime(anime_name, df_anime):
  anime_name = re.sub(r'_', r' ', anime_name)
  df_result = df_anime[df_anime['name'].str.contains(anime_name)]

  if df_result.empty:
    anime_name = ' '.join(anime_name.split(' ')[:3])
    df_result = df_anime[df_anime['name'].str.contains(anime_name)]
  return not df_result.empty

#### Get anime images

In [None]:
anime_data = json.load(open(DATASET_JSON_PATH))
get_images(anime_data)

#### Filter animes

In [3]:
df = pd.read_pickle('./data/df_anime_rank.pkl')
all_class_array = get_classes_anime(DATASET_PATH)
class_array = set()

for anime_name in all_class_array:
  if relevant_anime(anime_name, df):
    class_array.add(anime_name)

class_array = list(class_array)
print(f'All classes: {len(all_class_array)} - Filtered {len(class_array)}')
del all_class_array

All classes: 4401 - Filtered 290


### TF functions

In [4]:
def get_class_id(class_name):
  return class_array.index(class_name)

def build_example(path_file, class_name):
  img_array = open(path_file, 'rb').read()
  #img = load_img(path_file, target_size=(224, 224))
  #img_array = np.array(img)
  #img_array = preprocess_input(img_array, mode='tf')
  #key = hashlib.sha256(img_array).hexdigest()
  example = tf.train.Example(
    features=tf.train.Features(feature={
    #'key': tf.train.Feature(bytes_list=tf.train.BytesList(value=[key.encode('utf-8')])),
    'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_array])),
    'class_id': tf.train.Feature(int64_list=tf.train.Int64List(value=[get_class_id(class_name)])),
    #'class_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[class_name.encode('utf-8')])),
    #'filepath': tf.train.Feature(bytes_list=tf.train.BytesList(value=[path_file.encode('utf-8')]))
  }))
  return example

def create_tfrecord(data_path, withe_list):
  files = os.listdir(data_path)
  writer = tf.io.TFRecordWriter(TFRECORD_PATH)
  
  print('Started creating tfrecord')
  for idx, filename in enumerate(files):
    class_name, _ = filename.split('____')
  
    if class_name in withe_list:
      path_file = os.path.join(data_path, filename)
      tf_example = build_example(path_file, class_name)
      writer.write(tf_example.SerializeToString())
  print('Finished creating tfrecord')
  writer.close()

def parse_tfrecord(tfrecord, size):
  x = tf.io.parse_single_example(tfrecord, IMAGE_FEATURE_MAP)
  x_train = tf.image.decode_jpeg(x['image'], channels=3)
  x_train = tf.image.resize(x_train, (size, size))

  #class_id = tf.sparse.to_dense(x['class_id'], default_value=-1)
  class_id = x['class_id']
  if class_id is None:
    class_id = -1

  labels = tf.cast(class_id, tf.int64)
  y_train = labels
  #y_train = tf.stack([ labels ], axis=1)
  return x_train, y_train

def load_tfrecord_dataset(file_pattern, size=224):
  files = tf.data.Dataset.list_files(file_pattern)
  dataset = files.flat_map(tf.data.TFRecordDataset)
  return dataset.map(lambda x: parse_tfrecord(x, size))

IMAGE_FEATURE_MAP = {
  'image': tf.io.FixedLenFeature([], tf.string),
  'class_id': tf.io.FixedLenFeature([], tf.int64)
}

if False:
  create_tfrecord(DATASET_PATH, class_array)

### Evaluate models

In [16]:
def create_model(num_classes, input_shape, type_extractor = 'vgg') -> tf.keras.Model:
  if type_extractor == 'vgg':
    feature_extractor = VGG19(weights='imagenet', include_top=False, input_shape=input_shape)
  elif type_extractor == 'inception':
    feature_extractor = InceptionV3(weights='imagenet', include_top=False)
  elif type_extractor == 'resnet':
    feature_extractor = ResNet50V2(weights='imagenet', include_top=False)
  else:
    raise ValueError('type_extractor must be vgg, inception or resnet')
  
  model = tf.keras.Sequential()
  #model.add(tf.keras.layers.Input(input_shape, name='input'))
  model.add(feature_extractor)
  model.add(tf.keras.layers.Flatten())
  
  model.add(tf.keras.layers.Dense(512,activation=tf.nn.relu))
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.Dense(512,activation=tf.nn.relu))
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.Dense(512,activation=tf.nn.relu))
  model.add(tf.keras.layers.Dropout(0.5))
  
  model.add(tf.keras.layers.Dense(num_classes, activation=tf.nn.softmax))
  return model

def train(model, tf_dataset, epochs=50, mode=None, type_model='vgg'):
  #logdir = "logs/scalars/" + time.strftime("%Y%m%d-%H%M%S")
  logdir = "logs/scalars/" + "test_replicated_seed_4"
  if mode == 'eager_tf':
    avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)
    avg_val_loss = tf.keras.metrics.Mean('val_loss', dtype=tf.float32)
    
    for epoch in range(1, epochs + 1):
      for batch, (images, labels) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
          outputs = model(images, training=True)
          regularization_loss = tf.reduce_sum(model.losses)
          pred_loss = []
          for output, label, loss_fn in zip(outputs, labels, loss):
            pred_loss.append(loss_fn(label, output))
          total_loss = tf.reduce_sum(pred_loss) + regularization_loss
        grads = tape.gradient(total_loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        print("{}_train_{}, {}, {}".format(
          epoch, batch, total_loss.numpy(),
          list(map(lambda x: np.sum(x.numpy()), pred_loss))
        ))
        avg_loss.update_state(total_loss)
  elif mode == 'fit':
    callbacks = [
      #ReduceLROnPlateau(verbose=1),
      #EarlyStopping(patience=4, verbose=1),
      ModelCheckpoint(f'checkpoints/{type_model}.tf',verbose=1, save_weights_only=True),
      TensorBoard(log_dir=logdir, histogram_freq=1)
    ]

    start_time = time.time()
    model.fit(
      tf_dataset,
      epochs=epochs,
      callbacks=callbacks
      #validation_data=val_dataset
    )
    end_time = time.time() - start_time
    print(f'Total Training Time: {end_time}')

In [17]:
tf_record = load_tfrecord_dataset(TFRECORD_PATH)
len_mini = 2000

mini_tf_record = tf_record.take(len_mini)
train_size = int(0.70 * len_mini)
test_size = int(0.30 * len_mini)

train_dataset = mini_tf_record.take(train_size)
test_dataset = mini_tf_record.skip(train_size)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)
#0.0005 - 110312 - batch 8 - 1 best
#0.00075 - 111326 - batch 12 - 2 best

loss = tf.keras.losses.SparseCategoricalCrossentropy()
model = create_model(num_classes=len(class_array), input_shape=(224, 224, 3), type_extractor='vgg')
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])

In [18]:
train(model=model, tf_dataset=train_dataset.batch(8), epochs=10, mode='fit', type_model='vgg')

Epoch 1/10
    175/Unknown - 17s 89ms/step - loss: 13.5632 - accuracy: 0.2179
Epoch 1: saving model to checkpoints\vgg.tf
Epoch 2/10
Epoch 2: saving model to checkpoints\vgg.tf
Epoch 3/10
Epoch 3: saving model to checkpoints\vgg.tf
Epoch 4/10
Epoch 4: saving model to checkpoints\vgg.tf
Epoch 5/10
Epoch 5: saving model to checkpoints\vgg.tf
Epoch 6/10
Epoch 6: saving model to checkpoints\vgg.tf
Epoch 7/10
Epoch 7: saving model to checkpoints\vgg.tf
Epoch 8/10
Epoch 8: saving model to checkpoints\vgg.tf
Epoch 9/10
Epoch 9: saving model to checkpoints\vgg.tf
Epoch 10/10
Epoch 10: saving model to checkpoints\vgg.tf
Total Training Time: 209.47185635566711


In [None]:
%load_ext tensorboard


In [None]:
for x, y in tf_record:
  print(x.shape, y.shape)
  break

In [None]:
tf_record = load_tfrecord_dataset(TFRECORD_PATH)
#get length of tf_record
len(list(tf_record))