# MesoNet

## Config

In [1]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

## Imports

In [2]:
from datetime import datetime
import pathlib
import typing as t

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import applications
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.applications import xception
from tensorflow.keras import preprocessing
from tensorflow.keras import metrics
from tensorflow.keras import callbacks

## Constants

In [3]:
DATASET_PATH = pathlib.Path("../../../datasets/celeb_df/")
MODELS_PATH = pathlib.Path("../../../saved_models")
# TensorBoard logs path
LOGS_PATH = pathlib.Path("./logs")
IMAGE_SIZE = (256, 256)
INPUT_SHAPE = (*IMAGE_SIZE, 3)

## Load datasets

In [4]:
SEED = int(datetime.today().timestamp())
VALIDATION_SPLIT = 0.05
BATCH_SIZE = 32

In [5]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATASET_PATH,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    validation_split=VALIDATION_SPLIT,
    subset="training",
    seed=SEED,
)

Found 2342158 files belonging to 2 classes.
Using 2225051 files for training.


In [6]:
validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATASET_PATH,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE,
    validation_split=VALIDATION_SPLIT,
    subset="validation",
    seed=SEED,
)

Found 2342158 files belonging to 2 classes.
Using 117107 files for validation.


In [7]:
REALS_TO_FAKE_RATIO = (
    len(list(DATASET_PATH.joinpath("reals").iterdir()))
    / len(list(DATASET_PATH.joinpath("fakes").iterdir()))
)
REALS_TO_FAKE_RATIO

0.10647836701990959

## Define model

In [8]:
def build_meso_net() -> None:
    model = keras.Sequential()
    model.add(layers.InputLayer(INPUT_SHAPE))
    # First block
    model.add(layers.Conv2D(8, (3, 3), padding="same", activation = "relu"))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPool2D(pool_size=(2, 2), padding="same"))
    # Second block
    model.add(layers.Conv2D(8, (5, 5), padding="same", activation = "relu"))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPool2D(pool_size=(2, 2), padding="same"))
    # Third block
    model.add(layers.Conv2D(16, (5, 5), padding="same", activation = "relu"))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPool2D(pool_size=(2, 2), padding="same"))
    # Fourth layer
    model.add(layers.Conv2D(16, (5, 5), padding="same", activation = "relu"))
    model.add(layers.BatchNormalization())
    model.add(layers.MaxPool2D(pool_size=(4, 4), padding="same"))
    # Top
    model.add(layers.Flatten())
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(16))
    model.add(layers.LeakyReLU(alpha=0.1))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(1, activation="sigmoid"))
    
    return model

In [9]:
model = build_meso_net()

## Training

In [10]:
METRICS = [
        metrics.BinaryAccuracy(),
        metrics.AUC(),
        #metrics.Precision(),
        #metrics.Recall(),
        #metrics.TruePositives(),
        #metrics.TrueNegatives(),
        #metrics.FalsePositives(),
        #metrics.FalseNegatives(),
    ]

In [11]:
# Save each run in individual directory
log_dir = LOGS_PATH.joinpath("fit").joinpath(datetime.now().strftime("%Y%m%d-%H%M%S"))

CALLBACKS = [
    callbacks.EarlyStopping(
        monitor="val_loss",
        patience=2,
    ),
    callbacks.ModelCheckpoint(
        filepath=MODELS_PATH.joinpath("meso_net_{epoch:02d}_{val_loss:.2f}.h5"),
        monitor='val_loss',
        save_best_only=True,
    ),
    callbacks.TensorBoard(log_dir=log_dir),
]

In [12]:
# Params following "FaceForensics++: Learning to Detect Manipulated Facial Images"
optimizer = optimizers.Adam(
    learning_rate=1e-3,
    epsilon=1e-08
)

In [13]:
model.compile(
    optimizer=optimizer,
    loss="binary_crossentropy",
    metrics=METRICS
)

In [14]:
history = model.fit(
    train_ds,
    validation_data=validation_ds,
    class_weight={
        # 0 are fakes since they are first in alphabetical ordering
        0: REALS_TO_FAKE_RATIO,
        1: 1.0,
    },
    epochs=10,
    callbacks=CALLBACKS
)

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10


In [15]:
history.history

{'loss': [0.09156110882759094,
  0.0632546916604042,
  0.05672319605946541,
  0.05305468291044235],
 'binary_accuracy': [0.7606275081634521,
  0.8593434691429138,
  0.8768356442451477,
  0.8864551782608032],
 'auc': [0.8528405427932739,
  0.9363375306129456,
  0.9491717219352722,
  0.9557192921638489],
 'val_loss': [0.2676297724246979,
  0.1393992006778717,
  0.18926742672920227,
  0.25055092573165894],
 'val_binary_accuracy': [0.9169221520423889,
  0.9463226199150085,
  0.9346836805343628,
  0.8943017721176147],
 'val_auc': [0.948848307132721,
  0.9578119516372681,
  0.9777187705039978,
  0.9796455502510071]}