In [39]:
import tensorflow as tf
from keras.datasets import mnist
from keras.layers import *
from keras.models import Sequential, Model

In [40]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Data Preparation

In [73]:
img_a = Input((28,28), name='img_a')
img_b = Input((28,28), name='img_b')

def get_cnn_block(depth):
    return Sequential([
        Conv2D(depth, 3, 1),
        BatchNormalization(),
        ReLU()
    ])

DEPTH = 32

cnn = Sequential([
    Reshape((28,28,1)),
    get_cnn_block(DEPTH),
    get_cnn_block(DEPTH*2),
    get_cnn_block(DEPTH*3),
    GlobalAveragePooling2D(),
    Dense(64, activation='relu')
])

feature_vector_a = cnn(img_a)
feature_vector_b = cnn(img_b)

concat = Concatenate()([feature_vector_a, feature_vector_b])

dense = Dense(64, activation='relu')(concat)

output = Dense(1, activation='sigmoid')(dense)

model = Model(inputs=[img_a, img_b], outputs=output)

model.summary()

Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 img_a (InputLayer)          [(None, 28, 28)]             0         []                            
                                                                                                  
 img_b (InputLayer)          [(None, 28, 28)]             0         []                            
                                                                                                  
 sequential_26 (Sequential)  (None, 64)                   81184     ['img_a[0][0]',               
                                                                     'img_b[0][0]']               
                                                                                                  
 concatenate_5 (Concatenate  (None, 128)                  0         ['sequential_26[0][0]', 

In [74]:
import numpy as np

random_indices = np.random.choice(X_train.shape[0], 200, replace=False)

X_train_sample, y_train_sample = X_train[random_indices], y_train[random_indices]

X_train_sample.shape, y_train_sample.shape

((200, 28, 28), (200,))

In [75]:
import itertools

def make_paired_dataset(X, y):
    X_pairs, y_pairs = [], []
    
    tuples = [(x1, y1) for x1, y1 in zip(X, y)]
    
    for t in itertools.product(tuples, tuples):
        pair_A, pair_A = t
        img_A, label_A = t[0]
        img_B, label_B = t[1]
        
        new_label = int(label_A == label_B)
        
        X_pairs.append([img_A, img_B])
        y_pairs.append(new_label)
        
    X_pairs = np.array(X_pairs)
    y_pairs = np.array(y_pairs)
    
    return X_pairs, y_pairs

In [76]:
X_train_pairs, y_train_pairs = make_paired_dataset(X_train_sample, y_train_sample)

X_train_pairs.shape, y_train_pairs.shape

((40000, 2, 28, 28), (40000,))

In [77]:
random_indices = np.random.choice(X_test.shape[0], 50, replace=False)

X_test_sample, y_test_sample = X_test[random_indices], y_test[random_indices]

X_test_sample.shape, y_test_sample.shape

((50, 28, 28), (50,))

In [78]:
X_test_pairs, y_test_pairs = make_paired_dataset(X_test_sample, y_test_sample)

X_test_pairs.shape, y_test_pairs.shape

((2500, 2, 28, 28), (2500,))

# Model Training

In [79]:
model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              metrics=['accuracy'])

In [80]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

es = EarlyStopping(patience=3)

cp_callback = ModelCheckpoint(filepath='model.cptk',
                              save_weights_only=True,
                              verbose=1)

In [83]:
model.fit(x=[X_train_pairs[:,0,:,:], X_train_pairs[:,1,:,:]],
          y=y_train_pairs,
          validation_data=([X_test_pairs[:,0,:,:], X_test_pairs[:,1,:,:]],
                           y_test_pairs),
          epochs = 4,
          batch_size = 32,
          callbacks=[es, cp_callback])

Epoch 1/4
Epoch 1: saving model to model.cptk
Epoch 2/4
Epoch 2: saving model to model.cptk
Epoch 3/4
Epoch 3: saving model to model.cptk
Epoch 4/4
Epoch 4: saving model to model.cptk


<keras.src.callbacks.History at 0x7f18e787fb20>

In [84]:
model.save_weights('model_0.cptk')