In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
from tqdm import tqdm

from imageio import imread

In [2]:
def load_and_process_data():
    data_path = r"D:\DATA\archive"
    train_folder = os.path.join(data_path, 'images_background')
    val_folder = os.path.join(data_path, 'images_evaluation')
    save_path = 'SN_fsl\SN_fsl2'

    # Load or create dataset
    def loadimgs(path):
        if not os.path.exists(path):
            print(f"Unzipping {os.path.basename(path)}")
            os.chdir(data_path)
            os.system(f"unzip {os.path.basename(path)}.zip")
            os.chdir("..")
            
        X, y = [], []
        for alphabet in os.listdir(path):
            alphabet_path = os.path.join(path, alphabet)
            for letter in os.listdir(alphabet_path):
                letter_path = os.path.join(alphabet_path, letter)
                images = [imread(os.path.join(letter_path, f)) for f in os.listdir(letter_path)]
                X.append(np.array(images))
        return np.array(X)

    # Load and process data
    X_train = loadimgs(train_folder)
    X_val = loadimgs(val_folder)

    # Preprocessing function
    def preprocess(img):
        img = 1.0 - img / 255.0  # Single image processing
        img = np.expand_dims(img, axis=-1)  # Add channel dimension
        return tf.image.resize(img, (28, 28)).numpy()  # Proper 3D input

    # Process and augment with rotations
    def process_and_augment(data):
        processed = []
        for char_class in data:
            for img in char_class:
                processed.append(preprocess(img))
                # Add rotated versions
                for k in range(1, 4):
                    rotated = np.rot90(img, k)
                    processed.append(preprocess(rotated))
        return np.array(processed).reshape(-1, 20*4, 28, 28, 1)

    X_train = process_and_augment(X_train)
    X_val = process_and_augment(X_val)
    
    return X_train, X_val


X_train, X_val = load_and_process_data()

# ----------------------------
# Triplet Generation
# ----------------------------
class TripletGenerator(tf.keras.utils.Sequence):
    def __init__(self, dataset, batch_size=32):
        self.dataset = dataset
        self.batch_size = batch_size
        self.classes = dataset.shape[0]
        self.samples_per_class = dataset.shape[1]
        
    def __len__(self):
        return int(np.ceil(len(self.dataset) / self.batch_size))
    
    def __getitem__(self, idx):
        batch = []
        for _ in range(self.batch_size):
            # Random anchor class
            anchor_class = np.random.randint(self.classes)
            # Random positive sample
            anchor_idx, positive_idx = np.random.choice(self.samples_per_class, 2, replace=False)
            anchor = self.dataset[anchor_class, anchor_idx]
            positive = self.dataset[anchor_class, positive_idx]
            
            # Random negative class
            negative_class = np.random.randint(self.classes)
            while negative_class == anchor_class:
                negative_class = np.random.randint(self.classes)
            negative = self.dataset[negative_class, np.random.randint(self.samples_per_class)]
            
            batch.append((anchor, positive, negative))
        
        anchors, positives, negatives = zip(*batch)
        return [np.array(anchors), np.array(positives), np.array(negatives)], np.zeros(len(batch))

# ----------------------------
# Model Architecture
# ----------------------------
def create_base_network():
    return tf.keras.Sequential([
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same', 
                              input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.BatchNormalization(),
        
        tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))
    ])

base_network = create_base_network()

# ----------------------------
# Triplet Loss
# ----------------------------
def triplet_loss(margin=0.5):
    def loss(_, y_pred):
        anchor, positive, negative = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2]
        
        pos_dist = tf.reduce_sum(tf.square(anchor - positive), axis=1)
        neg_dist = tf.reduce_sum(tf.square(anchor - negative), axis=1)
        
        basic_loss = pos_dist - neg_dist + margin
        return tf.reduce_mean(tf.maximum(basic_loss, 0.0))
    return loss

# ----------------------------
# Training Setup
# ----------------------------
def create_siamese_model(base_network):
    anchor_input = tf.keras.Input(shape=(28, 28, 1), name='anchor')
    positive_input = tf.keras.Input(shape=(28, 28, 1), name='positive')
    negative_input = tf.keras.Input(shape=(28, 28, 1), name='negative')
    
    anchor_embedding = base_network(anchor_input)
    positive_embedding = base_network(positive_input)
    negative_embedding = base_network(negative_input)
    
    merged_output = tf.stack([anchor_embedding, positive_embedding, negative_embedding], axis=1)
    model = tf.keras.Model(inputs=[anchor_input, positive_input, negative_input], outputs=merged_output)
    
    return model

siamese_model = create_siamese_model(base_network)
siamese_model.compile(optimizer=tf.keras.optimizers.Adam(0.0001), loss=triplet_loss(0.5))


  images = [imread(os.path.join(letter_path, f)) for f in os.listdir(letter_path)]


In [3]:
train_generator = TripletGenerator(X_train, batch_size=32)
val_generator = TripletGenerator(X_val, batch_size=32)

history = siamese_model.fit(
    train_generator,
    epochs=50,
    validation_data=val_generator,
    callbacks=[
        tf.keras.callbacks.ReduceLROnPlateau(patience=3),
        tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50


In [5]:
def evaluate_k_shot(encoder, dataset, k_shot=5, n_way=3, test_episodes=1000):
    accuracies = []
    precisions = []
    recalls = []
    f1_scores = []
    
    for _ in range(test_episodes):
        # Episode setup
        classes = np.random.choice(len(dataset), n_way, replace=False)
        support, query = [], []
        true_labels = []
        
        # Sample selection
        for i, cls in enumerate(classes):
            samples = dataset[cls]
            selected = np.random.choice(len(samples), k_shot + 15, replace=False)
            support.extend(samples[selected[:k_shot]])
            query.extend(samples[selected[k_shot:]])
            true_labels.extend([i] * 15)  # 15 query samples per class
            
        # Embedding calculation
        support_emb = encoder.predict(np.array(support), verbose=0)
        query_emb = encoder.predict(np.array(query), verbose=0)
        
        # Prototype calculation
        prototypes = [np.mean(support_emb[i*k_shot:(i+1)*k_shot], axis=0) 
                     for i in range(n_way)]
        
        # Prediction
        preds = []
        for q in query_emb:
            distances = [np.linalg.norm(q - p) for p in prototypes]
            preds.append(np.argmin(distances))
            
        # Convert to numpy arrays
        true_labels = np.array(true_labels)
        preds = np.array(preds)
        
        # Calculate metrics
        tp = np.zeros(n_way)
        fp = np.zeros(n_way)
        fn = np.zeros(n_way)
        
        for cls in range(n_way):
            tp[cls] = np.sum((preds == cls) & (true_labels == cls))
            fp[cls] = np.sum((preds == cls) & (true_labels != cls))
            fn[cls] = np.sum((true_labels == cls) & (preds != cls))
            
        # Avoid division by zero
        precision = np.mean([tp[cls] / (tp[cls] + fp[cls]) if (tp[cls] + fp[cls]) > 0 else 0 
                      for cls in range(n_way)])
        recall = np.mean([tp[cls] / (tp[cls] + fn[cls]) if (tp[cls] + fn[cls]) > 0 else 0 
                     for cls in range(n_way)])
        f1 = np.mean([2 * (precision * recall) / (precision + recall) 
                    if (precision + recall) > 0 else 0 
                    for cls in range(n_way)])
        
        # Store metrics
        accuracies.append(np.mean(preds == true_labels))
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
    
    return {
        'accuracy': np.mean(accuracies),
        'precision': np.mean(precisions),
        'recall': np.mean(recalls),
        'f1': np.mean(f1_scores)
    }

# Usage
print("\nEvaluating k-shot performance...")
k_results = {}
for k in range(1, 7):
    metrics = evaluate_k_shot(base_network, X_val, k_shot=k)
    k_results[k] = metrics

# Display results
print("\nFinal Results:")
for k, metrics in k_results.items():
    print(f"k={k}:")
    print(f"  Accuracy:   {metrics['accuracy']:.4f}")
    print(f"  Precision:  {metrics['precision']:.4f}")
    print(f"  Recall:     {metrics['recall']:.4f}")
    print(f"  F1-Score:   {metrics['f1']:.4f}")
    print()


Evaluating k-shot performance...

Final Results:
k=1:
  Accuracy:   0.6545
  Precision:  0.6688
  Recall:     0.6545
  F1-Score:   0.6603

k=2:
  Accuracy:   0.6935
  Precision:  0.7067
  Recall:     0.6935
  F1-Score:   0.6995

k=3:
  Accuracy:   0.7044
  Precision:  0.7165
  Recall:     0.7044
  F1-Score:   0.7099

k=4:
  Accuracy:   0.7258
  Precision:  0.7347
  Recall:     0.7258
  F1-Score:   0.7298

k=5:
  Accuracy:   0.7325
  Precision:  0.7414
  Recall:     0.7325
  F1-Score:   0.7366

k=6:
  Accuracy:   0.7430
  Precision:  0.7530
  Recall:     0.7430
  F1-Score:   0.7476

