# Siamese Network Training for Cat Identification

This notebook contains the complete pipeline for training a Siamese network to identify individual cats from images. The process includes:

1.  **Dependency Checking**: Ensuring all required libraries are installed.
2.  **Dataset Analysis**: Exploring the dataset's structure and statistics.
3.  **Data Preparation**: Loading images and creating pairs/triplets for training.
4.  **Model Building**: Defining the Siamese network architecture with a pre-trained base model.
5.  **Training**: Training the model using both Contrastive and Triplet loss functions.
6.  **Evaluation**: Assessing the performance of the trained models.

## 1. Setup and Imports

In [None]:
import os
import sys
import subprocess
import random
from pathlib import Path
from collections import Counter
import json
from itertools import combinations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Data processing and ML
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder

# Deep learning
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.applications import EfficientNetB0, VGG16, MobileNetV2

# Image processing
import cv2
from PIL import Image
from tqdm import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

## 2. Dependency Check

In [None]:
def check_dependencies():
    """Check if all required dependencies are installed"""
    print("Checking dependencies...")

    # Map package names to their import names
    package_imports = {
        'tensorflow': 'tensorflow',
        'numpy': 'numpy',
        'pandas': 'pandas',
        'matplotlib': 'matplotlib',
        'seaborn': 'seaborn',
        'scikit-learn': 'sklearn',
        'opencv-python': 'cv2',
        'Pillow': 'PIL',
        'tqdm': 'tqdm'
    }

    missing_packages = []
    for package, import_name in package_imports.items():
        try:
            __import__(import_name)
            print(f"✓ {package}")
        except ImportError:
            missing_packages.append(package)
            print(f"✗ {package} - MISSING")

    if missing_packages:
        print(f"\nMissing packages: {', '.join(missing_packages)}")
        print("Please install them using:")
        print(f"pip install {' '.join(missing_packages)}")
        return False

    print("\nAll dependencies are installed!")
    return True

check_dependencies()

## 3. Dataset Analysis

In [None]:
def analyze_dataset(dataset_path):
    """Analyze the structure and statistics of the dataset"""
    print("Analyzing dataset structure...")
    
    # Get all cat folders
    cat_folders = [f for f in os.listdir(dataset_path) 
                  if f.startswith('cat_') and os.path.isdir(os.path.join(dataset_path, f))]
    cat_folders.sort()
    
    print(f"Found {len(cat_folders)} cat folders")
    
    # Analyze each cat folder
    cat_stats = []
    total_images = 0
    image_extensions = Counter()
    
    for cat_folder in cat_folders:
        cat_path = os.path.join(dataset_path, cat_folder)
        
        # Count images by extension
        image_files = [f for f in os.listdir(cat_path) 
                      if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
        extensions = [os.path.splitext(f)[1].lower() for f in image_files]
        image_extensions.update(extensions)
        
        # Check if info.json exists
        info_file = os.path.join(cat_path, 'info.json')
        has_info = os.path.exists(info_file)
        
        cat_stats.append({
            'cat_id': cat_folder,
            'num_images': len(image_files),
            'has_info': has_info,
            'extensions': list(set(extensions))
        })
        
        total_images += len(image_files)
    
    # Create summary
    summary = {
        'total_cats': len(cat_folders),
        'total_images': total_images,
        'avg_images_per_cat': total_images / len(cat_folders) if cat_folders else 0,
        'image_extensions': dict(image_extensions),
        'cats_with_info': sum(1 for cat in cat_stats if cat['has_info']),
        'cats_without_info': sum(1 for cat in cat_stats if not cat['has_info'])
    }
    
    return cat_stats, summary

def visualize_dataset_stats(cat_stats, summary):
    """Create visualizations of dataset statistics"""
    print("\nCreating dataset visualizations...")
    
    # Convert to DataFrame for easier plotting
    df = pd.DataFrame(cat_stats)
    
    # Create figure with subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # 1. Distribution of images per cat
    ax1.hist(df['num_images'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')
    ax1.set_xlabel('Number of Images per Cat')
    ax1.set_ylabel('Number of Cats')
    ax1.set_title('Distribution of Images per Cat')
    ax1.grid(True, alpha=0.3)
    
    # 2. Image count by cat (top 20)
    top_cats = df.nlargest(20, 'num_images')
    ax2.barh(range(len(top_cats)), top_cats['num_images'], color='lightcoral')
    ax2.set_yticks(range(len(top_cats)))
    ax2.set_yticklabels([cat_id[:15] + '...' if len(cat_id) > 15 else cat_id for cat_id in top_cats['cat_id']])
    ax2.set_xlabel('Number of Images')
    ax2.set_title('Top 20 Cats by Image Count')
    ax2.grid(True, alpha=0.3)
    
    # 3. Image extensions distribution
    extensions = summary['image_extensions']
    ax3.pie(extensions.values(), labels=extensions.keys(), autopct='%1.1f%%')
    ax3.set_title('Distribution of Image File Types')
    
    # 4. Info.json availability
    info_counts = [summary['cats_with_info'], summary['cats_without_info']]
    info_labels = ['With Info', 'Without Info']
    ax4.bar(info_labels, info_counts, color=['lightgreen', 'lightcoral'])
    ax4.set_ylabel('Number of Cats')
    ax4.set_title('Info.json Availability')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('dataset_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return df

dataset_path = 'post_processing' # Changed from 'siamese_dataset'
if not os.path.exists(dataset_path):
    print(f"Dataset path '{dataset_path}' not found!")
else:
    cat_stats, summary = analyze_dataset(dataset_path)
    print("\nDataset Summary:")
    print("=" * 50)
    print(f"Total cats: {summary['total_cats']}")
    print(f"Total images: {summary['total_images']}")
    print(f"Average images per cat: {summary['avg_images_per_cat']:.1f}")
    print(f"Image extensions: {summary['image_extensions']}")
    print(f"Cats with info.json: {summary['cats_with_info']}")
    print(f"Cats without info.json: {summary['cats_without_info']}")
    df_stats = visualize_dataset_stats(cat_stats, summary)

## 4. Configuration and Data Preparation

In [None]:
# Configuration
IMG_SIZE = 224
BATCH_SIZE = 32
EMBEDDING_DIM = 128
MARGIN = 1.0
LEARNING_RATE = 0.001
EPOCHS = 50
VALIDATION_SPLIT = 0.2
TEST_SPLIT = 0.1
DATASET_PATH = 'post_processing' # Changed from 'siamese_dataset'

In [None]:
class SiameseDataset:
    def __init__(self, dataset_path, img_size=IMG_SIZE):
        self.dataset_path = dataset_path
        self.img_size = img_size
        self.images = []
        self.labels = []
        self.label_encoder = LabelEncoder()
        
    def load_dataset(self, max_cats=None, min_images_per_cat=5):
        """Load images from the organized dataset"""
        print("Loading dataset...")
        
        cat_folders = [f for f in os.listdir(self.dataset_path) 
                      if f.startswith('cat_') and os.path.isdir(os.path.join(self.dataset_path, f))]
        cat_folders.sort()
        
        if max_cats:
            cat_folders = cat_folders[:max_cats]
        
        print(f"Found {len(cat_folders)} cat folders")
        
        for cat_folder in tqdm(cat_folders, desc="Loading cats"):
            cat_path = os.path.join(self.dataset_path, cat_folder)
            image_files = [f for f in os.listdir(cat_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            
            if len(image_files) < min_images_per_cat:
                print(f"Skipping {cat_folder} - only {len(image_files)} images")
                continue
            
            for img_file in image_files:
                img_path = os.path.join(cat_path, img_file)
                try:
                    img = cv2.imread(img_path)
                    if img is not None:
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        img = cv2.resize(img, (self.img_size, self.img_size))
                        img = img.astype(np.float32) / 255.0
                        self.images.append(img)
                        self.labels.append(cat_folder)
                except Exception as e:
                    print(f"Error loading {img_path}: {e}")
        
        self.images = np.array(self.images)
        self.labels = np.array(self.labels)
        self.encoded_labels = self.label_encoder.fit_transform(self.labels)
        
        print(f"Loaded {len(self.images)} images from {len(np.unique(self.labels))} cats")
        print(f"Image shape: {self.images.shape}")
        
        return self.images, self.encoded_labels
    
    def create_pairs(self, images, labels, num_pairs_per_image=2):
        """Create positive and negative pairs for training"""
        print("Creating training pairs...")
        pair_images = []
        pair_labels = []
        unique_labels = np.unique(labels)
        label_indices = {label: np.where(labels == label)[0] for label in unique_labels}
        
        for i in tqdm(range(len(images)), desc="Creating pairs"):
            current_image = images[i]
            current_label = labels[i]
            
            # Positive pairs
            pos_indices = label_indices[current_label]
            for _ in range(num_pairs_per_image):
                pos_idx = random.choice(pos_indices)
                if pos_idx != i:
                    pos_image = images[pos_idx]
                    pair_images.append([current_image, pos_image])
                    pair_labels.append(0)
            
            # Negative pairs
            neg_labels = [l for l in unique_labels if l != current_label]
            for _ in range(num_pairs_per_image):
                neg_label = random.choice(neg_labels)
                neg_idx = random.choice(label_indices[neg_label])
                neg_image = images[neg_idx]
                pair_images.append([current_image, neg_image])
                pair_labels.append(1)
        
        return np.array(pair_images), np.array(pair_labels)
    
    def create_triplets(self, images, labels, num_triplets_per_image=1):
        """Create triplets for triplet loss training"""
        print("Creating training triplets...")
        anchor_images, positive_images, negative_images = [], [], []
        unique_labels = np.unique(labels)
        label_indices = {label: np.where(labels == label)[0] for label in unique_labels}
        
        for i in tqdm(range(len(images)), desc="Creating triplets"):
            anchor_image = images[i]
            anchor_label = labels[i]
            
            for _ in range(num_triplets_per_image):
                pos_indices = label_indices[anchor_label]
                pos_idx = random.choice(pos_indices)
                if pos_idx != i:
                    positive_image = images[pos_idx]
                    neg_labels = [l for l in unique_labels if l != anchor_label]
                    neg_label = random.choice(neg_labels)
                    neg_idx = random.choice(label_indices[neg_label])
                    negative_image = images[neg_idx]
                    anchor_images.append(anchor_image)
                    positive_images.append(positive_image)
                    negative_images.append(negative_image)
        
        return (np.array(anchor_images), np.array(positive_images), np.array(negative_images))

## 5. Model Architecture

In [None]:
class SiameseModel:
    def __init__(self, input_shape, embedding_dim=EMBEDDING_DIM, base_model='efficientnet'):
        self.input_shape = input_shape
        self.embedding_dim = embedding_dim
        self.base_model_name = base_model
        
    def create_embedding_model(self):
        """Create the base embedding model"""
        if self.base_model_name == 'efficientnet':
            base_model = EfficientNetB0(include_top=False, weights='imagenet', input_shape=self.input_shape)
        elif self.base_model_name == 'vgg':
            base_model = VGG16(include_top=False, weights='imagenet', input_shape=self.input_shape)
        elif self.base_model_name == 'mobilenet':
            base_model = MobileNetV2(include_top=False, weights='imagenet', input_shape=self.input_shape)
        else:
            raise ValueError(f"Unsupported base model: {self.base_model_name}")
        
        for layer in base_model.layers:
            layer.trainable = False
        
        inputs = Input(shape=self.input_shape)
        x = base_model(inputs)
        x = Flatten()(x)
        x = Dense(512, activation='relu')(x)
        x = Dropout(0.3)(x)
        x = Dense(256, activation='relu')(x)
        x = Dropout(0.2)(x)
        outputs = Dense(self.embedding_dim, activation=None)(x)
        
        return Model(inputs, outputs, name='embedding_model')
    
    def create_siamese_model(self, loss_type='contrastive'):
        """Create the complete Siamese model"""
        embedding_model = self.create_embedding_model()
        
        if loss_type == 'contrastive':
            return self._create_contrastive_model(embedding_model)
        elif loss_type == 'triplet':
            return self._create_triplet_model(embedding_model)
        else:
            raise ValueError(f"Unsupported loss type: {loss_type}")
    
    def _create_contrastive_model(self, embedding_model):
        input_a = Input(shape=self.input_shape)
        input_b = Input(shape=self.input_shape)
        embedding_a = embedding_model(input_a)
        embedding_b = embedding_model(input_b)
        distance = self._euclidean_distance([embedding_a, embedding_b])
        model = Model(inputs=[input_a, input_b], outputs=distance)
        model.compile(loss=self._contrastive_loss, optimizer=Adam(learning_rate=LEARNING_RATE), metrics=['accuracy'])
        return model
    
    def _create_triplet_model(self, embedding_model):
        anchor_input = Input(shape=self.input_shape)
        positive_input = Input(shape=self.input_shape)
        negative_input = Input(shape=self.input_shape)
        anchor_embedding = embedding_model(anchor_input)
        positive_embedding = embedding_model(positive_input)
        negative_embedding = embedding_model(negative_input)
        loss = self._triplet_loss([anchor_embedding, positive_embedding, negative_embedding])
        model = Model(inputs=[anchor_input, positive_input, negative_input], outputs=loss)
        model.compile(loss=self._identity_loss, optimizer=Adam(learning_rate=LEARNING_RATE))
        return model
    
    def _euclidean_distance(self, vectors):
        (feats_a, feats_b) = vectors
        sum_squared = tf.reduce_sum(tf.square(feats_a - feats_b), axis=1, keepdims=True)
        return tf.sqrt(tf.maximum(sum_squared, tf.keras.backend.epsilon()))
    
    def _contrastive_loss(self, y_true, y_pred, margin=MARGIN):
        y_true = tf.cast(y_true, y_pred.dtype)
        squared_preds = tf.square(y_pred)
        squared_margin = tf.square(tf.maximum(margin - y_pred, 0))
        loss = tf.reduce_mean((1 - y_true) * squared_preds + y_true * squared_margin)
        return loss
    
    def _triplet_loss(self, inputs, alpha=0.2):
        anchor, positive, negative = inputs
        pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=1)
        neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=1)
        basic_loss = pos_dist - neg_dist + alpha
        loss = tf.maximum(basic_loss, 0.0)
        return loss
    
    def _identity_loss(self, y_true, y_pred):
        return tf.reduce_mean(y_pred)

## 6. Training Pipeline

In [None]:
class SiameseTrainer:
    def __init__(self, dataset, model, loss_type='contrastive'):
        self.dataset = dataset
        self.model = model
        self.loss_type = loss_type
        self.history = None
        
    def train(self, train_data, val_data, epochs=EPOCHS, batch_size=BATCH_SIZE):
        print(f"\nTraining Siamese model with {self.loss_type} loss...")
        callbacks = [
            ModelCheckpoint(f'best_siamese_{self.loss_type}.h5', monitor='val_loss', save_best_only=True, verbose=1),
            EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1),
            ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7, verbose=1)
        ]
        
        if self.loss_type == 'contrastive':
            self.history = self.model.fit(
                [train_data[0][:, 0], train_data[0][:, 1]], train_data[1],
                validation_data=([val_data[0][:, 0], val_data[0][:, 1]], val_data[1]),
                batch_size=batch_size, epochs=epochs, callbacks=callbacks, verbose=1
            )
        elif self.loss_type == 'triplet':
            self.history = self.model.fit(
                [train_data[0], train_data[1], train_data[2]], np.ones(len(train_data[0])),
                validation_data=([val_data[0], val_data[1], val_data[2]], np.ones(len(val_data[0]))),
                batch_size=batch_size, epochs=epochs, callbacks=callbacks, verbose=1
            )
        return self.history
    
    def evaluate(self, test_data, test_labels):
        print("\nEvaluating model...")
        embedding_model = self.model.layers[2]
        test_embeddings = embedding_model.predict(test_data)
        
        predictions = []
        for i, embedding in enumerate(test_embeddings):
            distances = []
            for j, other_embedding in enumerate(test_embeddings):
                if i != j:
                    dist = np.linalg.norm(embedding - other_embedding)
                    distances.append((dist, test_labels[j]))
            distances.sort(key=lambda x: x[0])
            predictions.append(distances[0][1])
        
        accuracy = accuracy_score(test_labels, predictions)
        precision = precision_score(test_labels, predictions, average='weighted')
        recall = recall_score(test_labels, predictions, average='weighted')
        f1 = f1_score(test_labels, predictions, average='weighted')
        
        print(f"Test Accuracy: {accuracy:.4f}")
        print(f"Test Precision: {precision:.4f}")
        print(f"Test Recall: {recall:.4f}")
        print(f"Test F1-Score: {f1:.4f}")
        
        return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}
    
    def plot_training_history(self):
        if self.history is None: return
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        ax1.plot(self.history.history['loss'], label='Training Loss')
        ax1.plot(self.history.history['val_loss'], label='Validation Loss')
        ax1.set_title('Model Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.grid(True)
        
        if 'accuracy' in self.history.history:
            ax2.plot(self.history.history['accuracy'], label='Training Accuracy')
            ax2.plot(self.history.history['val_accuracy'], label='Validation Accuracy')
            ax2.set_title('Model Accuracy')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Accuracy')
            ax2.legend()
            ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(f'training_history_{self.loss_type}.png', dpi=300, bbox_inches='tight')
        plt.show()

### Main Execution Block

In [None]:
# Step 1: Data Preparation
print("Step 1: Data Preparation")
dataset = SiameseDataset(DATASET_PATH, img_size=IMG_SIZE)
images, labels = dataset.load_dataset(max_cats=20, min_images_per_cat=5)

X_temp, X_test, y_temp, y_test = train_test_split(images, labels, test_size=TEST_SPLIT, stratify=labels, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=VALIDATION_SPLIT, stratify=y_temp, random_state=42)

print(f"\nTraining set: {len(X_train)} images")
print(f"Validation set: {len(X_val)} images")
print(f"Test set: {len(X_test)} images")

# Step 2: Model Architecture
print("\nStep 2: Model Architecture")
siamese_model_builder = SiameseModel(input_shape=(IMG_SIZE, IMG_SIZE, 3), embedding_dim=EMBEDDING_DIM, base_model='efficientnet')

# Step 3: Training Pipeline
print("\nStep 3: Training Pipeline")

# Train with contrastive loss
contrastive_model = siamese_model_builder.create_siamese_model(loss_type='contrastive')
train_pairs, train_pair_labels = dataset.create_pairs(X_train, y_train)
val_pairs, val_pair_labels = dataset.create_pairs(X_val, y_val)
contrastive_trainer = SiameseTrainer(dataset, contrastive_model, loss_type='contrastive')
contrastive_history = contrastive_trainer.train((train_pairs, train_pair_labels), (val_pairs, val_pair_labels))
contrastive_metrics = contrastive_trainer.evaluate(X_test, y_test)
contrastive_trainer.plot_training_history()

# Train with triplet loss
triplet_model = siamese_model_builder.create_siamese_model(loss_type='triplet')
train_triplets = dataset.create_triplets(X_train, y_train)
val_triplets = dataset.create_triplets(X_val, y_val)
triplet_trainer = SiameseTrainer(dataset, triplet_model, loss_type='triplet')
triplet_history = triplet_trainer.train(train_triplets, val_triplets)
triplet_metrics = triplet_trainer.evaluate(X_test, y_test)
triplet_trainer.plot_training_history()

# Save results
results_df = pd.DataFrame([
    {'loss_type': 'contrastive', **contrastive_metrics},
    {'loss_type': 'triplet', **triplet_metrics}
])
results_df.to_csv('siamese_training_results.csv', index=False)
print(f"\nResults saved to siamese_training_results.csv")

# Print summary
print("\nTraining Summary:")
print("=" * 50)
print("Contrastive Loss Results:")
print(results_df[results_df['loss_type'] == 'contrastive'])
print("\nTriplet Loss Results:")
print(results_df[results_df['loss_type'] == 'triplet'])