In [1]:
from models.pretrained_model_transfer import pretrained_model
import os
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import datasets
from tllib.ranking.logme import log_maximum_evidence as LogME
import pprint
from datasets import load_dataset, DownloadConfig

import pickle
from torchvision.models import (
    MobileNet_V2_Weights, MNASNet1_0_Weights, DenseNet121_Weights,
    DenseNet169_Weights, DenseNet201_Weights, ResNet34_Weights,
    ResNet50_Weights, ResNet101_Weights, ResNet152_Weights,
    GoogLeNet_Weights, Inception_V3_Weights
)
import numpy as np
from gudhi.rips_complex import RipsComplex

In [37]:
def compute_persistence_diagram(features: np.ndarray, max_edge_length: float = 100):
    """
    Takes feature embeddings and computes the H0 persistence diagram.
    This function encapsulates the Gudhi TDA calculation.
    """
    if features.shape[0] == 0:
        print("Warning: Input features for PD are empty.")
        return None

    try:
        print("Computing Persistence Diagram...")
        features_contiguous = np.ascontiguousarray(features, dtype=np.float64)
        rips_complex = RipsComplex(points=features_contiguous, max_edge_length=max_edge_length)
        simplex_tree = rips_complex.create_simplex_tree(max_dimension=1)
        diag = simplex_tree.persistence()
        h0_diag = [(pair[1][0], pair[1][1]) for pair in diag if pair[0] == 0 and pair[1][1] != float('inf')]
        print(f"Computed H0 persistence diagram with {len(h0_diag)} finite bars.")
        return h0_diag
    except Exception as e:
        print(f"Error during Gudhi computation: {e}")
        return None

def intercd(features: np.ndarray, labels: np.ndarray, pd_max_edge_length: float = 50):
    """
    Computes the Inter-level Persistence Score from feature embeddings.

    This score is the average persistence of the k-1 most persistent topological
    features (connected components), where k is the number of unique classes.

    Args:
        features (np.ndarray): The feature embeddings from the model.
        labels (np.ndarray): The ground-truth labels for the features.
        pd_max_edge_length (float): Maximum edge length for the Rips complex.

    Returns:
        The Inter-level Persistence score as a float. Returns 0.0 if the
        computation cannot be completed.
    """
    # 1. Basic check for valid input
    if features.shape[0] == 0:
        print("Warning: Input 'features' is empty. Skipping InterCD calculation.")
        return 0.0

    # 2. Compute the H0 persistence diagram using Gudhi
    h0_diag = None  # Initialize to None
    print("Computing Persistence Diagram...")
    try:
        # Gudhi requires a C-contiguous array for performance
        features_contiguous = np.ascontiguousarray(features, dtype=np.float64)

        rips_complex = RipsComplex(points=features_contiguous, max_edge_length=pd_max_edge_length)
        simplex_tree = rips_complex.create_simplex_tree(max_dimension=1)
        diag = simplex_tree.persistence()

        # Filter for H0 (dimension 0) features and pairs with a finite death time
        h0_diag = [(pair[1][0], pair[1][1]) for pair in diag if pair[0] == 0 and pair[1][1] != float('inf')]
        print(f"Computed H0 persistence diagram with {len(h0_diag)} finite bars.")
    except Exception as e:
        print(f"Error during Gudhi computation: {e}")
        # h0_diag will remain None if an error occurs

    # 3. Check if the diagram was successfully computed
    if h0_diag is None or not h0_diag:
        print("Warning: Persistence diagram could not be computed or is empty.")
        return 0.0  # Return a default score of 0

    # 4. Calculate the score from the diagram
    k = len(np.unique(labels))
    if k <= 1:
        print("Warning: Cannot compute score with only one class (k=1).")
        return 0.0

    # Get a sorted list of all persistence values (death_time - birth_time)
    # For H0, birth_time is always 0, so this is just a list of death_times.
    persistences = sorted([death - birth for birth, death in h0_diag], reverse=True)

    # 5. Check if we have enough persistence bars for the calculation
    if len(persistences) < k - 1:
        print(f"Warning: Not enough persistence bars ({len(persistences)}) to compute score for k={k} classes. Returning 0.")
        return 0.0

    # The score is the average of the k-1 most persistent features
    persist_inter = np.array(persistences[:k - 1])
    score = np.mean(persist_inter)

    return score

In [38]:


def score_model(configs, score_loader, device):
    print(f'Calc Transferabilities of {configs.model} on {configs.dataset}')

    if configs.model == 'inception_v3':
        model = models.__dict__[configs.model](pretrained=True, aux_logits=False).cuda()
    else:
        model = models.__dict__[configs.model](pretrained=True).cuda()

    # different models has different linear projection names
    if configs.model in ['mobilenet_v2', 'mnasnet1_0']:
        fc_layer = model.classifier[-1]
    elif configs.model in ['densenet121', 'densenet169', 'densenet201']:
        fc_layer = model.classifier
    elif configs.model in ['resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3']:
        fc_layer = model.fc
    else:
        # try your customized model
        raise NotImplementedError

    print('Conducting features extraction...')
    model = pretrained_model(model, fc_layer)
    features, outputs, targets = forward_pass_with_wrapper(score_loader, model, device)
    # predictions = F.softmax(outputs)

    print('Conducting transferability calculation...')
    logme_score = LogME(features.numpy(), targets.numpy(), regression=False)




    sample_size = 2048
    num_samples = features.shape[0]

    # Check if the dataset is large enough to sample from
    if num_samples > sample_size:
        print(f"Taking a random sample of {sample_size} from {num_samples} total samples for intercd score.")

        # 1. Generate a random permutation of indices from 0 to num_samples-1
        indices = torch.randperm(num_samples)

        # 2. Select the first `sample_size` indices from the random permutation
        sample_indices = indices[:sample_size]

        # 3. Use these indices to select a subset of features and targets
        sampled_features = features[sample_indices]
        sampled_targets = targets[sample_indices]
    else:
        # If the dataset is smaller than the desired sample size, just use all of it
        print(f"Dataset size ({num_samples}) is smaller than sample size ({sample_size}), using all data for intercd score.")
        sampled_features = features
        sampled_targets = targets
    # --- END: New code for random sampling ---

    # Calculate intercd_score on the (potentially smaller) sampled dataset
    intercd_score = intercd(sampled_features.numpy(), sampled_targets.numpy())
    print(f'Intercd score: {intercd_score}')

    # save calculated bayesian weight
#    torch.save(logme.ms, f'logme_{configs.dataset}/weight_{configs.model}.pth')

    print(f'LogME of {configs.model}: {logme_score}\n')
    return logme_score



In [39]:
from torch.utils.data import Dataset
from torch.utils.data.dataloader import default_collate
models_hub = ['mobilenet_v2', 'mnasnet1_0', 'densenet121', 'densenet169', 'densenet201',
              'resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3']
#models_hub = ['mobilenet_v2']

def collate_fn_skip_corrupt(batch):
    batch = [item for item in batch if item is not None]

    if not batch:
        return torch.tensor([]), torch.tensor([])

    images, labels = zip(*batch)

    # The 'labels' tuple now correctly contains integers.
    # The [int(lbl) for lbl in labels] line is no longer needed.

    images_batch = torch.stack(images, 0)
    labels_batch = torch.tensor(labels, dtype=torch.long)

    return images_batch, labels_batch

class HuggingFaceDatasetWrapper(Dataset):
    """
    The updated, more general wrapper.
    """
    def __init__(self, hf_dataset, transform=None, label_map=None):
        self.hf_dataset = hf_dataset
        self.transform = transform
        self.label_map = label_map

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        try:
            item = self.hf_dataset[idx]
            image = item['image']
            raw_label = item['label']

            # =================== THE FIX ===================
            # Use the label map only if one was provided
            if self.label_map:
                # This block runs for Birdsnap
                label = self.label_map[raw_label]
            else:
                # This block runs for SUN397
                label = raw_label
            # ===============================================

            # Minor bug fix: Convert to RGB *before* transforming
            if image.mode != 'RGB':
                image = image.convert('RGB')

            if self.transform:
                image = self.transform(image)

            return image, label
        except OSError as e:
            print(f"Warning: Skipping corrupt image at index {idx}: {e}")
            return None


class CustomSUN397Dataset(Dataset):
    """
    A custom PyTorch Dataset for the manually downloaded SUN397 data.
    It reads a specific split file (e.g., Training_01.txt) to load data.
    """
    def __init__(self, image_root_dir, split_file, class_name_file, transform=None):
        """
        Args:
            image_root_dir (string): Path to the SUN397 directory with all the images.
            split_file (string): Path to the text file defining the data split (e.g., Training_01.txt).
            class_name_file (string): Path to the ClassName.txt file.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_root_dir = image_root_dir
        self.transform = transform

        # 1. Create a mapping from class path (e.g., '/a/abbey') to an integer index (e.g., 0)
        with open(class_name_file, 'r') as f:
            # Reads '/a/abbey\n' and maps it to its line number (index)
            self.class_to_idx = {line.strip(): i for i, line in enumerate(f)}

        # 2. Read the image paths and create corresponding labels
        self.image_paths = []
        self.labels = []
        with open(split_file, 'r') as f:
            for line in f:
                relative_path = line.strip()
                self.image_paths.append(relative_path)

                # Extract class path from image path (e.g., '/a/abbey' from '/a/abbey/sun_...jpg')
                class_path = os.path.dirname(relative_path)
                self.labels.append(self.class_to_idx[class_path])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Construct the full image path
        # We use lstrip('/') to remove the leading '/' so os.path.join works correctly
        img_path = os.path.join(self.image_root_dir, self.image_paths[idx].lstrip('/'))

        # Load the image and ensure it's in RGB format
        image = Image.open(img_path).convert('RGB')

        # Get the integer label
        label = self.labels[idx]

        # Apply transformations if they exist
        if self.transform:
            image = self.transform(image)

        return image, label

def get_dataset(dataset_name: str, transform: transforms.Compose, data_path: str = 'data'):
    """
    Loads or downloads the specified dataset.

    :param dataset_name: The name of the dataset to load.
    :param transform: The torchvision transforms to apply to the dataset.
    :param data_path: The root directory to save downloaded datasets.
    :return: A torch.utils.data.Dataset object.
    """
    # Create the root data directory if it doesn't exist
    os.makedirs(data_path, exist_ok=True)
    name = dataset_name.lower()

    # --- Handle ImageNet with the hardcoded relative path ---
    if name == 'imagenet':
        print("Using pre-downloaded ImageNet from relative path...")
        # Go up 5 levels from /home/alpaca/Tesis/networks/resultsagg/analysis/transfer_benchmarks/
        # to /home/alpaca/, then into datasets/imagenet/train
        # os.path.join is used for cross-platform compatibility.
        imagenet_path = os.path.join('../../../../../datasets/imagenet', 'train')
        if not os.path.isdir(imagenet_path):
            raise FileNotFoundError(f"ImageNet path not found at relative location: {imagenet_path}")
        return datasets.ImageFolder(root=imagenet_path, transform=transform)

    # --- Handle standard torchvision datasets ---
    # The `download=True` flag will automatically download the data if not found in `root`.
    elif name == 'aircraft':
        return datasets.FGVCAircraft(root=data_path, split='train', transform=transform, download=True)
    elif name == 'caltech101':
        print("Loading Caltech-101 from local ImageFolder...")

        # Define the path directly to the folder containing the class subdirectories
        image_path = os.path.join('../../../../../datasets', 'caltech-101', '101_ObjectCategories')

        # ImageFolder will automatically find all class folders and their images
        return datasets.ImageFolder(root=image_path, transform=transform)

    elif name == 'cifar10':
        return datasets.CIFAR10(root=data_path, train=True, transform=transform, download=True)
    elif name == 'cifar100':
        return datasets.CIFAR100(root=data_path, train=True, transform=transform, download=True)
    elif name == 'dtd':
        return datasets.DTD(root=data_path, split='train', transform=transform, download=True)
    elif name == 'oxfordiiitpets':
        return datasets.OxfordIIITPet(root=data_path, split='trainval', transform=transform, download=True)

    elif name == 'stanfordcars':
        print("Loading Stanford Cars from local ImageFolder...")

        # Define the path directly to your training data folder
        train_path = os.path.join('../../../../../datasets', 'stanford_cars', 'train')

        # ImageFolder handles everything automatically!
        return datasets.ImageFolder(root=train_path, transform=transform)

    # ... (other cases) .
    elif name == 'sun397':
        print("Loading SUN397 from local custom path...")

        # Define the paths based on your file structure
        base_dir = "../../../../../datasets/SUN_dataset/data"
        image_root_dir = os.path.join(base_dir, "SUN397")
        class_name_file = os.path.join(base_dir, "ClassName.txt")

        # Select which split to use. Here we use the first training split.
        # You can easily change this to "Testing_01.txt" to load the test set.
        split_file = os.path.join(base_dir, "Training_01.txt")

        # Instantiate and return your custom dataset
        return CustomSUN397Dataset(
            image_root_dir=image_root_dir,
            split_file=split_file,
            class_name_file=class_name_file,
            transform=transform
        )
    elif name == 'birdsnap':
        print("Loading 'sasha/birdsnap' and creating label map...")
        # Load the training split

        download_config = DownloadConfig( # Set timeout to 300 seconds (5 minutes)
            max_retries=5,  # Allow up to 5 retries on failure
        )
        hf_train_dataset = load_dataset("sasha/birdsnap", split='train')

        # =================== THE FIX ===================
        # Get all unique string labels from the dataset
        all_string_labels = sorted(hf_train_dataset.unique("label"))
        # Create the mapping from string name to integer index
        label_map = {name: i for i, name in enumerate(all_string_labels)}
        print(f"Created map for {len(label_map)} unique classes.")
        # ===============================================

        # Pass the dataset AND the new map to the wrapper
        return HuggingFaceDatasetWrapper(
            hf_dataset=hf_train_dataset,
            transform=transform,
            label_map=label_map
        )
        # Use the wrapper to make it compatible with PyTorch's DataLoader

    else:
        raise NotImplementedError(f"Dataset '{dataset_name}' is not supported.")

def forward_pass_with_wrapper(score_loader, wrapped_model, device = 'cuda'):
    """
    A forward pass on the target dataset using the pretrained_model wrapper.

    :params score_loader: The dataloader for scoring transferability.
    :params wrapped_model: The pretrained_model wrapper instance, which handles hook management internally.
    :returns:
        features: Extracted features (input to the final linear layer).
        outputs: The final outputs of the model.
        targets: Ground-truth labels of the dataset.
    """
    features = []
    outputs = []
    targets = []

    # No need to define a local hook function or manually register/remove it.
    # The wrapper class handles this automatically.

    wrapped_model.eval()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(score_loader):
                        # --- START DEBUG ---
            # Print the type and value of target for each batch
           # print(f"Batch {batch_idx}: type(target) is {type(target)}")
            # --- END DEBUG ---
            if data.nelement() == 0:
                print('git gut, continue, something with 0 elements, not cool')
                continue

            # Move data to the same device as the model
            data = data.to(wrapped_model.device)
            targets.append(target)

            # 1. Run the forward pass to get the final model outputs.
            # As a side effect, the model's internal hook will automatically
            # capture the input to the final layer.
            batch_outputs = wrapped_model(data)

            # 2. Access the feature captured by the hook.
            # The `_captured_feature` attribute is updated during the forward pass.
            batch_features = wrapped_model._captured_feature

            # Append the results for the current batch
            features.append(batch_features.cpu())
            outputs.append(batch_outputs.cpu())

    # Concatenate all batch results into single tensors
    features = torch.cat(features)
    outputs = torch.cat(outputs)
    targets = torch.cat(targets)

    return features, outputs, targets

In [40]:
# This class replaces the get_configs() function and argparse
class NotebookConfigs:
    """
    A simple class to hold the configuration parameters for the notebook.
    This mimics the behavior of the object returned by parser.parse_args().
    """

    def __init__(self):
        # --- Execution Parameters ---
        self.gpu = 'cuda'
        self.batch_size = 32
        self.num_workers = 0

        # --- Dataset to be analyzed ---
        # YOU CAN CHANGE THIS VALUE IN YOUR NOTEBOOK
        # Options: 'aircraft', 'caltech101', 'cifar10', 'cifar100',
        #          'dtd', 'oxfordiiitpets', 'stanfordcars', 'sun397', 'imagenet'
        self.dataset = None

        # The 'model' attribute will be set inside the loop in main()
        self.model = None

# --- In a notebook cell, you would create the configs object like this ---
configs = NotebookConfigs()


In [41]:
def extract_and_save(configs, score_loader, device):
    """
    This function replaces the old 'score_model'. It runs the forward pass,
    pre-computes necessary items, and saves everything to disk.
    """
    print(f"--- Starting extraction for model: {configs.model} on dataset: {configs.dataset} ---")

    # =====================================================================
    # --- THIS IS THE MODIFIED SECTION ---
    # 1. Define the mapping from model name string to the correct V1 weights
    weights_mapping = {
        'mobilenet_v2': MobileNet_V2_Weights.IMAGENET1K_V1,
        'mnasnet1_0': MNASNet1_0_Weights.IMAGENET1K_V1,
        'densenet121': DenseNet121_Weights.IMAGENET1K_V1,
        'densenet169': DenseNet169_Weights.IMAGENET1K_V1,
        'densenet201': DenseNet201_Weights.IMAGENET1K_V1,
        'resnet34': ResNet34_Weights.IMAGENET1K_V1,
        'resnet50': ResNet50_Weights.IMAGENET1K_V1,
        'resnet101': ResNet101_Weights.IMAGENET1K_V1,
        'resnet152': ResNet152_Weights.IMAGENET1K_V1,
        'googlenet': GoogLeNet_Weights.IMAGENET1K_V1,
        'inception_v3': Inception_V3_Weights.IMAGENET1K_V1
    }

    # 2. Load Model using the explicit V1 weights
    weights = weights_mapping[configs.model]
    if configs.model == 'inception_v3':
        model = models.__dict__[configs.model](weights=weights, aux_logits=True).to(device)
    else:
        model = models.__dict__[configs.model](weights=weights).to(device)
    # --- END OF MODIFIED SECTION ---
    # =====================================================================


    # --- Find fc_layer (logic remains the same) ---
    if configs.model in ['mobilenet_v2', 'mnasnet1_0']:
        fc_layer = model.classifier[-1]
    elif 'densenet' in configs.model:
        fc_layer = model.classifier
    elif 'resnet' in configs.model or configs.model in ['googlenet', 'inception_v3']:
        fc_layer = model.fc
    else:
        raise NotImplementedError(f"Classifier layer logic not defined for model: {configs.model}")


    # 2. Extract Features, Outputs, and Targets
    print('Conducting features extraction...')
    wrapped_model = pretrained_model(model, fc_layer, device=device)
    print('Model instantiated, intializing forward pass')
    features, outputs, targets = forward_pass_with_wrapper(score_loader, wrapped_model)
    print('Forward pass successful')
    # 3. Pre-compute anything else needed for scoring
    #    - Predicted labels for NCE score
    predicted_labels = torch.argmax(outputs, dim=1)
    #    - Persistence diagram for InterCD score (using a sub-sample)
    print('subsampling for persistent diagram')
    sample_size = 4096
    if features.shape[0] > sample_size:
        indices = torch.randperm(features.shape[0])[:sample_size]
        sampled_features_for_pd = features[indices]
    else:
        sampled_features_for_pd = features
    print('computing persistent diagram')
    persistence_diagram = compute_persistence_diagram(sampled_features_for_pd.numpy())

    # 4. Define Save Directory and Save Everything
    save_dir = os.path.join('extraction_results', configs.dataset, configs.model)
    os.makedirs(save_dir, exist_ok=True)
    print(f"Saving results to: {save_dir}")

    # Move to CPU before saving to ensure portability
    torch.save(features.cpu(), os.path.join(save_dir, 'features.pt'))
    torch.save(outputs.cpu(), os.path.join(save_dir, 'outputs.pt'))
    torch.save(targets.cpu(), os.path.join(save_dir, 'targets.pt'))
    torch.save(predicted_labels.cpu(), os.path.join(save_dir, 'predicted_labels.pt'))

    if persistence_diagram is not None:
        with open(os.path.join(save_dir, 'persistence_diagram.pkl'), 'wb') as f:
            pickle.dump(persistence_diagram, f)

    print(f"--- Finished extraction for model: {configs.model} ---")

def main_extraction():
   # configs = NotebookConfigs() # Assuming you have your NotebookConfigs class
    device = "cuda" if torch.cuda.is_available() else "cpu"

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    # Loop through your models
    for model_name in models_hub:
        configs.model = model_name
        if model_name == 'inception_v3':
            transform = transforms.Compose([transforms.Resize((299, 299)), transforms.ToTensor(), normalize])
        else:
            transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), normalize])

        # --- Dataset Loading and Sub-sampling Logic ---
        print('getting dataset')
        full_dataset = get_dataset(dataset_name=configs.dataset, transform=transform)
        print('getting model')
        # Using a smaller sample size for demonstration
        num_samples_to_use = 250000
        if configs.dataset.lower() == 'imagenet' and len(full_dataset) > num_samples_to_use:
            # Using stratified sampling
            print('getting targets')
            targets = full_dataset.targets
            indices = np.arange(len(full_dataset))
            from sklearn.model_selection import train_test_split
            print('splitting dataset into training and test sets')
            subset_indices, _ = train_test_split(indices, train_size=num_samples_to_use, stratify=targets, random_state=42)
            score_dataset = torch.utils.data.Subset(full_dataset, subset_indices)

        else:
            score_dataset = full_dataset

        print(f"Initializing DataLoader with {len(score_dataset)} samples.")
        score_loader = DataLoader(score_dataset, batch_size=configs.batch_size, shuffle=False,
                                  num_workers=0, pin_memory=True, collate_fn=collate_fn_skip_corrupt) # Use num_workers=0 to avoid issues

        # Run the extraction and saving process
        print(f"Initializing model: {configs.model} on dataset: {configs.dataset} ---")
        extract_and_save(configs, score_loader, device)


def main():

    # You can then easily check the values
    print(f"Running analysis on dataset: {configs.dataset}")
    print(f"Using batch size: {configs.batch_size}")


    device = "cuda" if torch.cuda.is_available() else "cpu"

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    if not os.path.isdir(f'logme_{configs.dataset}'):
        os.mkdir(f'logme_{configs.dataset}')
    score_dict = {}
    for model in models_hub:
        configs.model = model
        if model == 'inception_v3':  # inception_v3 is pretrained on 299x299 images
            transform = transforms.Compose([
                transforms.Resize((299, 299)),
                transforms.ToTensor(),
                normalize
            ])
        else:
            transform = transforms.Compose([  # other models are pretrained on 224x224 images
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                normalize
            ])
            #############################ACA VER TEMA DE DATA####################################

        print('getting dataset')
        full_score_dataset = get_dataset(dataset_name=configs.dataset, transform=transform)
        print('getting datset success')
        # Define the desired number of samples
        num_samples_to_use = 250000

        # --- Sub-sampling Logic ---
        # Check if the current dataset is 'imagenet' and if it's larger than our target sample size
        if configs.dataset.lower() == 'imagenet' and len(full_score_dataset) > num_samples_to_use:
            print(f"Dataset is ImageNet. Sub-sampling from {len(full_score_dataset)} to {num_samples_to_use} images.")

            # 1. Generate a random permutation of indices for the full dataset
            indices = torch.randperm(len(full_score_dataset))

            # 2. Select the first N indices from the shuffled list
            subset_indices = indices[:num_samples_to_use]

            # 3. Create a Subset wrapper using the original dataset and the random indices
            score_dataset = torch.utils.data.Subset(full_score_dataset, subset_indices)

        else:
            # If the dataset is not ImageNet, or if it's smaller than the desired sample size,
            # we use the original, full dataset.
            score_dataset = full_score_dataset
        # --- End of Logic ---


        # Initialize the DataLoader with the (potentially smaller) dataset
        print(f"Initializing DataLoader with {len(score_dataset)} samples.")
        score_loader = DataLoader(score_dataset, batch_size=configs.batch_size, shuffle=False,
                                  num_workers=configs.num_workers, pin_memory=True)

        # or try your customized dataset
        score_loader = DataLoader(score_dataset, batch_size=configs.batch_size, shuffle=False,
                                  num_workers=configs.num_workers, pin_memory=True)
        score_dict[model] = score_model(configs, score_loader, device)
    results = sorted(score_dict.items(), key=lambda i: i[1], reverse=True)
    torch.save(score_dict, f'logme_{configs.dataset}/results.pth')
    print(f'Models ranking on {configs.dataset}: ')
    pprint.pprint(results)


In [43]:
%%time

datasets_to_run = [
    'aircraft',
    'cifar10',
    'cifar100',
    'dtd',
    'oxfordiiitpets',
    'sun397',
    'stanfordcars',
    'caltech101',
    'birdsnap',
    'imagenet'
]


#stanfordcars and caltech101 birdsnap, SUN missing

for current_dataset in datasets_to_run:
    configs.dataset = current_dataset
    main_extraction()

getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: mobilenet_v2 on dataset: birdsnap ---
--- Starting extraction for model: mobilenet_v2 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/mobilenet_v2
--- Finished extraction for model: mobilenet_v2 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: mnasnet1_0 on dataset: birdsnap ---
--- Starting extraction for model: mnasnet1_0 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/mnasnet1_0
--- Finished extraction for model: mnasnet1_0 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: densenet121 on dataset: birdsnap ---
--- Starting extraction for model: densenet121 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/densenet121
--- Finished extraction for model: densenet121 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: densenet169 on dataset: birdsnap ---
--- Starting extraction for model: densenet169 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/densenet169
--- Finished extraction for model: densenet169 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: densenet201 on dataset: birdsnap ---
--- Starting extraction for model: densenet201 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/densenet201
--- Finished extraction for model: densenet201 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: resnet34 on dataset: birdsnap ---
--- Starting extraction for model: resnet34 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/resnet34
--- Finished extraction for model: resnet34 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: resnet50 on dataset: birdsnap ---
--- Starting extraction for model: resnet50 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/resnet50
--- Finished extraction for model: resnet50 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: resnet101 on dataset: birdsnap ---
--- Starting extraction for model: resnet101 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/resnet101
--- Finished extraction for model: resnet101 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: resnet152 on dataset: birdsnap ---
--- Starting extraction for model: resnet152 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4094 finite bars.
Saving results to: extraction_results/birdsnap/resnet152
--- Finished extraction for model: resnet152 ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: googlenet on dataset: birdsnap ---
--- Starting extraction for model: googlenet on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/googlenet
--- Finished extraction for model: googlenet ---
getting dataset
Loading 'sasha/birdsnap' and creating label map...


Resolving data files:   0%|          | 0/127 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/100 [00:00<?, ?it/s]

Created map for 500 unique classes.
getting model
Initializing DataLoader with 39860 samples.
Initializing model: inception_v3 on dataset: birdsnap ---
--- Starting extraction for model: inception_v3 on dataset: birdsnap ---
Conducting features extraction...
Registering embedding hook on layer: Linear
Model instantiated, intializing forward pass
Forward pass successful
subsampling for persistent diagram
computing persistent diagram
Computing Persistence Diagram...
Computed H0 persistence diagram with 4095 finite bars.
Saving results to: extraction_results/birdsnap/inception_v3
--- Finished extraction for model: inception_v3 ---
getting dataset
Using pre-downloaded ImageNet from relative path...
getting model
getting targets
splitting dataset into training and test sets
Initializing DataLoader with 250000 samples.
Initializing model: mobilenet_v2 on dataset: imagenet ---
--- Starting extraction for model: mobilenet_v2 on dataset: imagenet ---
Conducting features extraction...
Registerin