# Training of Siamese Networks

In [None]:
import numpy as np
import os

from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D, BatchNormalization
from tensorflow.keras.models import Sequential

from Siamese_model import SiameseTrainer

In [None]:
def create_pairs(x, digit_indices, num_classes):
    """Creates a balanced dataset of pairs for Siamese networks."""
    pairs = []
    labels = []

    # n: smallest number of instances among all classes in the dataset
    n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
    
    for d in range(num_classes):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = np.random.randint(1, num_classes)
            dn = (d + inc) % num_classes
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)

In [None]:
def train(self, x_train, y_train, x_val, y_val, batch_size=32, epochs=10):
    # Create digit_indices for training and validation sets
    digit_indices_train = [np.where(y_train == i)[0] for i in range(self.num_classes)]
    digit_indices_val = [np.where(y_val == i)[0] for i in range(self.num_classes)]

    # Create training pairs
    tr_pairs, tr_y = create_pairs(x_train, digit_indices_train, self.num_classes)

    # Create validation pairs
    val_pairs, val_y = create_pairs(x_val, digit_indices_val, self.num_classes)

    # Train the model
    self.history = self.model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
                                  batch_size=batch_size,
                                  epochs=epochs,
                                  validation_data=([val_pairs[:, 0], val_pairs[:, 1]], val_y))


In [None]:
# Create a CNNModel object
def simple_cnn(input_shape, num_classes):
    model = Sequential([
        Conv2D(32, (3,3), activation='relu', input_shape=input_shape),
        MaxPooling2D((2,2)),
        Conv2D(64, (3,3), activation='relu'),
        MaxPooling2D((2,2)),
        Conv2D(128, (3,3), activation='relu'),
        MaxPooling2D((2,2)),
        Flatten(),
        Dense(1024, activation='relu'),
        Dense(128, activation='relu'),
        Dense(num_classes, activation='softmax')
    ])
    return model

In [None]:
input_shape = (96, 96, 1)
num_classes = 3
epochs = 10

train_path = os.getcwd() + "/Split_Tanker_Bulk_Container_frugal_vv/train.csv"
validation_path = os.getcwd() + "/Split_Tanker_Bulk_Container_frugal_vv/validation.csv"
test_path = os.getcwd() + "/Split_Tanker_Bulk_Container_frugal_vv/test.csv"

image_dir = '../OpenSARShip/Categories/'

In [None]:
model = simple_cnn(input_shape, num_classes)

In [None]:
# Assuming you have defined `simple_cnn` as before and have loaded your dataset
siamese_network = SiameseTrainer(base_model_func=simple_cnn, input_shape=input_shape, num_classes=num_classes)

# Compile the model with the desired optimizer, loss, and metrics
siamese_network.compile_model(optimizer='adam', loss='binary_crossentropy')

# Train the model
siamese_network.train(x_train, y_train, x_val, y_val, batch_size=32, epochs=10)

# Plot training history
siamese_network.plot_training_history()

# Optionally, evaluate the model on a test set prepared in a similar way
# siamese_network.evaluate_model(x_test, y_test)
