In [None]:
# Clone the repository
!git clone https://github.com/fashni/Triplet-Loss.git
%cd Triplet-Loss

# Download dataset

In [None]:
# LFW-df
!wget http://vis-www.cs.umass.edu/lfw/lfw-deepfunneled.tgz
!wget https://media.fashni.space/pub/dataset/lfw-deepfunneled/train.json
!wget https://media.fashni.space/pub/dataset/lfw-deepfunneled/valid.json
!tar -xzf lfw-deepfunneled.tgz

# Train

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf

from dataset import BatchGenerator, TripletGenerator
from models import facenet, inception, inception_resnet, siamnet
from utils import compute_metrics, compute_preds, get_images_and_labels

models = {
    "facenet": facenet,
    "inception": inception,
    "inception_resnet": inception_resnet,
    "siamnet": siamnet
}

In [None]:
seed = 69

# Dataset Parameters
train_path = 'train.json'
valid_path = 'valid.json'
batch_size = 32
augment = False
dset_name = "lfw"

# Model Parameters
input_shape = (160, 160, 3)
embedding_size = 128
model_name = "facenet"

# Training Parameters
learning_rate = 0.0001
epochs = 10

# Loss Parameters
strategy = "batch_all" # siamese, batch_all, or batch_hard
margin = 0.5
squared = False

In [None]:
# Create data generator
if strategy == "siamese":
  train_triplet_generator = TripletGenerator(train_path, batch_size=batch_size, input_shape=input_shape, augment=augment, seed=seed)
  valid_triplet_generator = TripletGenerator(valid_path, batch_size=batch_size, input_shape=input_shape, augment=False, seed=seed)
else:
  train_dataset_generator = BatchGenerator(train_path, batch_size=batch_size, input_shape=input_shape, augment=augment, seed=seed)
  valid_dataset_generator = BatchGenerator(valid_path, batch_size=batch_size, input_shape=input_shape, augment=False, seed=seed)

# Get the total images in the dataset
n_train = train_triplet_generator.n_images
n_valid = valid_triplet_generator.n_images

# Get datasets
train_dataset = train_triplet_generator.get_dataset()
valid_dataset = valid_triplet_generator.get_dataset()

In [None]:
# Build the model
model = models[model_name](input_shape, embedding_size, strategy=strategy, margin=margin, squared=squared)
model.summary()

In [None]:
# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
model.compile(optimizer=optimizer, weighted_metrics=[])

In [None]:
# Train and validation steps per epoch
train_steps = min(200, -(-n_train // batch_size))
valid_steps = min(50, -(-n_valid // batch_size))

In [None]:
# Train the model
hist = model.fit(
  train_dataset.take(train_steps),
  epochs=epochs,
  validation_data=valid_dataset.take(valid_steps),
  verbose=1,
  initial_epoch=0
)

# Evaluate

In [None]:
test_dataset = BatchGenerator(valid_path, input_shape=input_shape, augment=False, seed=seed).get_dataset()
images, labels = get_images_and_labels(test_dataset, 100)
images.shape

In [None]:
y_true, y_pred, embeddings = compute_preds(model, images, labels, batch_size, squared=squared, verbose=1)
fpr, tpr, prc, acc, f1, thres, auc = compute_metrics(y_true, y_pred)
j = (tpr-fpr).argmax()

In [None]:
print(f"{auc = }")
print(f"{f1[j] = }")
print(f"{acc[j] = }")
print(f"{prc[j] = }")
print(f"{tpr[j] = }")
print(f"{fpr[j] = }")
print(f"{thres[j] = }")

plt.plot(fpr, tpr)
plt.grid()
plt.show()

# Save weights

In [None]:
model.save_weights(f"{model_name}_{dset_name}.weights.h5")