In [1]:
import os
import re
import cv2
import numpy as np
import tensorflow as tf
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score
)
from collections import defaultdict
from tqdm import tqdm

tf.keras.backend.set_floatx('float64')

class PrototypicalEvaluator:
    def __init__(self, encoder, data_root, class_labels, num_support=8):
        self.encoder = encoder
        self.class_labels = class_labels
        self.num_support = num_support
        self.data_root = data_root
        self.class_map = {v: k for k, v in class_labels.items()}
        self.dataset = self._load_dataset()
        self.classes = list(self.dataset.keys())
        print(f"\nSuccessfully loaded {len(self.classes)} classes:")
        print("\n".join(self.classes))

    def _load_dataset(self):
        dataset = defaultdict(list)
        for class_name in self.class_labels.values():
            class_path = os.path.join(self.data_root, class_name)
            if not os.path.isdir(class_path):
                print(f"⚠️ Skipping missing class: {class_path}")
                continue
            for video_folder in os.listdir(class_path):
                video_path = os.path.join(class_path, video_folder)
                if os.path.isdir(video_path):
                    frames = self._load_video_frames(video_path)
                    if frames:
                        dataset[class_name].append(frames)
        return dataset

    def _load_video_frames(self, video_path):
        try:
            frame_files = sorted(
                [f for f in os.listdir(video_path) if f.lower().startswith('frame') and f.lower().endswith('.jpg')],
                key=lambda x: int(re.search(r'\d+', x).group())
            )
        except Exception as e:
            print(f"Error sorting frames: {str(e)}")
            return []
        video_frames = []
        for frame_file in frame_files:
            try:
                image = cv2.imread(os.path.join(video_path, frame_file), cv2.IMREAD_GRAYSCALE)
                if image is None:
                    raise ValueError("Failed to read image")
                image = image.astype(np.float64)
                resized = tf.image.resize(image[..., None], (28, 28)).numpy().squeeze()
                resized = (255.0 - resized) / 255.0  # Normalize to [0,1]
                video_frames.append(resized)
            except Exception as e:
                print(f"Skipping {frame_file}: {str(e)}")
        return video_frames

    def evaluate(self, num_episodes=100, n_way=8):
        """Evaluate model performance with full class coverage using single frame queries"""
        if n_way != len(self.class_labels):
            raise ValueError(f"n_way must match number of classes ({len(self.class_labels)})")
        for cls in self.classes:
            if len(self.dataset[cls]) < self.num_support + 1:
                raise ValueError(f"Class {cls} has only {len(self.dataset[cls])} videos. Need at least {self.num_support + 1}")

        # Initialize metrics storage
        frame_metrics = {'true': [], 'pred': [], 'confidence': [], 'distances': []}

        # Create balanced episodes with all classes
        for _ in tqdm(range(num_episodes)):
            support_set, query_set = self._create_episode(self.classes)
            prototypes = self._create_prototypes(support_set)

            for video, true_class in query_set:
                # Select a random frame from the video as query
                if len(video) > 0:
                    query_frame = video[np.random.randint(0, len(video))]
                    pred_result = self._predict_frame(query_frame, prototypes)
                    frame_metrics['true'].append(true_class)
                    frame_metrics['pred'].append(pred_result['class'])
                    frame_metrics['confidence'].append(pred_result['confidence'])
                    frame_metrics['distances'].append(pred_result['distances'])

        return self._compile_results(frame_metrics)

    def _create_episode(self, selected_classes):
        support_set = defaultdict(list)
        query_set = []
        for cls in selected_classes:
            videos = self.dataset[cls]
            np.random.shuffle(videos)
            support_set[cls] = videos[:self.num_support]
            for video in videos[self.num_support:]:
                query_set.append((video, cls))
        return support_set, query_set

    def _create_prototypes(self, support_set):
        """Create class prototypes by max-pooling frame embeddings and averaging across support videos"""
        prototypes = {}
        for cls, videos in support_set.items():
            class_video_embeddings = []
            for video in videos:
                video_tensor = tf.convert_to_tensor(np.array(video, dtype=np.float64)[..., None])
                embeddings = self.encoder.predict(video_tensor)
                if embeddings.ndim == 1:
                    embeddings = np.expand_dims(embeddings, 0)
                # Max-pool across frames
                video_emb = np.max(embeddings, axis=0)
                class_video_embeddings.append(video_emb)
            # Prototype is mean of support video embeddings
            prototypes[cls] = np.mean(class_video_embeddings, axis=0)
        return prototypes

    def _predict_frame(self, frame, prototypes):
        """Predict class for a single frame by comparing to class prototypes"""
        frame_tensor = tf.convert_to_tensor(np.array([frame], dtype=np.float64)[..., None])
        embedding = self.encoder.predict(frame_tensor)
        if embedding.ndim == 1:
            embedding = embedding.reshape(1, -1)
        frame_emb = embedding[0]  # Get the embedding for the single frame

        # Compute distances to each class prototype
        distances = {cls: np.linalg.norm(frame_emb - proto) for cls, proto in prototypes.items()}
        predicted_class = min(distances, key=distances.get)
        confidence = 1 - (distances[predicted_class] / sum(distances.values()))

        return {
            'class': predicted_class, 
            'confidence': confidence,
            'distances': distances
        }

    def _compile_results(self, frame_metrics):
        """Compute accuracy, precision, recall, f1 and average confidence"""
        true = frame_metrics['true']
        pred = frame_metrics['pred']

        results = {}
        results['frame_accuracy'] = np.mean(np.array(true) == np.array(pred))
        # Macro-averaged precision, recall, f1
        results['precision'] = precision_score(true, pred, average='macro', zero_division=0)
        results['recall'] = recall_score(true, pred, average='macro', zero_division=0)
        results['f1_score'] = f1_score(true, pred, average='macro', zero_division=0)
        results['avg_confidence'] = np.mean(frame_metrics['confidence'])
        results['distances'] = frame_metrics['distances']

        return results

# Usage example
if __name__ == "__main__":
    encoder = tf.keras.models.load_model(
        r"C:\Users\Mehdi\Desktop\work\SEMESTRE 4\P2M\CODE\SN_fsl\SN_fsl\data_saved\prototypical_net",
        compile=False
    )
    if isinstance(encoder, tf.keras.Sequential):
        encoder = tf.keras.Model(inputs=encoder.inputs, outputs=encoder.outputs)

    inputs = tf.keras.Input(shape=(28, 28, 1), dtype=tf.float64)
    x = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.float64))(inputs)
    outputs = encoder(x)
    encoder = tf.keras.Model(inputs, outputs)

    CLASS_LABELS = {
        0: 'Carcinoma',
        1: 'Extreme_polipoid',
        2: 'Laryngitis',
        3: 'leukoplacia',
        4: 'papilloma',
        5: 'scar',
        6: 'vocal_fold_cyst',
        7: 'Vocal_insufficiency',
    }
    evaluator = PrototypicalEvaluator(
        encoder=encoder,
        data_root=r"C:\Users\Mehdi\Desktop\work\SEMESTRE 4\P2M\data_frames\max_frames_fsl",
        class_labels=CLASS_LABELS,
        num_support=1
    )
    results = evaluator.evaluate(num_episodes=100)
    print(f"Frame-level Accuracy: {results['frame_accuracy']:.2%}")
    print(f"Precision (macro): {results['precision']:.2%}")
    print(f"Recall (macro): {results['recall']:.2%}")
    print(f"F1-score (macro): {results['f1_score']:.2%}")
    print(f"Average Confidence: {results['avg_confidence']:.2%}")
    
    # Print distances for the first few queries
    print("\nSample distances to each class prototype:")
    for i, dist_dict in enumerate(results['distances'][:5]):
        print(f"\nQuery {i+1}:")
        for cls, dist in sorted(dist_dict.items(), key=lambda x: x[1]):
            print(f"  {cls}: {dist:.4f}")



Successfully loaded 8 classes:
Carcinoma
Extreme_polipoid
Laryngitis
leukoplacia
papilloma
scar
vocal_fold_cyst
Vocal_insufficiency


  0%|          | 0/100 [00:00<?, ?it/s]



  1%|          | 1/100 [00:09<15:07,  9.17s/it]



  2%|▏         | 2/100 [00:10<07:10,  4.40s/it]



  3%|▎         | 3/100 [00:11<04:39,  2.89s/it]



  4%|▍         | 4/100 [00:12<03:31,  2.20s/it]



  5%|▌         | 5/100 [00:13<02:56,  1.86s/it]



  6%|▌         | 6/100 [00:14<02:30,  1.60s/it]



  7%|▋         | 7/100 [00:15<02:11,  1.41s/it]



  8%|▊         | 8/100 [00:16<01:59,  1.30s/it]



  9%|▉         | 9/100 [00:17<01:51,  1.22s/it]



 10%|█         | 10/100 [00:19<01:48,  1.21s/it]



 11%|█         | 11/100 [00:20<01:43,  1.16s/it]



 12%|█▏        | 12/100 [00:21<01:39,  1.13s/it]



 13%|█▎        | 13/100 [00:22<01:36,  1.11s/it]



 14%|█▍        | 14/100 [00:23<01:34,  1.09s/it]



 15%|█▌        | 15/100 [00:24<01:32,  1.08s/it]



 16%|█▌        | 16/100 [00:25<01:31,  1.09s/it]



 17%|█▋        | 17/100 [00:26<01:29,  1.08s/it]



 18%|█▊        | 18/100 [00:27<01:28,  1.08s/it]



 19%|█▉        | 19/100 [00:28<01:26,  1.07s/it]



 20%|██        | 20/100 [00:29<01:26,  1.09s/it]



 21%|██        | 21/100 [00:30<01:25,  1.08s/it]



 22%|██▏       | 22/100 [00:31<01:24,  1.08s/it]



 23%|██▎       | 23/100 [00:33<01:22,  1.07s/it]



 24%|██▍       | 24/100 [00:34<01:21,  1.07s/it]



 25%|██▌       | 25/100 [00:35<01:19,  1.07s/it]



 26%|██▌       | 26/100 [00:36<01:18,  1.07s/it]



 27%|██▋       | 27/100 [00:37<01:17,  1.06s/it]



 28%|██▊       | 28/100 [00:38<01:16,  1.06s/it]



 29%|██▉       | 29/100 [00:39<01:15,  1.06s/it]



 30%|███       | 30/100 [00:40<01:14,  1.06s/it]



 31%|███       | 31/100 [00:41<01:13,  1.06s/it]



 32%|███▏      | 32/100 [00:42<01:12,  1.06s/it]



 33%|███▎      | 33/100 [00:43<01:11,  1.07s/it]



 34%|███▍      | 34/100 [00:44<01:10,  1.07s/it]



 35%|███▌      | 35/100 [00:45<01:09,  1.07s/it]



 36%|███▌      | 36/100 [00:46<01:08,  1.07s/it]



 37%|███▋      | 37/100 [00:47<01:07,  1.07s/it]



 38%|███▊      | 38/100 [00:49<01:06,  1.08s/it]



 39%|███▉      | 39/100 [00:50<01:07,  1.10s/it]



 40%|████      | 40/100 [00:51<01:07,  1.13s/it]



 41%|████      | 41/100 [00:52<01:07,  1.14s/it]



 42%|████▏     | 42/100 [00:53<01:05,  1.12s/it]



 43%|████▎     | 43/100 [00:54<01:03,  1.12s/it]



 44%|████▍     | 44/100 [00:55<01:02,  1.11s/it]



 45%|████▌     | 45/100 [00:57<01:04,  1.18s/it]



 46%|████▌     | 46/100 [00:58<01:06,  1.23s/it]



 47%|████▋     | 47/100 [00:59<01:06,  1.25s/it]



 48%|████▊     | 48/100 [01:01<01:08,  1.31s/it]



 49%|████▉     | 49/100 [01:02<01:06,  1.31s/it]



 50%|█████     | 50/100 [01:03<01:04,  1.28s/it]



 51%|█████     | 51/100 [01:04<01:01,  1.26s/it]



 52%|█████▏    | 52/100 [01:06<00:59,  1.25s/it]



 53%|█████▎    | 53/100 [01:07<00:57,  1.23s/it]



 54%|█████▍    | 54/100 [01:08<00:55,  1.22s/it]



 55%|█████▌    | 55/100 [01:09<00:54,  1.21s/it]



 56%|█████▌    | 56/100 [01:10<00:53,  1.21s/it]



 57%|█████▋    | 57/100 [01:12<00:51,  1.21s/it]



 58%|█████▊    | 58/100 [01:13<00:50,  1.20s/it]



 59%|█████▉    | 59/100 [01:14<00:47,  1.17s/it]



 60%|██████    | 60/100 [01:15<00:46,  1.17s/it]



 61%|██████    | 61/100 [01:16<00:46,  1.18s/it]



 62%|██████▏   | 62/100 [01:18<00:45,  1.20s/it]



 63%|██████▎   | 63/100 [01:19<00:43,  1.19s/it]



 64%|██████▍   | 64/100 [01:20<00:41,  1.16s/it]



 65%|██████▌   | 65/100 [01:21<00:40,  1.16s/it]



 66%|██████▌   | 66/100 [01:22<00:39,  1.16s/it]



 67%|██████▋   | 67/100 [01:23<00:37,  1.15s/it]



 68%|██████▊   | 68/100 [01:24<00:36,  1.16s/it]



 69%|██████▉   | 69/100 [01:26<00:35,  1.15s/it]



 70%|███████   | 70/100 [01:27<00:34,  1.14s/it]



 71%|███████   | 71/100 [01:28<00:32,  1.13s/it]



 72%|███████▏  | 72/100 [01:29<00:31,  1.12s/it]



 73%|███████▎  | 73/100 [01:30<00:30,  1.11s/it]



 74%|███████▍  | 74/100 [01:31<00:28,  1.11s/it]



 75%|███████▌  | 75/100 [01:32<00:27,  1.10s/it]



 76%|███████▌  | 76/100 [01:33<00:26,  1.11s/it]



 77%|███████▋  | 77/100 [01:34<00:25,  1.10s/it]



 78%|███████▊  | 78/100 [01:35<00:24,  1.09s/it]



 79%|███████▉  | 79/100 [01:37<00:22,  1.09s/it]



 80%|████████  | 80/100 [01:38<00:21,  1.09s/it]



 81%|████████  | 81/100 [01:39<00:20,  1.08s/it]



 82%|████████▏ | 82/100 [01:40<00:20,  1.13s/it]



 83%|████████▎ | 83/100 [01:41<00:19,  1.12s/it]



 84%|████████▍ | 84/100 [01:42<00:17,  1.11s/it]



 85%|████████▌ | 85/100 [01:43<00:16,  1.11s/it]



 86%|████████▌ | 86/100 [01:44<00:15,  1.11s/it]



 87%|████████▋ | 87/100 [01:45<00:14,  1.10s/it]



 88%|████████▊ | 88/100 [01:47<00:13,  1.10s/it]



 89%|████████▉ | 89/100 [01:48<00:12,  1.09s/it]



 90%|█████████ | 90/100 [01:49<00:10,  1.09s/it]



 91%|█████████ | 91/100 [01:50<00:09,  1.09s/it]



 92%|█████████▏| 92/100 [01:51<00:08,  1.09s/it]



 93%|█████████▎| 93/100 [01:52<00:07,  1.09s/it]



 94%|█████████▍| 94/100 [01:53<00:06,  1.09s/it]



 95%|█████████▌| 95/100 [01:54<00:05,  1.10s/it]



 96%|█████████▌| 96/100 [01:55<00:04,  1.10s/it]



 97%|█████████▋| 97/100 [01:56<00:03,  1.10s/it]



 98%|█████████▊| 98/100 [01:58<00:02,  1.11s/it]



 99%|█████████▉| 99/100 [01:59<00:01,  1.10s/it]



100%|██████████| 100/100 [02:00<00:00,  1.20s/it]

Frame-level Accuracy: 39.15%
Precision (macro): 42.12%
Recall (macro): 50.53%
F1-score (macro): 41.45%
Average Confidence: 94.19%

Sample distances to each class prototype:

Query 1:
  Carcinoma: 30.6926
  Laryngitis: 69.8523
  Vocal_insufficiency: 71.7355
  papilloma: 82.2965
  scar: 98.7242
  Extreme_polipoid: 100.1532
  vocal_fold_cyst: 107.5108
  leukoplacia: 179.1025

Query 2:
  scar: 54.7668
  Carcinoma: 55.2546
  Vocal_insufficiency: 68.1865
  Laryngitis: 73.2242
  Extreme_polipoid: 86.5398
  papilloma: 103.5758
  vocal_fold_cyst: 114.8860
  leukoplacia: 178.1703

Query 3:
  Vocal_insufficiency: 88.1023
  scar: 112.6448
  vocal_fold_cyst: 120.0259
  Carcinoma: 148.2629
  papilloma: 150.7821
  Extreme_polipoid: 161.8651
  Laryngitis: 161.9180
  leukoplacia: 234.9709

Query 4:
  leukoplacia: 2.0442
  Carcinoma: 178.7517
  Laryngitis: 184.7585
  scar: 188.5101
  Vocal_insufficiency: 189.4187
  Extreme_polipoid: 202.2728
  papilloma: 206.3253
  vocal_fold_cyst: 214.7074

Query 5:
  


