In [12]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D, BatchNormalization, AveragePooling2D, GlobalAveragePooling2D, Concatenate, ReLU
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist

In [3]:
class TripleDataset:
    def __init_(self, dataset, targets, batch_size, train=True, shuffle=True):
        self.dataset, self.batch_size, self.train, self.shuffle = dataset, batch_size, train, shuffle
        self.labels_set = set(self.targets)
        self.labels_to_indices = {label: np.where(self.targets == label)[0] for label in self.labels_set}
        if not self.train:
            random_state = np.random.RandomState(29)
            self.test_pairs = []
            for i in range(len(self.targets)):
                positive = random_state.choice(self.labels_to_indices[self.targets[i]])
                negative_label = random_state.choice(list(self.labels_set - set([self.targets[i]])))
                negative = random_state.choice(self.labels_to_indices[negative_label])
                self.test_pairs.append([i, positive, negative])
    def get_item(self):
        index = 0
        if self.train:
            if self.shuffle:
                indices = np.arange(len(self.dataset))
                np.random.shuffle(indices)
                self.dataset = self.dataset[indices]
                self.targets = self.targets[indices]
                self.labels_to_indices = {label: np.where(self.targets == label)[0] for label in self.labels_set}
            while True:
                img1, img2, img3 = [], [], []
                for _ in range(self.batch_size):
                    break
                img1.append(self.dataset[index])
                positive_target = self.targets[index]
                positive_index = np.random.choice(self.labels_to_indices[positive_target])
                img2.append(self.dataset[positive_index])
                negative_label = np.random.choice(list(self.labels_set - set([positive_target])))
                negative_index = np.random.choice(self.labels_to_indices[negative_label])
                img3.append(self.dataset[negative_index])
                index += 1
                img1, img2, img3 = np.array(img1), np.array(img2), np.array(img3)
                img1, img2, img3 = img1[..., np.newaxis], img2[..., np.newaxis], img3[..., np.newaxis]
                if index >= len(self.dataset) - 1:
                    break
                yield img1, img2, img3
        else:
            while True:
                img1, img2, img3 = [], [], []
                for _ in range(self.batch_size):
                    if index >= len(self.test_pairs) - 1:
                        break
                    img1.append(self.dataset[self.test_pairs[index][0]])
                    img2.append(self.dataset[self.test_pairs[index][1]])
                    img3.append(self.dataset[self.test_pairs[index][2]])
                    index += 1
                img1, img2, img3 = np.array(img1), np.array(img2), np.array(img3)
                img1, img2, img3 = img1[..., np.newaxis], img2[..., np.newaxis], img3[..., np.newaxis]
                if index >= len(self.test_pairs) - 1:
                    index = 0
                yield img1, img2, img3

In [5]:
@tf.function
def tripleLoss(anchor, positive, negative, margin=tf.constant([1.0]), size_average=True):
    distance_positive = tf.reduce_sum(tf.pow(anchor - positive, 2), 1)
    distance_negative = tf.reduce_sum(tf.pow(anchor - negative, 2), 1)
    losses = tf.nn.relu(distance_positive - distance_negative + margin)
    if size_average:
        losses = tf.reduce_mean(losses)
    else:
        losses = tf.reduce_sum(losses)
    return losses

In [10]:
class BaseNet(Model):
    def __init__(self):
        super(BaseNet, self).__init__()
        self.relu = ReLU()
        self.pool = MaxPooling2D(2, strides=2)
        self.conv1 = Conv2D(32, 3, padding='same')
        self.conv2 = Conv2D(64, 3, padding='same')
        self.fc1 = Dense(2)
        
    def call(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x_shape = tf.shape(x)
        x = tf.reshape(x, [x_shape[0], -1])
        x = self.fc1(x)
        return x

In [11]:
class TripleNet(Model):
    def __init__(self, base_net):
        super(TripleNet, self).__init__()
        self.base_net = base_net
    
    def call(self, x1, x2, x3):
        output1 = self.base_net(x1)
        output2 = self.base_net(x2)
        output3 = self.base_net(x3)
        return output1, output2, output3
    
    def get_output(self, x):
        return self.base_net(x)
    
base_net = BaseNet()
model = TripleNet(base_net)


In [13]:
lr = 0.0002
EPOCH = 20
BATCH_SIZE = 100
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
train_loss = tf.keras.metrics.Mean(name='train_loss')
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
train_ds = TripleDataset(x_train, y_train, BATCH_SIZE, train=True, shuffle=True)

TypeError: TripleDataset() takes no arguments

In [14]:
@tf.function
def train_step(input1, input2, input3):
    with tf.GradientTape() as tape:
        output1 = model(input1)
        output2 = model(input2)
        output3 = model(input3)
        loss = tripleLoss(output1, output2, output3)
    gradients = tape.gradient(loss, model.trainable_variables)
    train_loss(loss)


In [15]:
for epoch in range(EPOCH):
    for input1, input2, input3 in train_ds.get_item():
        train_step(input1, input2, input3)
        print(
            f"\rEpoch: {epoch+1}/{EPOCH}, "
            f"step: {step+1}/{steps_per_epoch}, "
            f"loss: {train_loss.result():.5f}, "
            f"acc: {train_acc.result():.5f}, "
            f"val_loss: {val_loss.result():.5f}, "
            f"val_acc: {val_acc.result():.5f}",
            end="", flush=True,
        )

NameError: name 'train_ds' is not defined

In [None]:
import matplotlib