In [None]:
!pip install -q kaggle
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d stoicstatic/face-recognition-dataset

!unzip -q face-recognition-dataset.zip

!rm -rf face-recognition-dataset.zip

In [None]:
import os
import cv2
import time
import random
import numpy as np

import tensorflow as tf
from tensorflow.keras.applications.inception_v3 import preprocess_input

import seaborn as sns
import matplotlib.pyplot as plt



In [None]:
# model that can be used to crop student images e.g selfies or photos they take using their phones. 

random.seed(5)
np.random.seed(5)
tf.random.set_seed(5)

# define the path to the dataset
dataset_path = 'Extracted Faces/Extracted Faces'

# get the list of all the folders in the dataset
folders = os.listdir(dataset_path)
print(folders)

"""
output:
['708', '1084', '662', '1217', '211', '840', '1322', '7', '142', '1261', '1352', '676', '522', '60', '194', '1057', '856', '200', '185', '1386', '964', '201', '1511', '335', '871', '1467', '1238', '775', '1581', '12', '539', '253', '810', '1095', '1063', '1302', '332', '1240', '386', '528', '533', '1245', '1676', '183', '340', '483', '1072', '1026', '966', '891', '980', '1309', '573', '571', '449', '59', '928', '1195', '1455', '353', '38', '1482', '1456', '1494', '237', '634', '1336', '865', '1628', '665', '663', '103', '1014', '283', '787', '1402', '1614', '328', '1300', '195', '540', '440', '594', '346', '1093', '150', '535', '904', '151', '1522', '748', '184', '446', '347', '1629', '686', '1451', '1346', '1231', '1113', '1194', '1177', '570', '254', '939', '1675', '674', '1025', '468', '1667', '630', '906', '1449', '1432', '1313', '790', '551', '303', '1024', '1506', '585', '472', '858', '1383', '1327', '1218', '643', '626', '50', '455', '1423', '359', '1570', '1173', '795', '158', '1320', '1303', '360', '1047', '735', '1595', '789', '1527', '403', '758', '656', '345', '495', '794', '1398', '1620', '1312', '1341', '1605', '991', '1576', '1626', '127', '848', '357', '1504', '1424', '140', '1564', '613', '1239', '1068', '776', '1291', '917', '607', '1163', '82', '482', '1146', '948', '1228', '1418', '1587', '462', '339', '1155', '444', '1002', '1531', '1119', '1211', '1366', '957', '1374', '1431', '92', '997', '1617', '293', '805', '306', '1355', '1427', '916', '970', '136', '109', '1275', '1087', '214', '1642', '1664', '351', '301', '1453', '1615', '792', '476', '74', '536', '1586', '1299', '101', '157', '1189', '314', '267', '615', '263', '408', '192', '1220', '1364', '508', '946', '177', '350', '644', '1073', '884', '769', '1662', '1230', '520', '1361', '1454', '326', '399', '807', '595', '1059', '128', '304', '1172', '1597', '890', '922', '895', '896', '1553', '1213', '1094', '669', '1199', '1133', '1466', '814', '1397', '668', '869', '1148', '1256', '1563', '507', '1665', '821', '1641', '1090', '600', '824', '788', '317', '1486', '1656', '902', '193', '233', '728', '783', '652', '1109', '1067', '1333', '1156', '827', '1524', '39', '311', '818', '1039', '894', '1616', '400', '802', '1515', '742', '550', '609', '734', '1330', '1088', '70', '57', '1219', '473', '112', '72', '1426', '1433', '765', '1558', '78', '205', '43', '923', '309', '1489', '1235', '66', '464', '422', '69', '1650', '62', '944', '1225', '828', '396', '1187', '1227', '619', '617', '1575', '793', '941', '500', '1599', '1544', '154', '1061', '628', '950', '563', '318', '1120', '29', '65', '1101', '406', '456', '709', '90', '861', '135', '122', '338', '437', '115', '1477', '575', '163', '1296', '1099', '1400', '423', '565', '1128', '1203', '1474', '1419', '1110', '694', '1079', '531', '111', '625', '759', '963', '1497', '658', '93', '1387', '44', '341', '178', '757', '330', '46', '1338', '1033', '1359', '352', '197', '1131', '582', '1508', '298', '544', '236', '91', '1516', '395', '548', '1420', '690', '971', '242', '173', '762', '1584', '275', '637', '1328', '1469', '10', '30', '629', '1343', '1202', '255', '316', '639', '801', '63', '264', '257', '1636', '1015', '1257', '1081', '995', '749', '546', '171', '755', '1224', '1593', '389', '469', '320', '25', '688', '978', '874', '768', '578', '1669', '1436', '1408', '1152', '126', '373', '1521', '817', '1104', '753', '1608', '1577', '542', '1000', '958', '704', '488', '1568', '107', '764', '1670', '83', '883', '1610', '1430', '1007', '1668', '972', '1149', '88', '959', '1188', '167', '1491', '1096', '809', '374', '329', '919', '733', '1632', '1123', '731', '1365', '276', '1349', '898', '899', '667', '580', '1150', '1212', '19', '1305', '567', '1473', '1510', '1512', '1264', '660', '825', '591', '266', '246', '900', '695', '1102', '1677', '179', '1029', '677', '1562', '248', '280', '1376', '49', '621', '513', '942', '1316', '1372', '1107', '428', '1157', '1411', '384', '1175', '202', '988', '1602', '1547', '1421', '343', '1', '868', '313', '1181', '120', '673', '300', '512', '543', '549', '1307', '1647', '1301', '1117', '914', '532', '73', '718', '506', '1153', '189', '1086', '139', '1633', '1673', '1447', '1452', '182', '409', '1324', '421', '588', '1030', '8', '864', '1289', '1092', '1403', '710', '388', '510', '1580', '901', '1098', '933', '1362', '1534', '605', '553', '918', '134', '829', '105', '679', '1209', '277', '1461', '693', '427', '876', '834', '1406', '649', '1600', '1631', '180', '1277', '1051', '268', '952', '1038', '926', '851', '1470', '1500', '1085', '104', '618', '724', '968', '986', '1265', '35', '1260', '1176', '270', '1192', '1200', '921', '64', '1373', '1075', '411', '750', '261', '220', '116', '1251', '404', '604', '219', '633', '678', '1381', '1407', '4', '1111', '1091', '1481', '294', '910', '1612', '1382', '786', '1179', '833', '949', '743', '1115', '1539', '1363', '1182', '310', '597', '1561', '215', '529', '1542', '855', '1319', '956', '393', '654', '990', '199', '1035', '1401', '642', '1027', '650', '1060', '36', '1487', '26', '747', '545', '1596', '932', '727', '751', '1142', '908', '608', '707', '342', '877', '1370', '554', '823', '982', '584', '509', '392', '1412', '191', '1549', '1554', '800', '1342', '1050', '683', '1280', '1127', '21', '732', '803', '37', '336', '450', '1116', '804', '287', '1267', '1190', '700', '1344', '1529', '447', '560', '1627', '418', '94', '915', '429', '1019', '1191', '1281', '994', '497', '1282', '413', '1151', '491', '47', '505', '773', '559', '781', '1174', '170', '680', '1244', '1022', '441', '41', '596', '1001', '172', '1476', '278', '1164', '1310', '436', '402', '1679', '1556', '416', '417', '132', '1648', '1468', '355', '835', '653', '1100', '1479', '1666', '889', '612', '1348', '1368', '1105', '592', '897', '1106', '1545', '1583', '1448', '813', '1159', '888', '1606', '960', '1272', '736', '912', '321', '6', '558', '1578', '1528', '798', '1136', '752', '797', '741', '1053', '1229', '581', '1323', '1170', '1185', '880', '841', '1450', '903', '1246', '556', '478', '148', '1043', '938', '1535', '209', '1082', '1391', '886', '162', '1201', '557', '356', '106', '467', '1440', '290', '1269', '1645', '265', '746', '245', '207', '1253', '439', '1062', '1112', '675', '692', '1392', '397', '84', '187', '1357', '1574', '1622', '203', '1513', '1619', '1337', '606', '125', '523', '720', '337', '33', '1294', '362', '1652', '1603', '822', '655', '1278', '610', '1208', '1611', '1588', '433', '562', '281', '1446', '1460', '905', '975', '1241', '284', '1186', '696', '1288', '945', '715', '1162', '1565', '1533', '67', '631', '1013', '1405', '138', '319', '1167', '862', '382', '1356', '1049', '492', '1114', '1147', '251', '1502', '1045', '740', '480', '1266', '1569', '96', '566', '1210', '1623', '1557', '1124', '1171', '738', '442', '940', '1154', '837', '312', '1637', '269', '547', '369', '961', '1036', '1070', '538', '68', '1140', '499', '99', '852', '16', '1571', '1137', '1198', '354', '117', '616', '1523', '1380', '1607', '124', '881', '484', '224', '763', '719', '1414', '1032', '1052', '1618', '118', '1247', '860', '1525', '13', '1259', '739', '992', '1018', '1183', '811', '1413', '937', '1609', '586', '1066', '602', '1360', '924', '1304', '778', '1653', '515', '1443', '723', '1016', '1144', '1384', '1321', '227', '1396', '196', '1353', '1624', '1184', '1258', '1298', '206', '1297', '387', '188', '756', '85', '572', '983', '1592', '785', '1248', '638', '141', '217', '54', '996', '1638', '1290', '20', '691', '843', '1031', '1439', '844', '1490', '1021', '451', '1552', '297', '770', '231', '1243', '475', '1651', '1287', '1242', '1193', '1125', '1076', '1635', '999', '58', '872', '1138', '210', '1055', '420', '1311', '305', '1017', '1625', '726', '1582', '1640', '1389', '1519', '1306', '954', '1532', '414', '977', '161', '1613', '929', '772', '1130', '657', '143', '1089', '717', '168', '1538', '1143', '1250', '1644', '457', '1434', '1168', '181', '1459', '1273', '1141', '166', '1369', '1214', '244', '223', '1350', '1678', '579', '155', '1639', '564', '156', '853', '1654', '854', '1077', '1649', '1415', '452', '230', '931', '97', '555', '1393', '23', '133', '466', '198', '1496', '624', '925', '716', '1226', '226', '459', '344', '1271', '863', '589', '1222', '1326', '1069', '1074', '243', '1042', '1598', '1232', '1498', '1314', '27', '976', '1270', '435', '524', '234', '379', '176', '149', '784', '174', '598', '299', '1409', '165', '985', '1205', '1180', '71', '601', '1445', '89', '771', '1495', '361', '1492', '780', '722', '521', '1020', '1484', '1488', '729', '241', '1097', '1340', '1441', '238', '1671', '1646', '434', '1621', '1034', '839', '1283', '486', '1254', '721', '1499', '1325', '1377', '1560', '1659', '1471', '836', '98', '537', '1536', '40', '75', '1546', '487', '1503', '574', '850', '410', '460', '258', '830', '808', '620', '366', '1463', '882', '1660', '368', '296', '1674', '689', '1071', '1526', '1661', '684', '1108', '1221', '870', '86', '517', '81', '661', '365', '815', '623', '1464', '1135', '1308', '119', '113', '1122', '1657', '394', '169', '1604', '568', '1006', '1262', '527', '419', '1428', '77', '160', '526', '415', '378', '1480', '967', '987', '1332', '779', '1335', '323', '873', '816', '911', '632', '45', '846', '1478', '1591', '363', '875', '17', '377', '659', '240', '32', '164', '744', '95', '703', '430', '1514', '1358', '102', '1429', '56', '832', '969', '1379', '1520', '1139', '110', '1458', '1040', '713', '812', '0', '687', '1567', '646', '1329', '920', '1263', '390', '146', '1630', '706', '1540', '80', '1274', '936', '52', '541', '249', '1410', '490', '989', '1161', '1046', '322', '502', '847', '1236', '208', '228', '648', '670', '272', '782', '927', '973']

"""

# function to read image
def read_image(index):   
    dir = os.path.join(dataset_path, index[0], index[1])
    image = cv2.imread(dir)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

# function to split data into train and test
def split_data(dir, train_size):
    train = []
    test = []
    for i in os.listdir(dir):
        for j in os.listdir(os.path.join(dir, i)):
            if np.random.rand() < train_size:
                train.append([i, j])
            else:
                test.append([i, j])
    return train, test

train_data, test_data = split_data(dataset_path, 0.8)

# check length of train and test data
print(len(train_data))
print(len(test_data))

# check first 5 elements of train and test data
print(train_data[:5])
print(test_data[:5])


# create triplet data
def create_triplets(data_dir, dir_list, max_files=10):
    triplets = []
    for i in range(len(dir_list)):
        anchor = read_image(dir_list[i])
        positive = read_image(dir_list[np.random.randint(0, len(dir_list))])
        negative = read_image(dir_list[np.random.randint(0, len(dir_list))])
        triplets.append([anchor, positive, negative])
    return triplets

# create train and test triplets
train_triplets = create_triplets(dataset_path, train_data)
test_triplets = create_triplets(dataset_path, test_data)

print(len(train_triplets))
print(len(test_triplets))

def get_batch(batch_size=64, triplets=train_triplets):
    batch = []
    for i in range(batch_size):
        batch.append(triplets[np.random.randint(0, len(triplets))])
    return batch

# plot sample triplets
img, ax = plt.subplots(3, 3, figsize=(10, 10))
for i in range(3):
    for j in range(3):
        ax[i, j].imshow(get_batch()[i][j])
        ax[i, j].axis('off')
plt.show()


from tensorflow import keras
from tensorflow.keras import backend, layers, metrics

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import Xception
from tensorflow.keras.models import Model, Sequential

from tensorflow.keras.utils import plot_model
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report


# get encoder model

input_shape = (224, 224, 3)

base_model = Xception(weights='imagenet', include_top=False, input_shape=input_shape, pooling='avg')

x = base_model.output
x = layers.GlobalAveragePooling2D()(x)

encoder = Sequential([
    base_model,
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.BatchNormalization(),
    layers.Dense(64, activation='relu'),
    layers.Lambda(lambda x: backend.l2_normalize(x, axis=1))
], name='encoder')

encoder.summary()

# model plot
keras.utils.plot_model(encoder, show_shapes=True, dpi=64)

# siamese network

input_shape = (128, 128, 3)

anchor = layers.Input(input_shape, name='anchor')
positive = layers.Input(input_shape, name='positive')
negative = layers.Input(input_shape, name='negative')

encoded_anchor = encoder(anchor)
encoded_positive = encoder(positive)
encoded_negative = encoder(negative)

def distance_layer(vectors):
    anchor, positive, negative = vectors
    positive_distance = backend.sum(backend.square(anchor - positive), axis=-1)
    negative_distance = backend.sum(backend.square(anchor - negative), axis=-1)
    return positive_distance, negative_distance

positive_distance, negative_distance = layers.Lambda(distance_layer, name='distance')([encoded_anchor, encoded_positive, encoded_negative])

siamese_network = Model(inputs=[anchor, positive, negative], outputs=[positive_distance, negative_distance], name='siamese_network')

siamese_network.summary()

# model plot
plot_model(siamese_network, show_shapes=True, dpi=64)

# loss function

def loss(y_true, y_pred):
    margin = 1
    return backend.mean(y_true * backend.square(y_pred) + (1 - y_true) * backend.square(backend.maximum(margin - y_pred, 0)))

# compile model

siamese_network.compile(loss=loss, optimizer=Adam(0.0001))

# train model

def train_model(model, epochs=10, batch_size=64):
    for i in range(epochs):
        # due to memory constraints, we will train in batches
        batch = get_batch(batch_size)
        X = {
            'anchor': np.array([i[0] for i in batch]),
            'positive': np.array([i[1] for i in batch]),
            'negative': np.array([i[2] for i in batch])
        }
        Y = np.ones((batch_size, 1))
        loss = model.train_on_batch(X, Y)
        print('Epoch: {}, Loss: {}'.format(i, loss))

train_model(siamese_network, epochs=2, batch_size=64)

# test model

def test_model(model, batch_size=64):
    batch = get_batch(batch_size, test_triplets)
    X = {
        'anchor': np.array([i[0] for i in batch]),
        'positive': np.array([i[1] for i in batch]),
        'negative': np.array([i[2] for i in batch])
    }
    Y = np.ones((batch_size, 1))
    loss = model.test_on_batch(X, Y)
    print('Loss: {}'.format(loss))

test_model(siamese_network, batch_size=64)

# save model

encoder.save('encoder.h5')

def classify_image(image, model):
    image = read_image(image)
    image = np.expand_dims(image, axis=0)
    prediction = model.predict(image)
    return prediction

# load model
encoder = keras.models.load_model('encoder.h5')

# get prediction
prediction = classify_image('test.jpg', encoder)

print(prediction)