In [None]:
import os
import datasets
import numpy as np
import tensorflow as tf
from beyondml import tflow
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
model = tf.keras.models.load_model('o3dcon_model_pruned.h5', custom_objects = tflow.utils.get_custom_objects())
model.trainable = False
model.stop_training
model.summary()

In [None]:
face_idx = 1200
cifar_idx = 100
text_idx = 2984

#face_idx = 83
#cifar_idx = 28
#text_idx = 923

In [None]:
image_dir = '/Users/jwrenn4/Documents/utkface-split/validation/'
image_files = os.listdir(image_dir)
image_name = image_files[1200]
img = tf.keras.preprocessing.image.load_img(os.path.join(image_dir, image_files[face_idx]))
face_image = tf.image.resize(np.array(img), (128, 128))/255
face_image = np.array(face_image).reshape((1, 128, 128, 3))
img

In [None]:
(cifar10_x_train, cifar10_y_train), (cifar10_x_test, cifar10_y_test) = tf.keras.datasets.cifar10.load_data()
cifar10_x_test = tf.image.resize(cifar10_x_test, (128, 128))/255
cifar_image = cifar10_x_test[cifar_idx]
plt.imshow(cifar_image)
plt.show()

In [None]:
text_data = datasets.load_dataset('ag_news')
tokenizer = tf.keras.preprocessing.text.Tokenizer(30000)
tokenizer.fit_on_texts(text_data['train']['text'])
sequences = tf.keras.preprocessing.sequence.pad_sequences(tokenizer.texts_to_sequences(text_data['test']['text']), 128)
token_positions = np.asarray([np.arange(128)]*sequences.shape[0])

text_sequence = sequences[text_idx]
token_position = token_positions[text_idx]

print(text_data['test']['text'][text_idx])

In [None]:
preds = model.predict([face_image, np.array(cifar_image).reshape(1, 128, 128, 3), text_sequence.reshape(1, -1), token_position.reshape(1, -1)])

In [None]:
face_pred_mapper = dict(zip(range(10), ['0-10', '10-20', '20-30', '30-40', '40-50', '50-60', '50-70', '70-80', '80-90', '90+']))
print(f'Facial Prediction: {face_pred_mapper[preds[0].argmax(axis = 1)[0]]}')
img

In [None]:
cifar10_class_mapper = dict(zip(range(10), ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']))
print(f'CIFAR10 Prediction: {cifar10_class_mapper[preds[1].argmax(axis = 1)[0]]}')
plt.imshow(cifar_image)
plt.show()

In [None]:
ag_news_class_mapper = {0 : 'World', 1 : 'Sports', 2 : 'Business', 3 : 'Sci/Tech'}
print(f'Text Prediction: {ag_news_class_mapper[preds[2].argmax(axis = 1)[0]]}')
print('\n')
print(text_data['test']['text'][text_idx])