Goal: We want to be able to visualize what points each batch selection method is choosing at each part of the training process of a model.

Label each point with the following - loss value, batch #, if included in selection method

# Set Up

## Install dependencies

In [2]:
import fiftyone.zoo as foz
import cv2
import numpy as np
import fiftyone.brain as fob
import torch
import torch.nn as nn
import models
import yaml
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import fiftyone as fo # Import fiftyone
from PIL import Image # Import Image
import detectors
import timm
#import transformers


# resnet18_cifar10 = timm.create_model("resnet18_cifar10", pretrained=True)
# resnet34_cifar10 = timm.create_model("resnet34_cifar10", pretrained=True)
# resnet50_cifar10 = timm.create_model("resnet50_cifar10", pretrained=True)
# resnet34_supcon_cifar10 = timm.create_model("resnet34_supcon_cifar10", pretrained=True)
resnet50_supcon_cifar10 = timm.create_model("resnet50_supcon_cifar10", pretrained=True)
#resnet50_simclr_cifar10 = timm.create_model("resnet50_simclr_cifar10", pretrained=True)
pretrained_models = {
    # "resnet18_cifar10": resnet18_cifar10,
    # "resnet34_cifar10": resnet34_cifar10,
    # "resnet50_cifar10": resnet50_cifar10,
    # "resnet34_supcon_cifar10": resnet34_supcon_cifar10,
    "resnet50_supcon_cifar10": resnet50_supcon_cifar10,
    #"resnet50_simclr_cifar10": resnet50_simclr_cifar10
}

## Set up ResNet model checkpoints

In [3]:
# Read config file
config_file = "cfg/cifar10.yaml"
with open(config_file, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        f.close()

model_type = config['networks']['type']
model_args = config['networks']['params']
empty_model = getattr(models, model_type)(**model_args)

checkpoints = ['visualization_checkpoints/checkpoint_epoch_' + str(checkpoint) + '.pth.tar' for checkpoint in range(0,200,25)]
model_params = torch.load(checkpoints[0], map_location=torch.device('cpu'), weights_only=False)
empty_model.load_state_dict(model_params['state_dict'])

<All keys matched successfully>

# FiftyOne Analysis


In [4]:
def prepare_embeddings(model, model_name, dataset):
    # Set the model to evaluation mode
    model.eval()

    batch_size = 128
    embeddings = []

    # Resize, normalize, and convert images to (C, H, W)
    target_size = (32, 32)  # required input size for the CNN

    image_tensors = []
    for f in dataset.values("filepath"):
        img = cv2.imread(f, cv2.IMREAD_COLOR)  # force 3-channel BGR
        img = cv2.resize(img, target_size)
        img = img.astype(np.float32) / 255.0  # normalize to [0,1]
        img = img.transpose(2, 0, 1)  # convert from (H, W, C) to (C, H, W)
        image_tensors.append(img)

    for i in range(0, len(image_tensors), batch_size):
        batch_np = np.stack(image_tensors[i:i+batch_size])
        batch_tensor = torch.tensor(batch_np)

        with torch.no_grad():
            out = model(batch_tensor)
            embeddings.append(out.cpu().numpy())

    embedding_outputs = np.concatenate(embeddings, axis=0)

    results = fob.compute_visualization(
        dataset,
        embeddings=embedding_outputs,
        num_dims=2,
        method="umap",
        brain_key=model_name+"_test",
        verbose=True,
        seed=51,
    )
    dataset.load_brain_results(model_name+"_test")
    return results

In [5]:
# Dataset shown in fiftyone application
test_split = foz.load_zoo_dataset("cifar10", split="test")

for model_name, model in pretrained_models.items():
    prepare_embeddings(model, model_name, test_split)

Split 'test' already downloaded
Loading existing dataset 'cifar10-test'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use
Generating visualization...


  warn(


UMAP(n_jobs=1, random_state=51, verbose=True)
Wed Sep 24 13:20:41 2025 Construct fuzzy simplicial set
Wed Sep 24 13:20:41 2025 Finding Nearest Neighbors
Wed Sep 24 13:20:41 2025 Building RP forest with 10 trees
Wed Sep 24 13:20:48 2025 NN descent for 13 iterations
	 1  /  13
	 2  /  13
	 3  /  13
	 4  /  13
	Stopping threshold met -- exiting after 4 iterations
Wed Sep 24 13:21:00 2025 Finished Nearest Neighbor Search
Wed Sep 24 13:21:03 2025 Construct embedding


Epochs completed:   2%| ▏          8/500 [00:00]

	completed  0  /  500 epochs


Epochs completed:  12%| █▏         58/500 [00:01]

	completed  50  /  500 epochs


Epochs completed:  22%| ██▏        108/500 [00:03]

	completed  100  /  500 epochs


Epochs completed:  32%| ███▏       158/500 [00:04]

	completed  150  /  500 epochs


Epochs completed:  42%| ████▏      208/500 [00:05]

	completed  200  /  500 epochs


Epochs completed:  52%| █████▏     258/500 [00:06]

	completed  250  /  500 epochs


Epochs completed:  62%| ██████▏    308/500 [00:07]

	completed  300  /  500 epochs


Epochs completed:  72%| ███████▏   358/500 [00:08]

	completed  350  /  500 epochs


Epochs completed:  82%| ████████▏  408/500 [00:10]

	completed  400  /  500 epochs


Epochs completed:  92%| █████████▏ 458/500 [00:11]

	completed  450  /  500 epochs


Epochs completed: 100%| ██████████ 500/500 [00:12]


Wed Sep 24 13:21:15 2025 Finished embedding


In [6]:
session = fo.launch_app(test_split)
session.wait()

Notebook sessions cannot wait


# Add tags to dataset

## Set up consistent dataset

In [None]:
# Define the FiftyOneCIFARDataset class (copied from cell_id: -zonHH4HkINl)
class FiftyOneCIFARDataset(torch.utils.data.Dataset):
    def __init__(self, fiftyone_dataset, class_to_idx, transform=None):
        self.dataset = fiftyone_dataset
        self.transform = transform
        self.sample_ids = list(self.dataset.values("_id"))  # List of sample IDs
        self.class_to_idx = class_to_idx # Store the class_to_idx mapping

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        sample_id = self.sample_ids[idx]
        sample = self.dataset[sample_id]  # Access by sample ID
        img = Image.open(sample.filepath).convert("RGB")
        if self.transform:
            img = self.transform(img)
        # Convert string label to integer index using the mapping
        label_str = sample.ground_truth.label
        label_int = self.class_to_idx[label_str]
        label_tensor = torch.tensor(label_int, dtype=torch.long) # Ensure label is a LongTensor

        return img, label_tensor, str(sample_id)  # Return tensor label and sample_id as string

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Use train_dataset to get class_to_idx function
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
class_to_idx = train_dataset.class_to_idx

# Instantiate the original Fiftyone dataset and custum FiftyOne dataset and dataloader
test_split = foz.load_zoo_dataset("cifar10", split="test")
custom_fiftyone_dataset = FiftyOneCIFARDataset(test_split, class_to_idx=class_to_idx, transform=transform)
custom_fiftyone_dataloader = DataLoader(custom_fiftyone_dataset, batch_size=64, shuffle=False) # Use shuffle=False for consistent batching

In [None]:
# Tag class probabilities to each sample in the fiftyone dataset
def tag_class_probabilities(fo_dataset, custom_dataset, class_index, epoch, probabilities):
    if not fo_dataset.has_field(f"class_{class_index}_prob"):
        fo_dataset.add_sample_field(f"class_{class_index}_prob", fo.FloatField)

    sample_ids = custom_dataset.sample_ids

    for i, sample_id in enumerate(sample_ids):
        sample = fo_dataset[sample_id]
        sample[f"epoch_{epoch}_class_{class_index}_prob"] = float(probabilities[i, class_index].item())
        sample.save()  # <-- Save each modified sample

for checkpoint in [checkpoints[0], checkpoints[-1]]:
    model = empty_model
    model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])

    target_size = (32, 32)  # required input size for your CNN
    image_tensors = []
    for f in test_split.values("filepath"):
        img = cv2.imread(f, cv2.IMREAD_COLOR)  # force 3-channel BGR
        img = cv2.resize(img, target_size)
        img = img.astype(np.float32) / 255.0  # normalize to [0,1]
        img = img.transpose(2, 0, 1)  # convert from (H, W, C) to (C, H, W)
        image_tensors.append(img)

    output = model(torch.tensor(image_tensors))
    probabilities_softmax = nn.Softmax(dim=1)(output)
    epoch = checkpoint.split('_')[-1].split('.')[0]
    # log_probabilities = nn.LogSoftmax(dim=1)(output)

    for class_index in range(10):
        tag_class_probabilities(test_split, custom_fiftyone_dataset, class_index, epoch, probabilities_softmax)

## Tag batch number and loss values

In [None]:
# Might be wrong since the torch cifar10 dataset and fiftyone cifar10 dataset might differ in the ordering which would result in incorrect tagging

def tag_losses(dataloader, dataset, checkpoints):
    for checkpoint in checkpoints:
        model = empty_model
        print(f"Loading model from path: {checkpoint}") # Added print statement
        # Load the model state dictionary, ensuring map_location is set if needed (e.g., if trained on GPU and loading on CPU)
        try: # Added try-except block for loading model
            model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])
            print(f"Successfully loaded model from {checkpoint}") # Added print statement
        except Exception as e:
            print(f"Error loading model from {checkpoint}: {e}")
            continue # Skip to the next checkpoint if loading fails
        
        checkpoint_num = checkpoint.split('_')[-1].split('.')[0]
        model_name = f'model_epoch_{checkpoint_num}'
        criterion_none = torch.nn.CrossEntropyLoss(reduction="none")
        losses_by_id = {}

        with torch.no_grad():
            for inputs, labels, sample_ids in dataloader:
                outputs = model(inputs)
                individual_losses = criterion_none(outputs, labels)

                for i, sample_id in enumerate(sample_ids):
                    losses_by_id[sample_id] = round(individual_losses[i].item(), 4)

        # Ensure the 'losses' field exists only once
        if not dataset.has_field("losses"):
            dataset.add_sample_field("losses", fo.DictField)

        for sample_id, loss in losses_by_id.items():
            sample = dataset[sample_id]
            if sample.losses is None:
                sample.losses = {}
            sample.losses[model_name] = loss
            sample[model_name + "_loss"] = loss
            sample.save()


def tag_batches(dataloader, fo_dataset, custom_dataset):
    if not fo_dataset.has_field("batch_num"):
        fo_dataset.add_sample_field("batch_num", fo.IntField)

    sample_ids = custom_dataset.sample_ids
    batch_size = dataloader.batch_size

    for i, sample_id in enumerate(sample_ids):
        batch_index = i // batch_size
        sample = fo_dataset[sample_id]
        sample["batch_num"] = batch_index
        sample.save()  # <-- Save each modified sample




# Criterion
criterion = nn.CrossEntropyLoss()

# Call the tag_data function with the dataloader, dataset, criterion, and checkpoint paths
tag_losses(custom_fiftyone_dataloader, test_split, checkpoints)
tag_batches(custom_fiftyone_dataloader, test_split, custom_fiftyone_dataset)

## Get Selected Points

In [None]:
from methods.SelectionMethod import SelectionMethod
from methods.DivBS import DivBS
from methods.Uniform import Uniform
from methods.Bayesian import Bayesian
import numpy as np
import yaml
from utils import custom_logger
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

### Get selected points (DivBS & Uniform)



In [None]:
class Uniform_Selection(Uniform):
    def get_indicies(self, dataset, epoch):

        data_loader = DataLoader(dataset, batch_size=64, shuffle=False)

        all_indexes = []

        for i, datas in enumerate(data_loader):
            inputs = np.array(datas[0])
            targets = np.array(datas[1])
            indexes = np.array(np.arange(inputs.shape[0]).tolist())

            # Call batch selection logic
            inputs, targets, indexes = self.before_batch(i, inputs, targets, indexes, epoch)
            all_indexes.append(indexes)

        return all_indexes


class DivBS_Selection(DivBS):
    def get_indicies(self, dataset, epoch):

        data_loader = DataLoader(dataset, batch_size=64, shuffle=False)

        all_indexes = []

        for i, datas in enumerate(data_loader):
            inputs = datas[0]
            targets = datas[1]
            indexes = np.arange(inputs.shape[0])

            # Call batch selection logic
            inputs, targets, indexes = self.before_batch(i, inputs, targets, indexes, epoch)
            all_indexes.append(indexes)

        return all_indexes

class Bayesian_Selection(Bayesian):
    def get_indicies(self, dataset, epoch):

        data_loader = DataLoader(dataset, batch_size=64, shuffle=False)

        all_indexes = []

        for i, datas in enumerate(data_loader):
            inputs = datas[0]
            targets = datas[1]
            indexes = np.arange(inputs.shape[0])

            # Call batch selection logic
            inputs, targets, indexes = self.before_batch(i, inputs, targets, indexes, epoch)
            all_indexes.append(indexes)

        return all_indexes

In [None]:

# Set up logger
try: # Added try-except block for setting up logger
    logger = custom_logger("./exp")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info('device: ' + str(device))
except Exception as e:
    print(f"Error setting up logger: {e}")

# Set up dataset
test_split = foz.load_zoo_dataset("cifar10", split="test")
custom_fiftyone_dataset = FiftyOneCIFARDataset(test_split, class_to_idx=class_to_idx, transform=transform)

selected_indices_DivBS = []
selected_indices_Uniform = []
#selected_indices_Bayesian = []

# NOTE: Can't start at checkpoint 0 because it raises an error with feature tensor becoming 0
for checkpoint in checkpoints:

    model = empty_model
    print(f"Loading model from path: {checkpoint}") 

    try: # try-except block for loading model
        model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])
        print(f"Successfully loaded model from {checkpoint}")
    except Exception as e:
        print(f"Error loading model from {checkpoint}: {e}")
        continue # Skip to the next checkpoint if loading fails
    
    checkpoint_num = int(checkpoint.split('_')[-1].split('.')[0])
    
    DivBS_selector = DivBS_Selection(config, logger)
    Uniform_selector = Uniform_Selection(config, logger)
    #Bayesian_selector = Bayesian_Selection(config, logger)
    DivBS_selector.model = model
    Uniform_selector.model = model
    #Bayesian_selector.model = model
    run_indicies_DivBS = DivBS_selector.get_indicies(custom_fiftyone_dataset, checkpoint_num)
    run_indicies_Uniform = Uniform_selector.get_indicies(custom_fiftyone_dataset, checkpoint_num)
    #run_indicies_Bayesian = Bayesian_selector.get_indicies(custom_fiftyone_dataset, checkpoint_num)
    selected_indices_DivBS.append(run_indicies_DivBS)
    selected_indices_Uniform.append(run_indicies_Uniform)
    #selected_indices_Bayesian.append(run_indicies_Bayesian)

## Tag selected points

In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import fiftyone as fo # Import fiftyone
from PIL import Image # Import Image

# Define the FiftyOneCIFARDataset class (copied from cell_id: -zonHH4HkINl)
class FiftyOneCIFARDataset(torch.utils.data.Dataset):
    def __init__(self, fiftyone_dataset, class_to_idx, transform=None):
        self.dataset = fiftyone_dataset
        self.transform = transform
        self.sample_ids = list(self.dataset.values("_id"))  # List of sample IDs
        self.class_to_idx = class_to_idx # Store the class_to_idx mapping

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        sample_id = self.sample_ids[idx]
        sample = self.dataset[sample_id]  # Access by sample ID
        img = Image.open(sample.filepath).convert("RGB")
        if self.transform:
            img = self.transform(img)
        # Convert string label to integer index using the mapping
        label_str = sample.ground_truth.label
        label_int = self.class_to_idx[label_str]
        label_tensor = torch.tensor(label_int, dtype=torch.long) # Ensure label is a LongTensor

        return img, label_tensor, str(sample_id)  # Return tensor label and sample_id as string

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Use train_dataset to get class_to_idx function
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
class_to_idx = train_dataset.class_to_idx

# Instantiate the original Fiftyone dataset and custum FiftyOne dataset and dataloader
test_split = foz.load_zoo_dataset("cifar10", split="test")
custom_fiftyone_dataset = FiftyOneCIFARDataset(test_split, class_to_idx=class_to_idx, transform=transform)
custom_fiftyone_dataloader = DataLoader(custom_fiftyone_dataset, batch_size=64, shuffle=False) # Use shuffle=False for consistent batching


def tag_selected_points(dataloader, fo_dataset, custom_dataset, checkpoints, selected_indices, method_name):
    # Iterate over each checkpoint
    for ckpt_idx, checkpoint in enumerate(checkpoints):
        # Added try-except block for loading model
        model = empty_model
        print(f"Loading model from path: {checkpoint}") # Added print statement
        # Load the model state dictionary
        try: 
            model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])
            print(f"Successfully loaded model from {checkpoint}")
        except Exception as e:
            print(f"Error loading model from {checkpoint}: {e}")
            continue # Skip to the next checkpoint if loading fails
        
        checkpoint_num = checkpoint.split('_')[-1].split('.')[0]
        model_name = f'model_epoch_{checkpoint_num}'

        sample_ids = custom_dataset.sample_ids
        batch_size = dataloader.batch_size

        # # Added a check to ensure selected_indices has enough elements
        # if ckpt_idx >= len(selected_indices):
        #     print(f"Warning: Not enough selected indices for checkpoint {model_name}. Skipping tagging.")
        #     continue

        # Get selected indices for this checkpoint
        batches_of_selected_points = selected_indices[ckpt_idx]

        for i, sample_id in enumerate(sample_ids):
            sample_batch_index = i // batch_size
            in_batch_index = i % batch_size

            # # Added a check to ensure selected_batches has enough elements
            # if batch_index >= len(selected_batches):
            #     print(f"Warning: Not enough selected batches for sample {sample_id} in checkpoint {model_name}. Skipping tagging.")
            #     continue

            # Get selected indices for this batch
            batch_selected_points = batches_of_selected_points[sample_batch_index]

            # # Added a check to ensure batch_selected_points is iterable
            # if not hasattr(batch_selected_points, '__iter__'):
            #      print(f"Warning: batch_selected_points for sample {sample_id} in checkpoint {model_name} is not iterable. Skipping tagging.")
            #      continue

            # if sample is in the selected points for this batch, tag it as True, else False
            selected = in_batch_index in batch_selected_points
            sample = fo_dataset[sample_id]
            sample[model_name + "_" + method_name + "_selected"] = selected
            sample.save()


# Example usage:
tag_selected_points(custom_fiftyone_dataloader, test_split, custom_fiftyone_dataset, checkpoints, selected_indices_DivBS, "DivBS")
tag_selected_points(custom_fiftyone_dataloader, test_split, custom_fiftyone_dataset, checkpoints, selected_indices_Uniform, "Uniform")
# tag_selected_points(custom_fiftyone_dataloader, test_split, custom_fiftyone_dataset, checkpoints, selected_indices_Bayesian, "Bayesian")

# Show Data

In [None]:
session = fo.launch_app(test_split)
session.wait()