# PotoBeam ( Prototypical Networks for Beam Classification)

This project is based on the tutorial notebook available at:
[Meta Learning Tutorial](https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial16/Meta_Learning.ipynb)


Prototypical Networks operate on the principle that data points within an embedding space cluster around a central prototype for each class as shown in the belwo figure. Utilizing an encoder for non-linear mapping, the networks cluster features around a central prototype for each class within an embedding space. Classification becomes a nearest-neighbour problem, with the class of a query point determined by its proximity to these prototypes.
ProtoBeam employs this concept of Prototypical Networks for beam classification where prototypes are created during training using a training set of a particular antenna.


In [None]:
## Standard libraries
import os
import numpy as np
import random
import json
from PIL import Image
from collections import defaultdict
from statistics import mean, stdev
from copy import deepcopy

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.auto import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    %pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Import tensorboard
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial16"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
import numpy as np
import h5py
from sklearn.preprocessing import LabelEncoder
import torch.utils.data as data
import random

class DeepBeamDataset(data.Dataset):
    """
    A custom dataset class for handling and preprocessing IQ data stored in Deepbeam HDF5 files.

    Attributes:
        data_file (str): Path to the HDF5 file containing the data.
        label (str): The key for the label data in the HDF5 file.
        classes (list): List of all class labels.
        num_samples (int): Number of samples to be included in each data instance.
        indices (list or None): Specific indices of samples to include. If None, all samples are included.
    """

    def __init__(self, data_file, label, classes, num_samples, indices=None, augment=True):
        self.data_file = data_file
        self.label = label
        self.num_samples = num_samples
        self.indices = indices  # Parameter to specify indices of samples to include
        self.le = LabelEncoder()
        self.le.fit(classes)
        self.classes = classes
        self.augment = augment  # Parameter to enable or disable data augmentation


    def __len__(self):
        """
        Returns the number of data points in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        if self.indices is not None:
            return len(self.indices)
        with h5py.File(self.data_file, 'r') as f:
            return len(f[self.label]) // self.num_samples

    def __getitem__(self, idx):
        """
        Retrieves the data and label at the specified index.

        Args:
            idx (int): Index of the data point to retrieve.

        Returns:
            tuple: (data, label) where data is a numpy array of the IQ samples and label is the corresponding class label.
        """
        actual_idx = idx
        with h5py.File(self.data_file, 'r') as f:
            x = np.float32(f['iq'][self.num_samples*actual_idx:self.num_samples*(actual_idx+1)][:, :2048])
            x = x[np.newaxis, :, :]

            if self.augment:
                scale_factor = random.uniform(0.5, 1.4)  # Generate a random scaling factor between 0 and 2
                x *= scale_factor  # Scale the samples by this factor

                # Apply phase rotation
                rotation_angle = np.random.uniform(0, 2*np.pi)  # Random rotation angle
                complex_signal = x[:, :, 0] + 1j * x[:, :, 1]  # Convert to complex
                rotated_signal = complex_signal * np.exp(1j * rotation_angle)  # Rotate
                x[:, :, 0], x[:, :, 1] = rotated_signal.real, rotated_signal.imag  # Update I and Q

            # Normalization code remains unchanged
            epsilon = 1e-9
            x_min_i, x_max_i = x[:, :, 0].min(), x[:, :, 0].max()
            x[:, :, 0] = 2 * ((x[:, :, 0] - x_min_i) / (x_max_i - x_min_i + epsilon)) - 1

            x_min_q, x_max_q = x[:, :, 1].min(), x[:, :, 1].max()
            x[:, :, 1] = 2 * ((x[:, :, 1] - x_min_q) / (x_max_q - x_min_q + epsilon)) - 1

            y = np.int64([f[self.label][self.num_samples*actual_idx]])[0]

        return x, y



In [None]:
def deepbeam_dataset_from_labels(data_file, label, num_samples, class_set, num_blocks, augment=True, **kwargs):
    """
    Creates a DeepBeamDataset with selected indices based on the specified class labels.

    Args:
        data_file (str): Path to the HDF5 file containing the data.
        label (str): The key for the label data in the HDF5 file.
        num_samples (int): Number of samples to be included in each data instance.
        class_set (list): List of class labels to include in the dataset.
        num_blocks (int): Number of blocks to skip between selected indices.
        **kwargs: Additional arguments passed to the DeepBeamDataset constructor.

    Returns:
        DeepBeamDataset: An instance of the DeepBeamDataset class with the selected indices.
    """
    frames_per_class = 150000  # Number of frames per class for the 24 beams dataset this number will change for 5 beams or AOA
    num_classes_all = 24  # Total number of unique classes in the dataset
    selected_indices = np.array([], dtype=int)  # Array to store selected indices

    for class_label in class_set:
        # Calculate the start and end frame index for each class in the first cycle
        start_frame = (class_label * frames_per_class)
        end_frame = (start_frame + frames_per_class)
        start_idx = start_frame
        end_idx = end_frame
        gain_step = frames_per_class * num_classes_all

        # Select indices for the specified class and concatenate to selected_indices
        indices_for_class  = np.concatenate([
            np.arange(start_idx, end_idx, num_blocks),
            np.arange(start_idx + gain_step, end_idx + gain_step, num_blocks),
            np.arange(start_idx + (2 * gain_step), end_idx + (2 * gain_step), num_blocks)
        ])
        selected_indices = np.concatenate([selected_indices, indices_for_class])

    return DeepBeamDataset(data_file, label, torch.tensor(class_set), num_samples, selected_indices, augment=augment, **kwargs)

In [None]:
torch.manual_seed(0)           # Set seed for reproducibility
classes = torch.randperm(24)  # Returns random permutation of numbers 0 to 99
train_classes, val_classes, test_classes = classes[:24], classes[13:24], classes[0:24]
print(train_classes)
print(val_classes)
print(test_classes)

In [None]:
# Define the path to the HDF5 data file
data_file =  "/mnt/c/Users/ict_3/deepBeam_matlab/neu_ww72bk37k.h5"
label = 'tx_beam'  # Label key in the HDF5 file
num_samples = 2048  # Number of samples per data point

# Define the number of blocks to skip between selected indices for each dataset
# Experiments shown that we do not need the whole dataset for training the model
# Using part of the dataset for training is sufficient to achieve results similar using the full dataset
train_n_blocks = 125
val_n_blocks = 1400
test_n_blocks = 450

# Create training dataset
train_set = deepbeam_dataset_from_labels(
    data_file, label, num_samples, train_classes, train_n_blocks
)

# Create validation dataset
val_set = deepbeam_dataset_from_labels(
    data_file, label, num_samples, val_classes, val_n_blocks
)

# Create test dataset
test_set = deepbeam_dataset_from_labels(
    data_file, label, num_samples, test_classes, test_n_blocks
)

# Print the number of samples in the training dataset
print(len(train_set))


In [None]:
import numpy as np
import torch
from collections import defaultdict
import random

class FewShotBatchSampler(object):
    def __init__(self, total_frames, frames_per_class, N_way, K_shot, num_classes=24, class_set=[], include_query=False, shuffle=True, shuffle_once=False, blocks=50000):
        """
        A batch sampler for few-shot learning, designed to sample batches of N-way K-shot examples.

        Args:
            total_frames (int): Total number of frames in the dataset.
            frames_per_class (int): Number of frames per class before repeating.
            N_way (int): Number of classes to sample per batch.
            K_shot (int): Number of examples to sample per class in the batch.
            num_classes (int, optional): Number of unique classes in a complete set before repeating. Defaults to 5.
            class_set (list, optional): List of classes to include in the sampling. Defaults to an empty list.
            include_query (bool, optional): If True, doubles the K_shot for support and query sets. Defaults to False.
            shuffle (bool, optional): If True, examples and classes are newly shuffled in each iteration (for training). Defaults to True.
            shuffle_once (bool, optional): If True, examples and classes are shuffled once at the beginning (for validation). Defaults to False.
            blocks (int, optional): Number of blocks to skip between selected indices. Defaults to 50000.
        """
        self.total_frames = total_frames
        self.frames_per_class = frames_per_class
        self.N_way = N_way
        self.K_shot = K_shot * 2 if include_query else K_shot
        self.num_classes = num_classes
        self.shuffle = shuffle
        self.shuffle_once = shuffle_once
        self.include_query = include_query  # Ensure this attribute is correctly initialized
        self.classes = class_set
        self.num_blocks = blocks

        # Validate N_way against the number of available classes
        if N_way > len(self.classes):
            raise ValueError("Number of sampled classes (N_way) cannot be larger than the number of available classes")

        self.indices_per_class = {}
        self.batches_per_class = {}

        frames_per_class = 150000
        num_classes_all = 24
        gain_step = frames_per_class * num_classes_all

        # Generate indices for each class and store them
        selected_indices = []
        for c in self.classes:
            start_frame = (c * frames_per_class)
            end_frame = (start_frame + frames_per_class)
            start_idx = start_frame
            end_idx = end_frame

            # Generate indices for this class and append to selected_indices
            selected_indices = np.arange(start_idx, end_idx, self.num_blocks)
            selected_indices = np.concatenate([
                selected_indices,
                np.arange(start_idx + gain_step, end_idx + gain_step, self.num_blocks),
                np.arange(start_idx + (2 * gain_step), end_idx + (2 * gain_step), self.num_blocks)
            ])
            self.indices_per_class[c] = selected_indices
            self.batches_per_class[c] = self.indices_per_class[c].shape[0] // self.K_shot
            selected_indices = []

        self.batch_size = self.N_way * self.K_shot
        self.iterations = sum(self.batches_per_class.values()) // self.N_way
        self.class_list = [c for c in self.classes for _ in range(self.batches_per_class[c])]

        # Shuffle data once if shuffle_once is set
        if shuffle_once:
            self.shuffle_data()

    def shuffle_data(self):
        """Shuffles the indices per class and the class list."""
        for c in self.classes:
            perm = torch.randperm(self.indices_per_class[c].shape[0])
            self.indices_per_class[c] = self.indices_per_class[c][perm]
        random.shuffle(self.class_list)
        # Shuffle the class list from which we sample. Note that this way of shuffling
        # does not prevent to choose the same class twice in a batch. However, for
        # training and validation, this is not a problem.
    def __iter__(self):
        """Yields batches of indices for the few-shot learning task."""
        if self.shuffle:
            self.shuffle_data()

        start_index = defaultdict(int)
        for it in range(self.iterations):
            class_batch = self.class_list[it * self.N_way : (it + 1) * self.N_way]  # Select N classes for the batch
            index_batch = []
            for c in class_batch:  # For each class, select the next K examples and add them to the batch
                index_batch.extend(self.indices_per_class[c][start_index[c] : start_index[c] + self.K_shot])
                start_index[c] += self.K_shot

            if self.include_query:  # If we return support + query set, sort them so that they are easy to split
                index_batch = index_batch[::2] + index_batch[1::2]

            yield index_batch

    def __len__(self):
        """Returns the number of iterations (batches) per epoch."""
        return self.iterations


In [None]:
import random

class CustomSampler(object):
    def __init__(self, indices, shuffle=False):
        """
        Custom sampler for iterating over a list of indices.

        Args:
            indices (array-like): List or array of indices to sample from.
            shuffle (bool, optional): If True, shuffles the indices before sampling. Defaults to False.
        """
        super().__init__()
        # Assuming 'indices' is a numpy array; convert it to a list to ensure compatibility
        self.indices = indices.tolist()
        if shuffle:
            random.shuffle(self.indices)

        # Debugging print to check the type of indices
        print(type(self.indices))

    def __iter__(self):
        """
        Yields indices one by one.

        Yields:
            int: The next index from the list of indices.
        """
        for index in self.indices:
            yield index

    def __len__(self):
        """
        Returns the number of indices.

        Returns:
            int: The length of the indices list.
        """
        return len(self.indices)

# Example usage:
# custom_sampler = CustomSampler(indices=test_indi, shuffle=False)
# data_loader = data.DataLoader(test_set, sampler=custom_sampler, batch_size=100)


In [None]:
# Create DataLoader for training set
train_data_loader = data.DataLoader(
    train_set,
    batch_sampler=FewShotBatchSampler(
        total_frames=10800000,        # Total number of frames in the dataset
        frames_per_class=150000,      # Number of frames per class before repeating
        N_way=8,                      # Number of classes to sample per batch
        K_shot=4,                     # Number of examples to sample per class in the batch
        num_classes=len(train_classes),  # Total number of unique classes in the training set
        class_set=train_classes.numpy(), # Classes to include in sampling, adjusted based on dataset
        include_query=True,           # If True, doubles the K_shot for support and query sets
        shuffle=True,                 # Shuffle examples and classes in each iteration (for training)
        blocks=train_n_blocks         # Number of blocks to skip between selected indices
    ),
    num_workers=8  # Number of worker processes to use for data loading
)

# Create DataLoader for validation set
val_data_loader = data.DataLoader(
    val_set,
    batch_sampler=FewShotBatchSampler(
        total_frames=10800000,        # Total number of frames in the dataset
        frames_per_class=150000,      # Number of frames per class before repeating
        N_way=11,                     # Number of classes to sample per batch
        K_shot=3,                     # Number of examples to sample per class in the batch
        num_classes=len(val_classes), # Total number of unique classes in the validation set
        class_set=val_classes.numpy(),# Classes to include in sampling, adjusted based on dataset
        include_query=True,           # If True, doubles the K_shot for support and query sets
        shuffle=True,                 # Shuffle examples and classes in each iteration (for training)
        blocks=val_n_blocks           # Number of blocks to skip between selected indices
    ),
    num_workers=8  # Number of worker processes to use for data loading
)

# Print the number of batches in the training DataLoader
print(len(train_data_loader))

# Print the number of batches in the validation DataLoader
print(len(val_data_loader))


In [None]:
def split_batch(IQ, targets):
    """
    Splits a batch of IQ samples and targets into support and query sets.

    Args:
        IQ (Tensor): Batch of IQ samples, assumed to be of shape (2 * batch_size, ...).
        targets (Tensor): Batch of targets, assumed to be of shape (2 * batch_size, ...).

    Returns:
        Tuple[Tensor, Tensor, Tensor, Tensor]:
            - support_IQ (Tensor): Support set IQ samples of shape (batch_size, ...).
            - query_IQ (Tensor): Query set IQ samples of shape (batch_size, ...).
            - support_targets (Tensor): Support set targets of shape (batch_size, ...).
            - query_targets (Tensor): Query set targets of shape (batch_size, ...).
    """
    # Split the IQ samples tensor into support and query sets along the first dimension
    support_IQ, query_IQ = IQ.chunk(2, dim=0)

    # Split the targets tensor into support and query sets along the first dimension
    support_targets, query_targets = targets.chunk(2, dim=0)

    return support_IQ, query_IQ, support_targets, query_targets


In [None]:
def get_convnet(output_size):
    """
    Creates a DenseNet-based convolutional neural network for a specified output size.

    Args:
        output_size (int): The number of output classes for the classification task.

    Returns:
        torchvision.models.DenseNet: A DenseNet model configured with the specified output size.
    """
    # Create a DenseNet model with the specified configuration and output size
    convnet = torchvision.models.DenseNet(
        growth_rate=32,          # Number of filters to add each layer (growth rate)
        block_config=(5, 5, 5),  # Number of layers in each dense block
        bn_size=1,               # Multiplicative factor for bottleneck layers
        num_init_features=64,    # Number of filters in the initial convolution layer
        num_classes=output_size  # Output dimensionality
    )
    return convnet


In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
import pytorch_lightning as pl

class ProtoNet(pl.LightningModule):
    def __init__(self, proto_dim, lr):
        """
        Initializes the ProtoNet model.

        Args:
            proto_dim (int): Dimensionality of the prototype feature space.
            lr (float): Learning rate for the Adam optimizer.
        """
        super().__init__()
        self.save_hyperparameters()
        self.model = get_convnet(output_size=self.hparams.proto_dim)
        self.model.features.conv0 = nn.Conv2d(1, 64, kernel_size=(2, 2), stride=(1, 1), padding=(3, 3), bias=False)

    def configure_optimizers(self):
        """
        Configures the optimizers and learning rate scheduler.

        Returns:
            list: List of optimizers.
            list: List of learning rate schedulers.
        """
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1)
        return [optimizer], [scheduler]

    @staticmethod
    def calculate_prototypes(features, targets):
        """
        Calculates class prototypes from features and targets.

        Args:
            features (Tensor): Tensor of shape [N, proto_dim] containing feature vectors.
            targets (Tensor): Tensor of shape [N] containing class labels.

        Returns:
            Tensor: Prototypes for each class.
            Tensor: Unique classes.
        """
        classes, _ = torch.unique(targets).sort()  # Determine which classes we have
        prototypes = []
        for c in classes:
            p = features[torch.where(targets == c)[0]].mean(dim=0)  # Average class feature vectors
            p = p / p.norm(dim=0, keepdim=True)  # Normalize the prototype
            prototypes.append(p)
        prototypes = torch.stack(prototypes, dim=0)
        return prototypes, classes

    def classify_feats(self, prototypes, classes, feats, targets):
        """
        Classifies new examples with prototypes and returns classification error.

        Args:
            prototypes (Tensor): Tensor containing class prototypes.
            classes (Tensor): Tensor containing unique class labels.
            feats (Tensor): Tensor containing feature vectors to classify.
            targets (Tensor): Tensor containing true class labels for the features.

        Returns:
            Tensor: Log-softmax predictions.
            Tensor: True labels.
            Tensor: Accuracy of the predictions.
        """
        dist = torch.pow(prototypes[None, :] - feats[:, None], 2).sum(dim=2)  # Squared Euclidean distance
        preds = F.log_softmax(-dist, dim=1)
        labels = (classes[None, :] == targets[:, None]).long().argmax(dim=-1)
        acc = (preds.argmax(dim=1) == labels).float().mean()
        return preds, labels, acc


    def calculate_loss(self, batch, mode):
        """
        Calculates training loss for a given support and query set.

        Args:
            batch (tuple): Tuple containing images and targets.
            mode (str): Mode for logging, e.g., "train" or "val".

        Returns:
            Tensor: Calculated loss.
        """
        iqs, targets = batch
        features = self.model(iqs)  # Encode all IQ samples of support and query set
        support_feats, query_feats, support_targets, query_targets = split_batch(features, targets)
        prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
        preds, labels, acc = self.classify_feats(prototypes, classes, query_feats, query_targets)
        loss = F.cross_entropy(preds, labels)
        self.log(f"{mode}_loss", loss)
        self.log(f"{mode}_acc", acc)
        return loss

    def training_step(self, batch, batch_idx):
        """
        Performs a single training step.

        Args:
            batch (tuple): Tuple containing images and targets.
            batch_idx (int): Index of the current batch.

        Returns:
            Tensor: Calculated loss for the batch.
        """
        return self.calculate_loss(batch, mode="train")

    def validation_step(self, batch, batch_idx):
        """
        Performs a single validation step.

        Args:
            batch (tuple): Tuple containing images and targets.
            batch_idx (int): Index of the current batch.

        Returns:
            None
        """
        _ = self.calculate_loss(batch, mode="val")


In [None]:
import os
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

def train_model(model_class, train_loader, val_loader, **kwargs):
    """
    Trains a PyTorch Lightning model with the specified training and validation data loaders.

    Args:
        model_class (pl.LightningModule): The class of the model to be trained.
        train_loader (DataLoader): DataLoader for the training data.
        val_loader (DataLoader): DataLoader for the validation data.
        **kwargs: Additional keyword arguments for the model class initialization.

    Returns:
        pl.LightningModule: The trained model.
    """
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, model_class.__name__),
        accelerator="gpu" if str(device).startswith("cuda") else "cpu",
        devices=1,
        max_epochs=110,
        callbacks=[
            ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
            LearningRateMonitor("epoch")
        ],
        enable_progress_bar=True
    )
    trainer.logger._default_hp_metric = None

    # Check whether a pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, model_class.__name__ + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        # Automatically loads the model with the saved hyperparameters
        model = model_class.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything(42)  # Ensure reproducibility
        model = model_class(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = model_class.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)  # Load best checkpoint after training

    return model


In [None]:
# Train the ProtoNet model with specified parameters
protonet_model = train_model(
    ProtoNet,          # Model class to be trained
    proto_dim=128,     # Dimensionality of prototype feature space
    lr=1e-3,           # Learning rate for the optimizer
    train_loader=train_data_loader,  # DataLoader for the training data
    val_loader=val_data_loader       # DataLoader for the validation data
)
# Note: You can adjust the learning rate (lr) as needed, e.g., lr=2e-4


In [None]:
%tensorboard --logdir /mnt/c/Users/ict_3/saved_models/tutorial16/ProtoNet/lightning_logs/version_160

In [None]:
@torch.no_grad()
def test_proto_net(model, dataset, data_feats=None, k_shot=4):
    """
    Tests a pretrained ProtoNet model on a given dataset.

    Args:
        model (ProtoNet): Pretrained ProtoNet model.
        dataset (ImageDataset): The dataset on which the test should be performed.
        data_feats (tuple, optional): The encoded features of all images in the dataset.
                                      If None, they will be newly calculated and returned for later usage.
        k_shot (int, optional): Number of examples per class in the support set. Defaults to 4.

    Returns:
        tuple: Mean and standard deviation of accuracies, encoded features and targets, and prototypes.
    """
    model = model.to(device)
    model.eval()
    num_classes = len(test_classes)
    exmps_per_class = int(len(dataset) / num_classes)  # Assumes uniform example distribution

    if data_feats is None:
        # Dataset preparation
        custom_sampler = CustomSampler(indices=dataset.indices, shuffle=True)
        dataloader = data.DataLoader(dataset, sampler=custom_sampler, batch_size=32, num_workers=6)

        iq_features = []
        iq_targets = []
        for iqs, targets in tqdm(dataloader, "Extracting IQ samples features", leave=False):
            iqs = iqs.to(device)
            feats = model.model(iqs)
            iq_features.append(feats.detach().cpu())
            iq_targets.append(targets)
        iq_features = torch.cat(iq_features, dim=0)
        iq_targets = torch.cat(iq_targets, dim=0)

        # Sort by classes for easier processing later
        iq_targets, sort_idx = iq_targets.sort()
        iq_targets = iq_targets.reshape(num_classes, exmps_per_class).transpose(0, 1)
        iq_features = iq_features[sort_idx].reshape(num_classes, exmps_per_class, -1).transpose(0, 1)
    else:
        iq_features, iq_targets = data_feats

    accuracies = []

    # Evaluate the model using k-shot batches
    for k_idx in tqdm(range(0, iq_features.shape[0], k_shot), "Evaluating prototype classification", leave=False):
        # Select support set and calculate prototypes
        k_iq_feats, k_targets = iq_features[k_idx:k_idx+k_shot].flatten(0, 1), iq_targets[k_idx:k_idx+k_shot].flatten(0, 1)
        prototypes, proto_classes = model.calculate_prototypes(k_iq_feats, k_targets)

        # Evaluate accuracy on the rest of the dataset
        batch_acc = 0
        for e_idx in range(0, iq_features.shape[0], k_shot):
            if k_idx == e_idx:  # Do not evaluate on the support set examples
                continue
            e_iq_feats, e_targets = iq_features[e_idx:e_idx+k_shot].flatten(0, 1), iq_targets[e_idx:e_idx+k_shot].flatten(0, 1)
            _, _, acc = model.classify_feats(prototypes, proto_classes, e_iq_feats, e_targets)
            batch_acc += acc.item()

        batch_acc /= (len(range(0, iq_features.shape[0], k_shot)) - 1)
        accuracies.append(batch_acc)

    return (mean(accuracies), stdev(accuracies)), (iq_features, iq_targets), prototypes



In [None]:
# Dictionary to store accuracies for different k-shot values
protonet_accuracies = dict()

# Variable to store encoded features for later usage
data_feats = None

# Iterate over different k-shot values and test the ProtoNet model
for k in [2, 8, 16, 32, 64, 128]:
    # Test the model with the current k-shot value
    protonet_accuracies[k], data_feats, Proto = test_proto_net(protonet_model, test_set, data_feats=data_feats, k_shot=k)

    # Print the accuracy for the current k-shot value
    print(f"Accuracy for k={k}: {100.0 * protonet_accuracies[k][0]:4.2f}% (+-{100 * protonet_accuracies[k][1]:4.2f}%)")


In [None]:
ax = plot_few_shot(protonet_accuracies, name="ProtoNet", color="C1")
ax2 = plot_few_shot(protonet_accuracies, name="ProtoNet3", color="C8")

plt.show()
plt.close()

In [None]:
data_file2 =  "/mnt/c/Users/ict_3/deepBeam_matlab/neu_ww72bk483.h5"
#data_file3 = "/mnt/c/Users/ict_3/Documents/Downloads/neu_ww72bk46j.h5"
#ata_file4 = "/mnt/c/Users/ict_3/Documents/Downloads/neu_ww72bk458.h5"



test_set2 = deepbeam_dataset_from_labels(
    data_file2, label, num_samples, test_classes2,200)

test_set3 = deepbeam_dataset_from_labels(
    data_file3, label, num_samples, test_classes2,200)

test_set4 = deepbeam_dataset_from_labels(
    data_file4, label, num_samples, test_classes2,200)
print(len(test_set2))
num_classes=len(test_classes2)
#exmps_per_class = dataset.targets.shape[0]//num_classes  # We assume uniform example distribution here
exmps_per_class= len(test_set2)/num_classes
print(exmps_per_class)
# label_counts = defaultdict(int)
# i=0
# for i in test_set2.indices:
#     x,y = test_set2[i]
#     label_counts[y]+=1
# label_counts = dict(label_counts)
# print(label_counts)

In [None]:
@torch.no_grad()
def test_proto_net2(model, dataset, data_feats=None, k_shot=4,pr=0):
    """
    Inputs
        model - Pretrained ProtoNet model
        dataset - The dataset on which the test should be performed.
                  Should be instance of ImageDataset
        data_feats - The encoded features of all images in the dataset.
                     If None, they will be newly calculated, and returned
                     for later usage.
        k_shot - Number of examples per class in the support set.
    """
    model = model.to(device)
    model.eval()
    #num_classes = dataset.targets.unique().shape[0]
    num_classes=len(test_classes2)
    #exmps_per_class = dataset.targets.shape[0]//num_classes  # We assume uniform example distribution here
    exmps_per_class= int(len(dataset)/num_classes )   # The encoder network remains unchanged across k-shot settings. Hence, we only need
    # to extract the features for all images once.
    if data_feats is None:
        # Dataset preparation
        custom_sampler = CustomSampler(indices=dataset.indices,shuffle=True)

        dataloader = data.DataLoader(dataset, sampler=custom_sampler, batch_size=64,num_workers=6)
        #dataloader = data.DataLoader(dataset.indicies, batch_size=128, num_workers=4, shuffle=False, drop_last=False)

        iq_features = []
        iq_targets = []
        for iqs, targets in tqdm(dataloader, "Extracting IQ samples features", leave=False):
            iqs = iqs.to(device)
            feats = model.model(iqs)
            iq_features.append(feats.detach().cpu())
            iq_targets.append(targets)
        iq_features = torch.cat(iq_features, dim=0)
        iq_targets = torch.cat(iq_targets, dim=0)
        # Sort by classes, so that we obtain tensors of shape [num_classes, exmps_per_class, ...]
        # Makes it easier to process later
        iq_targets, sort_idx = iq_targets.sort()
        iq_targets = iq_targets.reshape(num_classes, exmps_per_class).transpose(0, 1)
        iq_features = iq_features[sort_idx].reshape(num_classes, exmps_per_class, -1).transpose(0, 1)
    else:
        iq_features, iq_targets = data_feats

    # We iterate through the full dataset in two manners. First, to select the k-shot batch.
    # Second, the evaluate the model on all other examples
    accuracies = []
    for k_idx in tqdm(range(0, iq_features.shape[0], k_shot), "Evaluating prototype classification", leave=False):
        # Select support set and calculate prototypes
        k_iq_feats, k_targets = iq_features[k_idx:k_idx+k_shot].flatten(0,1), iq_targets[k_idx:k_idx+k_shot].flatten(0,1)
        prototypes, proto_classes = model.calculate_prototypes(k_iq_feats, k_targets)
        #prototypes=(pr/5+prototypes)
        # Evaluate accuracy on the rest of the dataset
        batch_acc = 0
        for e_idx in range(0, iq_features.shape[0], k_shot):
            if k_idx == e_idx:  # Do not evaluate on the support set examples
                continue
            e_iq_feats, e_targets = iq_features[e_idx:e_idx+k_shot].flatten(0,1), iq_targets[e_idx:e_idx+k_shot].flatten(0,1)

            _, _, acc = model.classify_feats(prototypes, proto_classes, e_iq_feats, e_targets)

            batch_acc += acc.item()
        batch_acc /= (len(list(range(0, iq_features.shape[0], k_shot)))-1)
        accuracies.append(batch_acc)

    return (mean(accuracies), stdev(accuracies)), (iq_features, iq_targets)



In [None]:

protonet_DP2_accuracies = dict()
#data_feats = None
data_feats2=None
for k in [2,8,16,32,64,128]:
    protonet_DP2_accuracies[k], data_feats2 = test_proto_net2(protonet_model, test_set2, data_feats=data_feats2, k_shot=k,pr=0)
    print(f"Accuracy for k={k}: {100.0*protonet_DP2_accuracies[k][0]:4.2f}% (+-{100*protonet_DP2_accuracies[k][1]:4.2f}%)")
