<a href="https://colab.research.google.com/github/dorobat-diana/LicentaAi/blob/main/OneShotSeamese.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Lambda, Dense, Dropout, BatchNormalization
import numpy as np
import os
import random
import matplotlib.pyplot as plt
from google.colab import drive

try:
    drive.mount('/content/drive')
except:
    print("Google Drive already mounted or mount failed.")

SIAMESE_MODEL_PATH = '/content/drive/MyDrive/ColabNotebooks/results/siamese_mobilenetv2_landmark_model_phase2_toplayers.keras'

BASE_MODEL_PATH_FOR_STRUCTURE = '/content/drive/MyDrive/ColabNotebooks/results/MobileNetV2_FamousPlaces/phase3_with_dropout_l2.keras'
FEATURE_LAYER_NAME_FOR_STRUCTURE = 'global_average_pooling2d_1'

N_LAYERS_UNFROZEN_IN_SAVED_FEATURE_EXTRACTOR = 10

DATA_DIR_FOR_ONESHOT = '/content/drive/MyDrive/ColabNotebooks/data/famous_places/split/test'
IMG_WIDTH, IMG_HEIGHT = 224, 224
IMG_SHAPE = (IMG_WIDTH, IMG_HEIGHT, 3)
N_WAY = 5
N_TRIALS = 500

def load_and_preprocess_image_tf(path_tensor, img_shape=(IMG_WIDTH, IMG_HEIGHT, 3)):
    img = tf.io.read_file(path_tensor)
    try:
        img = tf.image.decode_image(img, channels=img_shape[2], expand_animations=False)
    except tf.errors.InvalidArgumentError:
        tf.print(f"Warning: Could not decode image {path_tensor}. Returning zeros.")
        return tf.zeros(img_shape, dtype=tf.float32)
    if len(img.shape) != 3 or img.shape[2] != img_shape[2]:
        tf.print(f"Warning: Image {path_tensor} has unexpected shape {img.shape}. Converting or returning zeros.")
        if img.shape[2] == 1: img = tf.image.grayscale_to_rgb(img)
        elif img.shape[2] == 4: img = img[:,:,:3]
        else: return tf.zeros(img_shape, dtype=tf.float32)
    img = tf.image.resize(img, [img_shape[0], img_shape[1]])
    img = tf.cast(img, tf.float32)
    img = tf.keras.applications.mobilenet_v2.preprocess_input(img)
    return img

def predict_similarity_oneshot(img_path1, img_path2, model, img_shape):
    if model is None: return None
    img1_tensor = load_and_preprocess_image_tf(tf.constant(img_path1), img_shape)
    img2_tensor = load_and_preprocess_image_tf(tf.constant(img_path2), img_shape)
    if tf.reduce_sum(img1_tensor) == 0 or tf.reduce_sum(img2_tensor) == 0:
        print(f"Warning: One or both images for similarity check might not have loaded correctly.")
        return 0.0
    img1_batch = tf.expand_dims(img1_tensor, axis=0)
    img2_batch = tf.expand_dims(img2_tensor, axis=0)
    try:
        prediction = model.predict([img1_batch, img2_batch], verbose=0)
        return prediction[0][0]
    except Exception as e:
        return 0.0

def build_siamese_model_for_weight_loading(input_shape, feature_extractor_model_instance):
    """
    Builds the Siamese network model structure, ensuring all layer names match
    the model from which weights will be loaded.
    Uses tf.* ops in Lambda for robustness.
    """
    input_a = Input(shape=input_shape, name="input_image_A")
    input_b = Input(shape=input_shape, name="input_image_B")

    processed_a = feature_extractor_model_instance(input_a)
    processed_b = feature_extractor_model_instance(input_b)

    distance_layer = Lambda(
        lambda tensors: tf.reduce_sum(tf.abs(tensors[0] - tensors[1]), axis=1, keepdims=True),
        output_shape=(1,),
        name="L1_distance"
    )
    distance = distance_layer([processed_a, processed_b])

    x = Dense(128, activation='relu', name="head_dense_1")(distance)
    x = Dropout(0.3, name="head_dropout_1")(x)
    x = Dense(64, activation='relu', name="head_dense_2")(x)
    x = Dropout(0.3, name="head_dropout_2")(x)
    prediction = Dense(1, activation='sigmoid', name="similarity_prediction")(x)

    rebuilt_model = Model(inputs=[input_a, input_b], outputs=prediction)
    return rebuilt_model

print("Rebuilding Siamese model structure...")
siamese_model_rebuilt = None
try:
    print(f"Loading base MobileNetV2 model from: {BASE_MODEL_PATH_FOR_STRUCTURE} for feature extractor structure...")
    base_mobilenet_model = load_model(BASE_MODEL_PATH_FOR_STRUCTURE)
    print("Base MobileNetV2 model loaded.")

    feature_output_structure = base_mobilenet_model.get_layer(FEATURE_LAYER_NAME_FOR_STRUCTURE).output
    feature_extractor_structure = Model(
        inputs=base_mobilenet_model.input,
        outputs=feature_output_structure,
        name="feature_extractor_functional"
    )
    print(f"Feature extractor structure ('{feature_extractor_structure.name}') created.")

    print(f"Setting trainability for layers in '{feature_extractor_structure.name}'...")
    for layer in feature_extractor_structure.layers:
        layer.trainable = False

    if N_LAYERS_UNFROZEN_IN_SAVED_FEATURE_EXTRACTOR > 0:
        unfrozen_count = 0
        for layer in feature_extractor_structure.layers[-N_LAYERS_UNFROZEN_IN_SAVED_FEATURE_EXTRACTOR:]:
            if not isinstance(layer, BatchNormalization):
                layer.trainable = True
                unfrozen_count += 1
        print(f"  Set {unfrozen_count} non-BN layers in feature extractor to trainable.")
    else:
        print("  All layers in feature extractor kept frozen.")
    siamese_model_rebuilt = build_siamese_model_for_weight_loading(IMG_SHAPE, feature_extractor_structure)
    print("Full Siamese model structure rebuilt.")
    print(f"Loading weights from: {SIAMESE_MODEL_PATH}")
    siamese_model_rebuilt.load_weights(SIAMESE_MODEL_PATH)
    print("Weights loaded successfully into rebuilt model.")
    siamese_model = siamese_model_rebuilt

except Exception as e:
    print(f"Error during model rebuilding or weight loading: {e}")
    raise SystemExit("Model construction/weight loading failed.")

def get_image_paths_by_class(directory):
    image_paths = {}
    class_names = sorted([d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))])
    if not class_names: raise ValueError(f"No subdirectories (classes) found in {directory}")
    for class_name in class_names:
        class_dir = os.path.join(directory, class_name)
        paths_in_class = []
        for fname in os.listdir(class_dir):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                paths_in_class.append(os.path.join(class_dir, fname))
        if paths_in_class:
            if len(paths_in_class) >= 2: image_paths[class_name] = paths_in_class
            else: print(f"Warning: Class '{class_name}' has < 2 images, excluding.")
    return image_paths

print(f"\nLoading image paths from: {DATA_DIR_FOR_ONESHOT}")
image_paths_by_class = get_image_paths_by_class(DATA_DIR_FOR_ONESHOT)
if len(image_paths_by_class) < N_WAY:
    raise ValueError(f"Not enough classes for {N_WAY}-way one-shot. Found {len(image_paths_by_class)}.")
print(f"Found {len(image_paths_by_class)} eligible classes.")

print(f"\nStarting {N_WAY}-way one-shot learning evaluation for {N_TRIALS} trials...")
correct_predictions = 0
all_trial_details = []

for trial in range(N_TRIALS):
    selected_class_names = random.sample(list(image_paths_by_class.keys()), N_WAY)
    support_set_images = {}
    query_image_path = None
    true_query_class = None

    for i, class_name in enumerate(selected_class_names):
        img1_path, img2_path = random.sample(image_paths_by_class[class_name], 2)
        support_set_images[class_name] = img1_path
        if i == 0:
            query_image_path = img2_path
            true_query_class = class_name

    if query_image_path is None: continue

    similarity_scores = {}
    for class_name, support_img_path in support_set_images.items():
        similarity = predict_similarity_oneshot(query_image_path, support_img_path, siamese_model, IMG_SHAPE)
        similarity_scores[class_name] = similarity

    if not similarity_scores: predicted_class = "N/A"
    else: predicted_class = max(similarity_scores, key=similarity_scores.get)

    is_correct = (predicted_class == true_query_class)
    if is_correct: correct_predictions += 1

    all_trial_details.append({
        'trial_num': trial + 1,
        'query_path': query_image_path,
        'support_set': dict(support_set_images),
        'true_class': true_query_class,
        'predicted_class': predicted_class,
        'scores': dict(similarity_scores)
    })
    if (trial + 1) % (N_TRIALS // 10 if N_TRIALS >=10 else 1) == 0:
        print(f"Completed Trial {trial+1}/{N_TRIALS}")

one_shot_accuracy = (correct_predictions / N_TRIALS) * 100 if N_TRIALS > 0 else 0
print(f"\n--- One-Shot Learning Evaluation Summary ---")
print(f"Total Trials: {N_TRIALS}")
print(f"Correct Predictions: {correct_predictions}")
print(f"Accuracy: {one_shot_accuracy:.2f}%")

def display_one_shot_task(trial_detail):
    query_path = trial_detail['query_path']
    support_set = trial_detail['support_set']
    true_class = trial_detail['true_class']
    predicted_class = trial_detail['predicted_class']
    scores = trial_detail['scores']
    trial_num = trial_detail['trial_num']

    n_support = len(support_set)
    plt.figure(figsize=(3 * (n_support + 1), 4))
    plt.subplot(1, n_support + 1, 1)
    try:
        query_img = plt.imread(query_path)
        plt.imshow(query_img)
    except FileNotFoundError:
        plt.text(0.5, 0.5, 'Query Img\nNot Found', ha='center', va='center', color='red')
    plt.title(f"Query (True: {true_class})\nPred: {predicted_class}")
    plt.axis('off')

    i = 2
    for cls, path in support_set.items():
        plt.subplot(1, n_support + 1, i)
        try:
            support_img = plt.imread(path)
            plt.imshow(support_img)
        except FileNotFoundError:
            plt.text(0.5, 0.5, f'{cls}\nSupport Img\nNot Found', ha='center', va='center', color='red')

        title_color = 'black'
        if cls == true_class and cls == predicted_class: title_color = 'blue'
        elif cls == predicted_class: title_color = 'red'
        elif cls == true_class: title_color = 'green'

        plt.title(f"Support: {cls}\nScore: {scores.get(cls, 0):.2f}", color=title_color)
        plt.axis('off')
        i += 1

    plt.suptitle(f"One-Shot Task Visualization - Trial {trial_num}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()

if all_trial_details:
    print("\nVisualizing a few one-shot task examples:")
    display_one_shot_task(all_trial_details[0])
    first_incorrect = next((td for td in all_trial_details if td['predicted_class'] != td['true_class']), None)
    if first_incorrect and first_incorrect['trial_num'] != all_trial_details[0]['trial_num']:
        display_one_shot_task(first_incorrect)
    elif len(all_trial_details) > 1 and first_incorrect is None :
         display_one_shot_task(all_trial_details[1 % len(all_trial_details)])


print("\n--- One-Shot Evaluation Script Finished ---")