Goal: We want to be able to visualize what points each batch selection method is choosing at each part of the training process of a model.

Label each point with the following - loss value, batch #, if included in selection method

# Set Up

#### Install dependencies and load pretrained models

In [15]:
import fiftyone.zoo as foz
import cv2
import numpy as np
import fiftyone.brain as fob
from fiftyone import ViewField as F
import torch
import torch.nn as nn
import models
import yaml
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import fiftyone as fo # Import fiftyone
from PIL import Image # Import Image
import detectors
import timm
from methods.DivBS import DivBS
from methods.Uniform import Uniform
from methods.Bayesian import Bayesian
from methods.TrainLoss import TrainLoss
from methods.RhoLoss import RhoLoss
from utils import custom_logger

# resnet18_cifar10 = timm.create_model("resnet18_cifar10", pretrained=True)
# resnet34_cifar10 = timm.create_model("resnet34_cifar10", pretrained=True)
# resnet50_cifar10 = timm.create_model("resnet50_cifar10", pretrained=True)
# resnet34_supcon_cifar10 = timm.create_model("resnet34_supcon_cifar10", pretrained=True)
resnet50_supcon_cifar10 = timm.create_model("resnet50_supcon_cifar10", pretrained=True)
#resnet50_simclr_cifar10 = timm.create_model("resnet50_simclr_cifar10", pretrained=True)
pretrained_models = {
    # "resnet18_cifar10": resnet18_cifar10,
    # "resnet34_cifar10": resnet34_cifar10,
    # "resnet50_cifar10": resnet50_cifar10,
    # "resnet34_supcon_cifar10": resnet34_supcon_cifar10,
    "resnet50_supcon_cifar10": resnet50_supcon_cifar10,
    #"resnet50_simclr_cifar10": resnet50_simclr_cifar10
}

#### Set up consistent dataset

In [16]:
# Define the FiftyOneCIFARDataset class (copied from cell_id: -zonHH4HkINl)
class FiftyOneCIFARDataset(torch.utils.data.Dataset):
    def __init__(self, fiftyone_dataset, class_to_idx, transform=None):
        self.dataset = fiftyone_dataset
        self.transform = transform
        self.sample_ids = list(self.dataset.values("_id"))  # List of sample IDs
        self.class_to_idx = class_to_idx # Store the class_to_idx mapping

    def __len__(self):
        return len(self.sample_ids)

    def __getitem__(self, idx):
        sample_id = self.sample_ids[idx]
        sample = self.dataset[sample_id]  # Access by sample ID
        img = Image.open(sample.filepath).convert("RGB")
        if self.transform:
            img = self.transform(img)
        # Convert string label to integer index using the mapping
        label_str = sample.ground_truth.label
        label_int = self.class_to_idx[label_str]
        label_tensor = torch.tensor(label_int, dtype=torch.long) # Ensure label is a LongTensor

        return img, label_tensor, str(sample_id)  # Return tensor label and sample_id as string

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Use train_dataset to get class_to_idx function
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
class_to_idx = train_dataset.class_to_idx

# Instantiate the original Fiftyone dataset, custum FiftyOne dataset, and dataloader
fo_dataset = foz.load_zoo_dataset("cifar10", split="test")
custom_fiftyone_dataset = FiftyOneCIFARDataset(fo_dataset, class_to_idx=class_to_idx, transform=transform)
custom_fiftyone_dataloader = DataLoader(custom_fiftyone_dataset, batch_size=64, shuffle=False) # Use shuffle=False for consistent batching

Split 'test' already downloaded
Loading existing dataset 'cifar10-test'. To reload from disk, either delete the existing dataset or provide a custom `dataset_name` to use


#### Set up ResNet model checkpoints

In [17]:
# Read config file
config_file = "cfg/cifar10.yaml"
with open(config_file, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
        f.close()

model_type = config['networks']['type']
model_args = config['networks']['params']
empty_model = getattr(models, model_type)(**model_args)

checkpoints = ['visualization_checkpoints/checkpoint_epoch_' + str(checkpoint) + '.pth.tar' for checkpoint in [3, 5, 7]] # or range(0, 200, 25)
model_params = torch.load(checkpoints[0], map_location=torch.device('cpu'), weights_only=False)
empty_model.load_state_dict(model_params['state_dict'])

# Temporary to make everything run faster
# first_checkpoint = checkpoints[0]
# last_checkpoint = checkpoints[-1]
# checkpoints = [first_checkpoint, last_checkpoint]

<All keys matched successfully>

#### FiftyOne Analysis Setup


In [18]:
def prepare_embeddings(model, model_name, fo_dataset):
    # Set the model to evaluation mode
    model.eval()

    batch_size = 128
    embeddings = []

    # Resize, normalize, and convert images to (C, H, W)
    target_size = (32, 32)  # required input size for the CNN

    image_tensors = []
    for f in fo_dataset.values("filepath"):
        img = cv2.imread(f, cv2.IMREAD_COLOR)  # force 3-channel BGR
        img = cv2.resize(img, target_size)
        img = img.astype(np.float32) / 255.0  # normalize to [0,1]
        img = img.transpose(2, 0, 1)  # convert from (H, W, C) to (C, H, W)
        image_tensors.append(img)

    for i in range(0, len(image_tensors), batch_size):
        batch_np = np.stack(image_tensors[i:i+batch_size])
        batch_tensor = torch.tensor(batch_np)

        with torch.no_grad():
            out = model(batch_tensor)
            embeddings.append(out.cpu().numpy())

    embedding_outputs = np.concatenate(embeddings, axis=0)

    results = fob.compute_visualization(
        fo_dataset,
        embeddings=embedding_outputs,
        num_dims=2,
        method="umap",
        brain_key=model_name+"_test",
        verbose=True,
        seed=51,
    )
    fo_dataset.load_brain_results(model_name+"_test")
    return results

In [19]:
## Optional to reset FiftyOne database
# fo_dataset.delete()

# Load embeddings for each model
for model_name, model in pretrained_models.items():
    prepare_embeddings(model, model_name, fo_dataset)

Generating visualization...
UMAP(n_jobs=1, random_state=51, verbose=True)
Thu Oct  2 13:48:08 2025 Construct fuzzy simplicial set
Thu Oct  2 13:48:08 2025 Finding Nearest Neighbors
Thu Oct  2 13:48:08 2025 Building RP forest with 10 trees


  warn(


Thu Oct  2 13:48:08 2025 NN descent for 13 iterations
	 1  /  13
	 2  /  13
	 3  /  13
	 4  /  13
	Stopping threshold met -- exiting after 4 iterations
Thu Oct  2 13:48:10 2025 Finished Nearest Neighbor Search
Thu Oct  2 13:48:10 2025 Construct embedding


Epochs completed:   1%|            6/500 [00:00]

	completed  0  /  500 epochs


Epochs completed:  11%| █▏         57/500 [00:01]

	completed  50  /  500 epochs


Epochs completed:  21%| ██▏        107/500 [00:02]

	completed  100  /  500 epochs


Epochs completed:  31%| ███▏       157/500 [00:03]

	completed  150  /  500 epochs


Epochs completed:  41%| ████▏      207/500 [00:04]

	completed  200  /  500 epochs


Epochs completed:  51%| █████▏     257/500 [00:05]

	completed  250  /  500 epochs


Epochs completed:  61%| ██████▏    307/500 [00:07]

	completed  300  /  500 epochs


Epochs completed:  71%| ███████▏   357/500 [00:08]

	completed  350  /  500 epochs


Epochs completed:  81%| ████████▏  407/500 [00:09]

	completed  400  /  500 epochs


Epochs completed:  91%| █████████▏ 457/500 [00:10]

	completed  450  /  500 epochs


Epochs completed: 100%| ██████████ 500/500 [00:11]


Thu Oct  2 13:48:22 2025 Finished embedding


# Add tags to dataset

#### Tag batch numbers

In [20]:
def tag_batches(fo_dataset, custom_dataset, fo_dataloader):
    if not fo_dataset.has_field("batch_num"):
        fo_dataset.add_sample_field("batch_num", fo.IntField)

    sample_ids = custom_dataset.sample_ids
    batch_size = fo_dataloader.batch_size

    for i, sample_id in enumerate(sample_ids):
        batch_index = i // batch_size
        sample = fo_dataset[sample_id]
        sample["batch_num"] = batch_index
        sample.save()  # <-- Save each modified sample

tag_batches(fo_dataset, custom_fiftyone_dataset, custom_fiftyone_dataloader)

#### Tag class probabilities

In [21]:
# Tag class probabilities to each sample in the fiftyone dataset
def tag_class_probabilities(checkpoints, fo_dataset, custom_dataset):
    
    image_tensors = []
    for f in fo_dataset.values("filepath"):
        img = cv2.imread(f, cv2.IMREAD_COLOR)  # force 3-channel BGR
        img = cv2.resize(img, (32, 32))
        img = img.astype(np.float32) / 255.0  # normalize to [0,1]
        img = img.transpose(2, 0, 1)  # convert from (H, W, C) to (C, H, W)
        image_tensors.append(img)

    for checkpoint in checkpoints:
        # Load model from checkpoint
        model = empty_model
        model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])

        # Get outputs and probabilities
        output = model(torch.tensor(image_tensors))
        probabilities = nn.Softmax(dim=1)(output) # or nn.LogSoftmax(dim=1)(output)
        epoch = checkpoint.split('_')[-1].split('.')[0]
        sample_ids = custom_dataset.sample_ids

        for i, sample_id in enumerate(sample_ids):
            sample = fo_dataset[sample_id]
            for class_index in range(10):
                sample[f"epoch_{epoch}_class_{class_index}_prob"] = float(probabilities[i, class_index].item())
            sample.save()  # <-- Save each modified sample

# Tag probabilities for each class
tag_class_probabilities(checkpoints, fo_dataset, custom_fiftyone_dataset)

#### Tag loss values

In [22]:
# Might be wrong since the torch cifar10 dataset and fiftyone cifar10 dataset might differ in the ordering which would result in incorrect tagging

def tag_losses(checkpoints, dataset, dataloader):
    for checkpoint in checkpoints:
        model = empty_model
        print(f"Loading model from path: {checkpoint}") # Added print statement
        # Load the model state dictionary, ensuring map_location is set if needed (e.g., if trained on GPU and loading on CPU)
        try: # Added try-except block for loading model
            model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])
            print(f"Successfully loaded model from {checkpoint}") # Added print statement
        except Exception as e:
            print(f"Error loading model from {checkpoint}: {e}")
            continue # Skip to the next checkpoint if loading fails
        
        checkpoint_num = checkpoint.split('_')[-1].split('.')[0]
        model_name = f'model_epoch_{checkpoint_num}'
        criterion_none = torch.nn.CrossEntropyLoss(reduction="none")
        losses_by_id = {}

        with torch.no_grad():
            for inputs, labels, sample_ids in dataloader:
                outputs = model(inputs)
                individual_losses = criterion_none(outputs, labels)

                for i, sample_id in enumerate(sample_ids):
                    losses_by_id[sample_id] = round(individual_losses[i].item(), 4)

        # Ensure the 'losses' field exists only once
        if not dataset.has_field("losses"):
            dataset.add_sample_field("losses", fo.DictField)

        for sample_id, loss in losses_by_id.items():
            sample = dataset[sample_id]
            if sample.losses is None:
                sample.losses = {}
            sample.losses[model_name] = loss
            sample[model_name + "_loss"] = loss
            sample.save()

# Call the tag_data function with the dataloader, dataset, criterion, and checkpoint paths
tag_losses(checkpoints, fo_dataset, custom_fiftyone_dataloader)

Loading model from path: visualization_checkpoints/checkpoint_epoch_3.pth.tar
Successfully loaded model from visualization_checkpoints/checkpoint_epoch_3.pth.tar
Loading model from path: visualization_checkpoints/checkpoint_epoch_5.pth.tar
Successfully loaded model from visualization_checkpoints/checkpoint_epoch_5.pth.tar
Loading model from path: visualization_checkpoints/checkpoint_epoch_7.pth.tar
Successfully loaded model from visualization_checkpoints/checkpoint_epoch_7.pth.tar


#### Tag method selected points

In [23]:
class BaseSelectionMixin:
    """Mixin providing the get_indices method for any selector."""
    def get_indices(self, dataset, epoch, batch_size=64):
        data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        all_indexes = []

        for i, datas in enumerate(data_loader):
            inputs = datas[0]
            targets = datas[1]
            indexes = np.arange(len(inputs))

            # Call batch selection logic
            inputs, targets, indexes = self.before_batch(i, inputs, targets, indexes, epoch)
            all_indexes.append(indexes)

        return all_indexes


class Uniform_Selection(BaseSelectionMixin, Uniform):
    pass


class DivBS_Selection(BaseSelectionMixin, DivBS):
    pass


class TrainLoss_Selection(BaseSelectionMixin, TrainLoss):
    pass


class RhoLoss_Selection(BaseSelectionMixin, RhoLoss):
    pass

class Bayesian_Selection(Bayesian):
    def get_indices(self, dataset, epoch):

        data_loader = DataLoader(dataset, batch_size=64, shuffle=False)

        all_indexes = []

        for i, datas in enumerate(data_loader):
            inputs = datas[0]
            targets = datas[1]
            indexes = np.arange(inputs.shape[0])

            # Call batch selection logic
            inputs, targets, indexes = self.before_batch(i, inputs, targets, indexes, epoch)
            all_indexes.append(indexes)

        return all_indexes

In [24]:
# Set up logger
try: # Added try-except block for setting up logger
    logger = custom_logger("./exp")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info('device: ' + str(device))
except Exception as e:
    print(f"Error setting up logger: {e}")

selected_points = {}
methods = ['DivBS', 'Uniform', 'TrainLoss', 'RhoLoss'] #'Bayesian']

for method in methods:
    selected_points[method] = []

selector_classes = {
    "DivBS": DivBS_Selection,
    "Uniform": Uniform_Selection,
    "TrainLoss": TrainLoss_Selection,
    "RhoLoss": RhoLoss_Selection,
    #"Bayesian": Bayesian_Selection,
}

# NOTE: Can't start at checkpoint 0 because it raises an error with feature tensor becoming 0
for checkpoint in checkpoints:

    model = empty_model
    print(f"Loading model from path: {checkpoint}") 

    try: # try-except block for loading model
        model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])
        print(f"Successfully loaded model from {checkpoint}")
    except Exception as e:
        print(f"Error loading model from {checkpoint}: {e}")
        continue # Skip to the next checkpoint if loading fails
    
    checkpoint_num = int(checkpoint.split('_')[-1].split('.')[0])
    # Create and run selectors
    for method, cls in selector_classes.items():
        selector = cls(config, logger)
        selector.model = model
        
        run_indices = selector.get_indices(custom_fiftyone_dataset, checkpoint_num)
        selected_points[method].append(run_indices)

device: cuda
Loading model from path: visualization_checkpoints/checkpoint_epoch_3.pth.tar
Successfully loaded model from visualization_checkpoints/checkpoint_epoch_3.pth.tar
Creating DivBS...
balance: False
selecting samples for epoch 3, ratio 0.1
Creating Uniform...
selecting samples for epoch 3
balance: False
ratio: 0.1
Creating TrainLoss...
balance: False
selecting samples for epoch 3, ratio 0.1
Creating RhoLoss...
Loading holdout model from /home/phancock/Online-Batch-Selection/exp/test_rholoss/RhoLoss/best_holdout.pth.tar
Cached irreducible losses for 50000 samples in dataset.
balance: False
selecting samples for epoch 3, ratio 0.1
Loading model from path: visualization_checkpoints/checkpoint_epoch_5.pth.tar
Successfully loaded model from visualization_checkpoints/checkpoint_epoch_5.pth.tar
Creating DivBS...
balance: False
selecting samples for epoch 5, ratio 0.1
Creating Uniform...
selecting samples for epoch 5
balance: False
ratio: 0.1
Creating TrainLoss...
balance: False
selec

In [25]:
def tag_selected_points(checkpoints, fo_dataset, custom_dataset, dataloader, selected_points, method_name):
    # Iterate over each checkpoint
    for ckpt_idx, checkpoint in enumerate(checkpoints):
        # Added try-except block for loading model
        
        checkpoint_num = checkpoint.split('_')[-1].split('.')[0]
        model_name = f'model_epoch_{checkpoint_num}'

        sample_ids = custom_dataset.sample_ids
        batch_size = dataloader.batch_size

        # Get selected indices for this checkpoint
        batches_of_selected_points = selected_points[ckpt_idx]

        for i, sample_id in enumerate(sample_ids):
            sample_batch_index = i // batch_size
            in_batch_index = i % batch_size

            # Get selected indices for this batch
            batch_selected_points = batches_of_selected_points[sample_batch_index]
            # Normalize to list
            if isinstance(batch_selected_points, np.ndarray):
                batch_selected_points = batch_selected_points.tolist()
            elif isinstance(batch_selected_points, (np.int64, int)):
                batch_selected_points = [int(batch_selected_points)]

            # Convert to set for faster lookup
            batch_selected_points = set(batch_selected_points)

            # if sample is in the selected points for this batch, tag it as True, else False
            selected = in_batch_index in batch_selected_points
            sample = fo_dataset[sample_id]
            sample[model_name + "_" + method_name + "_selected"] = selected
            sample.save()

for method in methods:
    tag_selected_points(checkpoints, fo_dataset, custom_fiftyone_dataset, custom_fiftyone_dataloader, selected_points[method], method)

# Show Data

Potential Improvements:
* Add linked plot for even better visualization (which plots would be best for that?)
* Find way to include Bayesian selection

In [26]:
image_tensors = []
for f in fo_dataset.values("filepath"):
    img = cv2.imread(f, cv2.IMREAD_COLOR)  # force 3-channel BGR
    img = cv2.resize(img, (32, 32))
    img = img.astype(np.float32) / 255.0  # normalize to [0,1]
    img = img.transpose(2, 0, 1)  # convert from (H, W, C) to (C, H, W)
    image_tensors.append(img)

for checkpoint in checkpoints:
    # Load model from checkpoint
    model = empty_model
    model.load_state_dict(torch.load(checkpoint, map_location=torch.device('cpu'), weights_only=False)['state_dict'])

    # Get outputs and probabilities
    output = model(torch.tensor(image_tensors))
    probabilities = nn.Softmax(dim=1)(output) # or nn.LogSoftmax(dim=1)(output)
    epoch = checkpoint.split('_')[-1].split('.')[0]
    sample_ids = custom_fiftyone_dataset.sample_ids

    print(probabilities.argmax(dim=1))

    for i, sample_id in enumerate(sample_ids):
        pred_class = int(probabilities.argmax(dim=1)[i].item())
        sample = fo_dataset[sample_id]
        # Store as a Classification label (which supports colors)
        sample[f"epoch_{epoch}_predictions"] = fo.Classification(
            label=str(pred_class),  # use string name or class id
            confidence=float(probabilities[i][pred_class].item())
        )
        sample.save()
        # sample[f"epoch_{epoch}_predictions"] = (probabilities.argmax(dim=1)[i].item())
        # sample.save()  # <-- Save each modified sample

tensor([7, 8, 4,  ..., 4, 4, 4])
tensor([5, 1, 5,  ..., 5, 1, 7])
tensor([5, 1, 6,  ..., 5, 1, 7])


In [27]:
# import fiftyone as fo
# from fiftyone.core.odm.dataset import ColorScheme
# from fiftyone import Classification

# # -----------------------------
# # 1. Extract ground truth field
# # -----------------------------
# gt_scheme = fo_dataset.app_config.color_scheme
# gt_field_cfg = None
# for f in gt_scheme.fields:
#     if f["path"] == "ground_truth":
#         gt_field_cfg = f
#         break

# if gt_field_cfg is None:
#     raise ValueError("No color scheme found for 'ground_truth'!")

# # Auto-detect classes from ground truth
# classes = [vc["value"] for vc in gt_field_cfg["valueColors"]]
# string_to_color = {vc["value"]: vc["color"] for vc in gt_field_cfg["valueColors"]}

# # -----------------------------
# # 2. Detect all numeric prediction fields
# # -----------------------------

# pred_fields = [
#     f for f in fo_dataset.get_field_schema().keys()
#     if f.startswith("epoch_") and f.endswith("_predictions")
# ]

# # -----------------------------
# # 3. Create color scheme for predictions
# # -----------------------------
# pred_value_colors = []
# for c in classes:
#     color = string_to_color.get(c, "#000000")  # default to black if not found
#     pred_value_colors.append({"value": c, "color": color})
# pred_scheme = ColorScheme(fields=[{
#     "path": f,
#     "type": "classification",
#     "valueColors": pred_value_colors
# } for f in pred_fields])  

# fo.app_config.color_scheme = pred_scheme
# fo.app_config.color_by = "value"
# fo.app_config.field 

# # Show fiftyone app
# session = fo.launch_app(fo_dataset)


In [28]:
session = fo.launch_app(fo_dataset)