In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import tensorflow.keras.backend as K
import scipy.ndimage as nd
import random
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow logs
import sys
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "setuptools"])

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

# ==============================
# Task 1: Dataset Preprocessing
# ==============================

def load_and_preprocess_mnist():
    """Load and preprocess the MNIST dataset."""
    # Load MNIST dataset
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Reshape data for convolution operations
    x_train = x_train.reshape(-1, 28, 28, 1).astype('float32')
    x_test = x_test.reshape(-1, 28, 28, 1).astype('float32')

    # Normalize to [0, 1]
    x_train = x_train / 255.0
    x_test = x_test / 255.0

    # Split into train and test sets (80/20)
    x_train, x_val, y_train, y_val = train_test_split(
        x_train, y_train, test_size=0.2, random_state=42, stratify=y_train
    )

    return (x_train, y_train), (x_val, y_val), (x_test, y_test)

def elastic_deformation(image, alpha=36, sigma=4):
    """Apply elastic deformation to an image."""
    shape = image.shape

    # Handle 3D images (with channel dimension)
    working_shape = shape
    if len(shape) == 3 and shape[2] == 1:
        working_shape = shape[:2]  # Use only height and width for transformations
        image = image.squeeze()  # Remove channel dimension temporarily

    dx = nd.gaussian_filter((np.random.rand(*working_shape) * 2 - 1), sigma) * alpha
    dy = nd.gaussian_filter((np.random.rand(*working_shape) * 2 - 1), sigma) * alpha

    x, y = np.meshgrid(np.arange(working_shape[0]), np.arange(working_shape[1]), indexing='ij')
    indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))

    distorted_image = nd.map_coordinates(image, indices, order=1, mode='reflect')

    # Reshape back to original shape
    if len(shape) == 3 and shape[2] == 1:
        return distorted_image.reshape(working_shape).reshape(shape)
    else:
        return distorted_image.reshape(shape)

def augment_data(images, labels, n_augmented=5000):
    """Apply elastic deformations to augment the dataset."""
    indices = np.random.choice(len(images), n_augmented, replace=True)
    augmented_images = []
    augmented_labels = []

    for idx in indices:
        img = images[idx]
        label = labels[idx]

        deformed = elastic_deformation(img)
        augmented_images.append(deformed)
        augmented_labels.append(label)

    # Convert to numpy arrays
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)

    # Combine original and augmented data
    combined_images = np.concatenate([images, augmented_images])
    combined_labels = np.concatenate([labels, augmented_labels])

    return combined_images, combined_labels

# ==============================
# Task 2: Meta-Learning Framework
# ==============================

class MetaLearningDataGenerator:
    """Generate episodes for meta-learning tasks."""

    def __init__(self, x, y, n_way=5, k_shot=5, q_query=5):
        """
        Initialize the meta-learning data generator.

        Args:
            x: Image data
            y: Labels
            n_way: Number of classes in each episode
            k_shot: Number of examples per class in the support set
            q_query: Number of examples per class in the query set
        """
        self.x = x
        self.y = y
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query

        # Group samples by class
        self.class_indices = {}
        for i in range(10):  # MNIST has 10 classes
            self.class_indices[i] = np.where(y == i)[0]

    def generate_episode(self):
        """Generate a single episode with support and query sets."""
        # Randomly sample n_way classes
        episode_classes = np.random.choice(10, self.n_way, replace=False)

        support_x = []
        support_y = []
        query_x = []
        query_y = []

        for i, cls in enumerate(episode_classes):
            # Get indices for this class
            cls_indices = self.class_indices[cls]

            # Sample for support and query sets
            selected_indices = np.random.choice(cls_indices, self.k_shot + self.q_query, replace=False)
            support_indices = selected_indices[:self.k_shot]
            query_indices = selected_indices[self.k_shot:self.k_shot + self.q_query]

            # Gather support samples
            support_x.append(self.x[support_indices])
            support_y.append(np.full(self.k_shot, i))  # Use relative class index (0 to n_way-1)

            # Gather query samples
            query_x.append(self.x[query_indices])
            query_y.append(np.full(self.q_query, i))  # Use relative class index

        # Convert to numpy arrays
        support_x = np.concatenate(support_x)
        support_y = np.concatenate(support_y)
        query_x = np.concatenate(query_x)
        query_y = np.concatenate(query_y)

        # Create one-hot encoded labels
        support_y_onehot = tf.keras.utils.to_categorical(support_y, self.n_way)
        query_y_onehot = tf.keras.utils.to_categorical(query_y, self.n_way)

        return support_x, support_y_onehot, query_x, query_y_onehot, query_y, episode_classes

# ==============================
# Task 3: Few-Shot Learning Models
# ==============================

def create_embedding_network():
    """Create a convolutional network for feature embedding."""
    embedding_network = models.Sequential([
        layers.Conv2D(64, 3, padding='same', activation='relu', input_shape=(28, 28, 1)),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2),

        layers.Conv2D(128, 3, padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(2),

        layers.Conv2D(256, 3, padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.GlobalAveragePooling2D(),

        layers.Dense(256, activation='relu'),
        layers.BatchNormalization()
    ])

    return embedding_network

# 3.1 Prototypical Network
class PrototypicalNetwork(tf.keras.Model):
    """Prototypical Network implementation."""

    def __init__(self, embedding_network):
        super(PrototypicalNetwork, self).__init__()
        self.embedding_network = embedding_network

    def compute_prototypes(self, support_x, support_y):
        """Compute class prototypes from support set."""
        embeddings = self.embedding_network(support_x)
        n_way = support_y.shape[1]
        prototypes = tf.zeros((n_way, embeddings.shape[1]))

        for i in range(n_way):
            mask = support_y[:, i]
            class_embeddings = embeddings * tf.expand_dims(mask, axis=1)
            sum_embeddings = tf.reduce_sum(class_embeddings, axis=0)
            n_examples = tf.reduce_sum(mask)
            prototypes = tf.tensor_scatter_nd_update(
                prototypes,
                tf.constant([[i]]),
                tf.expand_dims(sum_embeddings / n_examples, axis=0)
            )

        return prototypes

    def compute_distances(self, prototypes, query_embeddings):
        """Compute Euclidean distances between query embeddings and prototypes."""
        n_queries = query_embeddings.shape[0]
        n_prototypes = prototypes.shape[0]

        # Reshape for broadcasting
        query_embeddings = tf.reshape(query_embeddings, (n_queries, 1, -1))
        prototypes = tf.reshape(prototypes, (1, n_prototypes, -1))

        # Compute squared Euclidean distances
        distances = tf.reduce_sum(tf.square(query_embeddings - prototypes), axis=2)
        return distances

    def call(self, inputs):
        """Forward pass for prototypical network."""
        support_x, support_y, query_x = inputs

        # Embed support and query examples
        prototypes = self.compute_prototypes(support_x, support_y)
        query_embeddings = self.embedding_network(query_x)

        # Compute distances and convert to probabilities (negative distances)
        distances = self.compute_distances(prototypes, query_embeddings)
        logits = -distances

        return logits

def train_prototypical_network(train_data, val_data, n_way=5, k_shot=5, q_query=5,
                              n_episodes=1000, val_episodes=100, lr=0.001):
    """Train a prototypical network."""
    (x_train, y_train) = train_data
    (x_val, y_val) = val_data

    # Data generators
    train_generator = MetaLearningDataGenerator(x_train, y_train, n_way, k_shot, q_query)
    val_generator = MetaLearningDataGenerator(x_val, y_val, n_way, k_shot, q_query)

    # Create model and optimizer
    embedding_network = create_embedding_network()
    proto_net = PrototypicalNetwork(embedding_network)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    # Loss function
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

    # Training loop
    train_losses = []
    val_accuracies = []

    for episode in range(n_episodes):
        # Generate episode
        support_x, support_y, query_x, query_y_onehot, query_y, _ = train_generator.generate_episode()

        # Training step
        with tf.GradientTape() as tape:
            logits = proto_net((support_x, support_y, query_x))
            loss = loss_fn(query_y_onehot, logits)

        # Update weights
        gradients = tape.gradient(loss, proto_net.trainable_variables)
        optimizer.apply_gradients(zip(gradients, proto_net.trainable_variables))

        train_losses.append(loss.numpy())

        # Validation (every 50 episodes)
        if (episode + 1) % 50 == 0:
            val_accuracies_batch = []

            for _ in range(val_episodes):
                support_x, support_y, query_x, query_y_onehot, query_y, _ = val_generator.generate_episode()
                logits = proto_net((support_x, support_y, query_x))
                predictions = tf.argmax(logits, axis=1)
                accuracy = tf.reduce_mean(tf.cast(tf.equal(predictions, query_y), tf.float32))
                val_accuracies_batch.append(accuracy.numpy())

            mean_val_accuracy = np.mean(val_accuracies_batch)
            val_accuracies.append(mean_val_accuracy)

            print(f"Episode {episode+1}/{n_episodes}, Loss: {np.mean(train_losses[-50:]):.4f}, "
                  f"Validation Accuracy: {mean_val_accuracy:.4f}")

    return proto_net, train_losses, val_accuracies

# 3.2 Siamese Network
class SiameseNetwork(tf.keras.Model):
    """Siamese Network implementation for few-shot learning."""

    def __init__(self, embedding_network):
        super(SiameseNetwork, self).__init__()
        self.embedding_network = embedding_network
        # Create the dense layer during initialization
        self.classifier = layers.Dense(1, activation='sigmoid')

    def call(self, inputs):
        """Forward pass for siamese networks."""
        image1, image2 = inputs

        # Get embeddings
        embedding1 = self.embedding_network(image1)
        embedding2 = self.embedding_network(image2)

        # L1 distance
        distance = tf.abs(embedding1 - embedding2)

        # Use the pre-created dense layer
        output = self.classifier(distance)

        return output

def contrastive_loss(y_true, y_pred, margin=1.0):
    """Contrastive loss for siamese networks."""
    y_true = tf.cast(y_true, tf.float32)

    # For similar pairs (y_true=1), we want small distances
    # For dissimilar pairs (y_true=0), we want distances larger than margin
    squared_pred = tf.square(y_pred)
    margin_square = tf.square(tf.maximum(margin - y_pred, 0))

    loss = y_true * squared_pred + (1 - y_true) * margin_square
    return tf.reduce_mean(loss)

def create_siamese_pairs(x, y, n_pairs=10000):
    """Create positive and negative pairs for Siamese network training."""
    n_samples = x.shape[0]
    pairs = []
    labels = []

    # Create positive pairs (same class)
    for _ in range(n_pairs // 2):
        # Randomly select a class
        cls = np.random.randint(0, 10)
        idx1, idx2 = np.random.choice(np.where(y == cls)[0], 2, replace=False)
        pairs.append([x[idx1], x[idx2]])
        labels.append(1)

    # Create negative pairs (different classes)
    for _ in range(n_pairs // 2):
        cls1, cls2 = np.random.choice(10, 2, replace=False)
        idx1 = np.random.choice(np.where(y == cls1)[0])
        idx2 = np.random.choice(np.where(y == cls2)[0])
        pairs.append([x[idx1], x[idx2]])
        labels.append(0)

    return np.array(pairs), np.array(labels)

def train_siamese_network(train_data, val_data, n_pairs=10000, val_pairs=1000,
                         batch_size=64, epochs=10, lr=0.001):
    """Train a Siamese network."""
    (x_train, y_train) = train_data
    (x_val, y_val) = val_data

    # Create pairs
    train_pairs, train_labels = create_siamese_pairs(x_train, y_train, n_pairs)
    val_pairs, val_labels = create_siamese_pairs(x_val, y_val, val_pairs)

    # Create model
    embedding_network = create_embedding_network()
    siamese_net = SiameseNetwork(embedding_network)

    # Prepare inputs
    input1 = layers.Input((28, 28, 1))
    input2 = layers.Input((28, 28, 1))
    output = siamese_net([input1, input2])

    model = tf.keras.Model(inputs=[input1, input2], outputs=output)

    # Compile model
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
        loss=tf.keras.losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )

    # Train model
    history = model.fit(
        [train_pairs[:, 0], train_pairs[:, 1]], train_labels,
        validation_data=([val_pairs[:, 0], val_pairs[:, 1]], val_labels),
        batch_size=batch_size,
        epochs=epochs,
        verbose=1
    )

    return siamese_net, history

def evaluate_siamese_network(siamese_net, test_data, n_way=5, k_shot=5, q_query=15, n_episodes=100):
    """Evaluate Siamese network using nearest neighbor classification."""
    (x_test, y_test) = test_data
    test_generator = MetaLearningDataGenerator(x_test, y_test, n_way, k_shot, q_query)

    accuracies = []

    for _ in range(n_episodes):
        support_x, support_y_onehot, query_x, query_y_onehot, query_y, _ = test_generator.generate_episode()

        # Get embeddings for support and query sets
        support_embeddings = siamese_net.embedding_network.predict(support_x)
        query_embeddings = siamese_net.embedding_network.predict(query_x)

        # For each query, find the nearest support sample
        predictions = []

        for q_emb in query_embeddings:
            # Compute distances to all support embeddings
            distances = np.sum(np.square(support_embeddings - q_emb), axis=1)

            # Find the closest support example
            min_idx = np.argmin(distances)
            # Get the class of the closest example
            pred_class = np.argmax(support_y_onehot[min_idx])
            predictions.append(pred_class)

        # Calculate accuracy
        accuracy = np.mean(np.array(predictions) == query_y)
        accuracies.append(accuracy)

    return np.mean(accuracies)

# ==============================
# Task 4: One-Shot Learning Models
# ==============================

# 4.1 Matching Network
class MatchingNetwork(tf.keras.Model):
    """Matching Network implementation."""

    def __init__(self, embedding_network):
        super(MatchingNetwork, self).__init__()
        self.embedding_network = embedding_network

    def call(self, inputs):
        """Forward pass for matching network."""
        support_x, support_y, query_x = inputs

        # Embed support and query examples
        support_embeddings = self.embedding_network(support_x)
        query_embeddings = self.embedding_network(query_x)

        # Compute cosine similarity
        support_embeddings_norm = tf.nn.l2_normalize(support_embeddings, axis=1)
        query_embeddings_norm = tf.nn.l2_normalize(query_embeddings, axis=1)

        # Compute cosine similarity between query and support embeddings
        similarities = tf.matmul(query_embeddings_norm, support_embeddings_norm, transpose_b=True)

        # Compute attention weights
        attention = tf.nn.softmax(similarities, axis=1)

        # Weighted sum of support labels
        logits = tf.matmul(attention, support_y)

        return logits

def train_matching_network(train_data, val_data, n_way=5, k_shot=1, q_query=5,
                          n_episodes=1000, val_episodes=100, lr=0.001):
    """Train a matching network for one-shot learning."""
    (x_train, y_train) = train_data
    (x_val, y_val) = val_data

    # Data generators
    train_generator = MetaLearningDataGenerator(x_train, y_train, n_way, k_shot, q_query)
    val_generator = MetaLearningDataGenerator(x_val, y_val, n_way, k_shot, q_query)

    # Create model and optimizer
    embedding_network = create_embedding_network()
    matching_net = MatchingNetwork(embedding_network)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    # Loss function
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

    # Training loop
    train_losses = []
    val_accuracies = []

    for episode in range(n_episodes):
        # Generate episode
        support_x, support_y, query_x, query_y_onehot, query_y, _ = train_generator.generate_episode()

        # Training step
        with tf.GradientTape() as tape:
            logits = matching_net((support_x, support_y, query_x))
            loss = loss_fn(query_y_onehot, logits)

        # Update weights
        gradients = tape.gradient(loss, matching_net.trainable_variables)
        optimizer.apply_gradients(zip(gradients, matching_net.trainable_variables))

        train_losses.append(loss.numpy())

        # Validation (every 50 episodes)
        if (episode + 1) % 50 == 0:
            val_accuracies_batch = []

            for _ in range(val_episodes):
                support_x, support_y, query_x, query_y_onehot, query_y, _ = val_generator.generate_episode()
                logits = matching_net((support_x, support_y, query_x))
                predictions = tf.argmax(logits, axis=1)
                accuracy = tf.reduce_mean(tf.cast(tf.equal(predictions, query_y), tf.float32))
                val_accuracies_batch.append(accuracy.numpy())

            mean_val_accuracy = np.mean(val_accuracies_batch)
            val_accuracies.append(mean_val_accuracy)

            print(f"Episode {episode+1}/{n_episodes}, Loss: {np.mean(train_losses[-50:]):.4f}, "
                  f"Validation Accuracy: {mean_val_accuracy:.4f}")

    return matching_net, train_losses, val_accuracies

# 4.2 Siamese Network for One-Shot Learning
def train_oneshot_siamese_network(train_data, val_data, n_way=5, k_shot=1, q_query=5,
                                 n_episodes=1000, val_episodes=100, lr=0.0001):
    """Train Siamese Network optimized for one-shot learning."""
    (x_train, y_train) = train_data
    (x_val, y_val) = val_data

    # Create embedding network
    embedding_network = create_embedding_network()

    # Create a proper Siamese model using Keras functional API
    input1 = layers.Input((28, 28, 1))
    input2 = layers.Input((28, 28, 1))

    # Get embeddings using the same network (shared weights)
    embedding1 = embedding_network(input1)
    embedding2 = embedding_network(input2)

    # Use Lambda layer for L1 distance to ensure compatibility
    l1_distance = layers.Lambda(lambda tensors: tf.abs(tensors[0] - tensors[1]))([embedding1, embedding2])

    # Prediction layer
    prediction = layers.Dense(1, activation='sigmoid')(l1_distance)

    # Create model
    siamese_model = tf.keras.Model(inputs=[input1, input2], outputs=prediction)

    # Compile model with binary cross-entropy loss
    siamese_model.compile(
        optimizer=tf.keras.optimizers.Adam(lr),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    # Create training pairs with emphasis on hard negatives
    def create_pairs_batch():
        n_classes = 10  # MNIST classes
        pairs = []
        labels = []

        # Create positive pairs - same digit
        for _ in range(n_episodes // 2):
            digit = np.random.randint(0, n_classes)
            idx1, idx2 = np.random.choice(np.where(y_train == digit)[0], 2, replace=False)
            pairs.append([x_train[idx1], x_train[idx2]])
            labels.append(1)

        # Create negative pairs - different digits
        for _ in range(n_episodes // 2):
            # Select two different digits
            digit1, digit2 = np.random.choice(n_classes, 2, replace=False)
            idx1 = np.random.choice(np.where(y_train == digit1)[0])
            idx2 = np.random.choice(np.where(y_train == digit2)[0])
            pairs.append([x_train[idx1], x_train[idx2]])
            labels.append(0)

        return np.array(pairs), np.array(labels)

    # Train with batch-based approach
    batch_size = 32
    for epoch in range(20):  # More epochs for one-shot learning
        # Create training pairs
        pairs, labels = create_pairs_batch()

        # Train for one epoch
        history = siamese_model.fit(
            [pairs[:, 0], pairs[:, 1]],
            labels,
            batch_size=batch_size,
            epochs=1,
            verbose=0
        )

        # Evaluate on validation set
        val_generator = MetaLearningDataGenerator(x_val, y_val, n_way, k_shot, q_query)
        val_accuracies = []

        for _ in range(val_episodes):
            support_x, support_y_onehot, query_x, query_y_onehot, query_y, _ = val_generator.generate_episode()

            # Get embeddings
            support_embeddings = embedding_network.predict(support_x)
            query_embeddings = embedding_network.predict(query_x)

            # For each query, predict class by nearest neighbor in support set
            correct = 0
            total = 0

            for i, q_embedding in enumerate(query_embeddings):
                # Compute distances to support examples
                distances = np.sum(np.abs(support_embeddings - q_embedding), axis=1)

                # Predicted class is the class of the nearest support example
                min_idx = np.argmin(distances)
                predicted_class = np.argmax(support_y_onehot[min_idx])
                true_class = query_y[i]

                if predicted_class == true_class:
                    correct += 1
                total += 1

            val_accuracies.append(correct / total)

        # Print progress
        print(f"Epoch {epoch+1}/20, Loss: {history.history['loss'][0]:.4f}, "
              f"One-shot Val Accuracy: {np.mean(val_accuracies):.4f}")

    return embedding_network

# ==============================
# Task 5: Performance Evaluation and Analysis
# ==============================

def evaluate_model(model, test_data, model_type, n_way=5, k_shot=1, n_episodes=100):
    """Evaluate meta-learning model and return metrics."""
    (x_test, y_test) = test_data
    test_generator = MetaLearningDataGenerator(x_test, y_test, n_way, k_shot, 15)

    y_true_all = []
    y_pred_all = []

    for _ in range(n_episodes):
        support_x, support_y, query_x, query_y_onehot, query_y, _ = test_generator.generate_episode()

        if model_type == 'prototypical':
            logits = model((support_x, support_y, query_x))
            predictions = tf.argmax(logits, axis=1).numpy()

        elif model_type == 'matching':
            logits = model((support_x, support_y, query_x))
            predictions = tf.argmax(logits, axis=1).numpy()

        elif model_type == 'siamese':
            # For Siamese, we're passing the embedding network
            support_embeddings = model.predict(support_x)
            query_embeddings = model.predict(query_x)

            predictions = []
            for q_emb in query_embeddings:
                distances = np.sum(np.abs(support_embeddings - q_emb), axis=1)
                min_idx = np.argmin(distances)
                pred_class = np.argmax(support_y[min_idx])
                predictions.append(pred_class)

            predictions = np.array(predictions)

        y_true_all.extend(query_y)
        y_pred_all.extend(predictions)

    # Convert to numpy arrays
    y_true_all = np.array(y_true_all)
    y_pred_all = np.array(y_pred_all)

    # Calculate metrics
    accuracy = accuracy_score(y_true_all, y_pred_all)
    precision = precision_score(y_true_all, y_pred_all, average='macro')
    recall = recall_score(y_true_all, y_pred_all, average='macro')
    f1 = f1_score(y_true_all, y_pred_all, average='macro')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def compare_models(proto_metrics, matching_metrics, siamese_metrics):
    """Compare performance of different meta-learning models."""
    models = ['Prototypical', 'Matching', 'Siamese']
    metrics = [proto_metrics, matching_metrics, siamese_metrics]

    print("\nModel Comparison:")
    print("=" * 60)
    print(f"{'Model':<15} {'Accuracy':<10} {'Precision':<10} {'Recall':<10} {'F1 Score':<10}")
    print("-" * 60)

    for name, metric in zip(models, metrics):
        print(f"{name:<15} {metric['accuracy']:.4f}     {metric['precision']:.4f}     "
              f"{metric['recall']:.4f}     {metric['f1']:.4f}")

    print("=" * 60)

    # Create bar plot for accuracy comparison
    plt.figure(figsize=(10, 6))
    plt.bar(models, [m['accuracy'] for m in metrics])
    plt.xlabel('Model')
    plt.ylabel('Accuracy')
    plt.title('Accuracy Comparison of Meta-Learning Models')
    plt.ylim(0, 1.0)
    plt.savefig('model_comparison.png')
    plt.close()

def analyze_shot_impact(model, test_data, model_type, n_way=5, shots_range=[1, 2, 5, 10], n_episodes=50):
    """Analyze impact of shot count on model performance."""
    accuracies = []

    for k_shot in shots_range:
        test_generator = MetaLearningDataGenerator(test_data[0], test_data[1], n_way, k_shot, 15)

        episode_accuracies = []
        for _ in range(n_episodes):
            support_x, support_y, query_x, query_y_onehot, query_y, _ = test_generator.generate_episode()

            if model_type == 'prototypical':
                logits = model((support_x, support_y, query_x))
                predictions = tf.argmax(logits, axis=1).numpy()

            elif model_type == 'matching':
                logits = model((support_x, support_y, query_x))
                predictions = tf.argmax(logits, axis=1).numpy()

            elif model_type == 'siamese':
                # For Siamese, we're passing the embedding network
                support_embeddings = model.predict(support_x)
                query_embeddings = model.predict(query_x)

                predictions = []
                for q_emb in query_embeddings:
                    distances = np.sum(np.abs(support_embeddings - q_emb), axis=1)
                    min_idx = np.argmin(distances)
                    pred_class = np.argmax(support_y[min_idx])
                    predictions.append(pred_class)

                predictions = np.array(predictions)

            accuracy = np.mean(predictions == query_y)
            episode_accuracies.append(accuracy)

        mean_accuracy = np.mean(episode_accuracies)
        accuracies.append(mean_accuracy)
        print(f"{k_shot}-shot accuracy: {mean_accuracy:.4f}")

    # Plot impact of shot count
    plt.figure(figsize=(10, 6))
    plt.plot(shots_range, accuracies, marker='o')
    plt.xlabel('Number of Shots (K)')
    plt.ylabel('Accuracy')
    plt.title(f'Impact of Shot Count on {model_type.capitalize()} Network Performance')
    plt.xticks(shots_range)
    plt.grid(True)
    plt.savefig(f'{model_type}_shot_impact.png')
    plt.close()

    return accuracies

# ==============================
# Main Execution
# ==============================

def main():
    print("Loading and preprocessing MNIST dataset...")
    (x_train, y_train), (x_val, y_val), (x_test, y_test) = load_and_preprocess_mnist()

    print("Augmenting training data with elastic deformations...")
    x_train_aug, y_train_aug = augment_data(x_train, y_train)

    print(f"Dataset sizes: Train: {x_train_aug.shape}, Validation: {x_val.shape}, Test: {x_test.shape}")

    # Train Prototypical Network (few-shot)
    print("\n===== Training Prototypical Network (5-way, 5-shot) =====")
    proto_net, proto_losses, proto_accuracies = train_prototypical_network(
        (x_train_aug, y_train_aug), (x_val, y_val),
        n_way=5, k_shot=5, n_episodes=500
    )

    # Train Siamese Network (few-shot)
    print("\n===== Training Siamese Network for Few-Shot Learning =====")
    siamese_net, siamese_history = train_siamese_network(
        (x_train_aug, y_train_aug), (x_val, y_val),
        n_pairs=10000, epochs=5
    )

    # Evaluate Siamese Network for Few-Shot
    print("\n===== Evaluating Siamese Network (5-way, 5-shot) =====")
    siamese_accuracy = evaluate_siamese_network(
        siamese_net, (x_test, y_test),
        n_way=5, k_shot=5, n_episodes=50
    )
    print(f"Siamese Network Few-Shot Test Accuracy: {siamese_accuracy:.4f}")

    # Train Matching Network (one-shot)
    print("\n===== Training Matching Network (5-way, 1-shot) =====")
    matching_net, matching_losses, matching_accuracies = train_matching_network(
        (x_train_aug, y_train_aug), (x_val, y_val),
        n_way=5, k_shot=1, n_episodes=500
    )

    # Train One-Shot Siamese Network
    print("\n===== Training Siamese Network for One-Shot Learning =====")
    oneshot_siamese_embedding = train_oneshot_siamese_network(
        (x_train_aug, y_train_aug), (x_val, y_val),
        n_way=5, k_shot=1, n_episodes=500
    )

    # Evaluate models (on 5-way classification)
    print("\n===== Evaluating Models on Test Set =====")

    # Prototypical Network (few-shot)
    proto_metrics = evaluate_model(
        proto_net, (x_test, y_test),
        model_type='prototypical',
        n_way=5, k_shot=5
    )
    print(f"Prototypical Network (5-way, 5-shot) - Accuracy: {proto_metrics['accuracy']:.4f}")

    # Matching Network (one-shot)
    matching_metrics = evaluate_model(
        matching_net, (x_test, y_test),
        model_type='matching',
        n_way=5, k_shot=1
    )
    print(f"Matching Network (5-way, 1-shot) - Accuracy: {matching_metrics['accuracy']:.4f}")

    # One-Shot Siamese Network
    siamese_oneshot_metrics = evaluate_model(
        oneshot_siamese_embedding, (x_test, y_test),
        model_type='siamese',
        n_way=5, k_shot=1
    )
    print(f"Siamese Network (5-way, 1-shot) - Accuracy: {siamese_oneshot_metrics['accuracy']:.4f}")

    # Compare models
    compare_models(proto_metrics, matching_metrics, siamese_oneshot_metrics)

    # Analyze impact of shot count
    print("\n===== Analyzing Impact of Shot Count on Prototypical Network =====")
    shot_impact = analyze_shot_impact(
        proto_net, (x_test, y_test),
        model_type='prototypical',
        shots_range=[1, 2, 5, 10]
    )

    print("\n===== Analysis of challenges and improvements =====")
    print("1. One-shot learning is challenging due to limited data, showing lower accuracy.")
    print("2. Prototypical networks perform better with more shots due to better prototype estimation.")
    print("3. Future improvements could include:")
    print("   - Deeper embedding networks")
    print("   - Attention mechanisms for better feature extraction")
    print("   - Meta-optimizers like MAML")
    print("   - Task-specific adaptation")
    print("   - More sophisticated data augmentation techniques")

if __name__ == "__main__":
    main()

Loading and preprocessing MNIST dataset...
Augmenting training data with elastic deformations...
Dataset sizes: Train: (53000, 28, 28, 1), Validation: (12000, 28, 28, 1), Test: (10000, 28, 28, 1)

===== Training Prototypical Network (5-way, 5-shot) =====


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Episode 50/500, Loss: 1.3625, Validation Accuracy: 0.4764
Episode 100/500, Loss: 1.0014, Validation Accuracy: 0.7896
Episode 150/500, Loss: 0.5492, Validation Accuracy: 0.8440
Episode 200/500, Loss: 0.3844, Validation Accuracy: 0.8744
Episode 250/500, Loss: 0.3531, Validation Accuracy: 0.9136
Episode 300/500, Loss: 0.2661, Validation Accuracy: 0.9304
Episode 350/500, Loss: 0.2502, Validation Accuracy: 0.9304
Episode 400/500, Loss: 0.2606, Validation Accuracy: 0.9440
Episode 450/500, Loss: 0.1600, Validation Accuracy: 0.9412
Episode 500/500, Loss: 0.2012, Validation Accuracy: 0.9608

===== Training Siamese Network for Few-Shot Learning =====
Epoch 1/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 95ms/step - accuracy: 0.6871 - loss: 0.6046 - val_accuracy: 0.5760 - val_loss: 0.6601
Epoch 2/5
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 95ms/step - accuracy: 0.8782 - loss: 0.3135 - val_accuracy: 0.6790 - val_loss: 0.5979
Epoch 3/5
[1m157/157[0m