# Class-Aware OOD with CIFAR-100
This walkthrough configures MetaLoRA for class-aware sampling on CIFAR-100 and evaluates out-of-distribution robustness.

- Configure paths, dependencies, and experiment settings.
3- Train on all CIFAR-100 classes with a class-aware sampler.
- Use SVHN as a truly OOD benchmark while reusing the CIFAR-trained head.
- Report both in-distribution accuracy and SVHN OOD metrics.

In [3]:
import os
import random
import subprocess
import sys
from pathlib import Path

import numpy as np
import torch
from omegaconf import OmegaConf
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR100 as TorchvisionCIFAR100, SVHN as TorchvisionSVHN


def find_repo_root(start: Path, marker: str = "main.py", max_depth: int = 10) -> Path:
    """Locate the repository root when running locally or inside Colab."""
    env_root = os.environ.get("METALORA_ROOT") or os.environ.get("REPO_ROOT")
    if env_root:
        candidate = Path(env_root).expanduser().resolve()
        if (candidate / marker).exists():
            return candidate
    current = start.resolve()
    for _ in range(max_depth):
        if (current / marker).exists():
            return current
        if current.parent == current:
            break
        current = current.parent
    colab_candidate = Path("/content/metalora").resolve()
    if (colab_candidate / marker).exists():
        return colab_candidate
    if Path("/content").exists():
        print("Repository not found; attempting to clone into /content/metalora ...")
        subprocess.run(
            [
                "git",
                "clone",
                "https://github.com/doem97/metalora.git",
                str(colab_candidate),
            ],
            check=True,
        )
        if (colab_candidate / marker).exists():
            return colab_candidate
    raise FileNotFoundError(
        f"Could not locate {marker}. Set METALORA_ROOT to the repo path or clone it under /content/metalora."
    )


REPO_ROOT = find_repo_root(Path.cwd())
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
if "datasets" in sys.modules:
    del sys.modules["datasets"]
os.chdir(REPO_ROOT)

import datasets
from trainer import (
    CLASS_MEAN_FNAME,
    TEXT_FEAT_FNAME,
    Trainer,
    load_clip_to_cpu,
    load_vit_to_cpu,
    )
from models import PeftModelFromCLIP, PeftModelFromViT, ZeroShotCLIP
from models.satmae_vit import MAEViTAdapter
from utils.config_omega import cfg as base_cfg
from utils.evaluator import Evaluator
from utils.logger import logger
from utils.samplers import ClassAwareSampler, DownSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
def load_experiment_config(repo_root, dataset_name, model_name, tuner_name=None, overrides=None):
    config = OmegaConf.create(OmegaConf.to_container(base_cfg, resolve=True))
    config_paths = [
        repo_root / "configs" / "data" / f"{dataset_name}.yaml",
        repo_root / "configs" / "model" / f"{model_name}.yaml",
    ]
    if tuner_name:
        config_paths.append(repo_root / "configs" / "tuner" / f"{tuner_name}.yaml")
    for path in config_paths:
        if not path.exists():
            raise FileNotFoundError(path)
        config = OmegaConf.merge(config, OmegaConf.load(path))
    if overrides:
        config = OmegaConf.merge(config, OmegaConf.create(overrides))
    return config


def make_cifar_subset(dataset_cls, root, train, transform, class_indices, remap=True):
    keep = sorted(class_indices)
    dataset = dataset_cls(root, train=train, transform=transform)
    targets = list(dataset.targets)
    indices = [idx for idx, label in enumerate(targets) if label in keep]
    if len(indices) == 0:
        raise ValueError("No samples found for the provided classes.")
    index_array = np.array(indices, dtype=np.int64)
    dataset.data = dataset.data[index_array]
    selected_targets = [targets[idx] for idx in indices]
    dataset.original_targets = selected_targets.copy()
    subset_classnames = [dataset.classes[idx] for idx in keep]
    if remap:
        label_map = {orig: new_idx for new_idx, orig in enumerate(keep)}
        remapped_targets = [label_map[label] for label in selected_targets]
        dataset.targets = remapped_targets
        dataset.labels = remapped_targets
        dataset.classes = subset_classnames
        dataset.classnames = subset_classnames
        dataset.class_to_idx = {name: idx for idx, name in enumerate(subset_classnames)}
        dataset.label_map = label_map
        dataset.inverse_label_map = {v: k for k, v in label_map.items()}
    else:
        dataset.targets = selected_targets
        dataset.labels = selected_targets
        dataset.classnames = subset_classnames
        dataset.label_map = None
        dataset.inverse_label_map = None
    if hasattr(dataset, "get_cls_num_list"):
        dataset.cls_num_list = dataset.get_cls_num_list()
        dataset.num_classes = len(dataset.cls_num_list)
    else:
        dataset.num_classes = len(subset_classnames)
        dataset.cls_num_list = [dataset.targets.count(i) for i in range(dataset.num_classes)]
    dataset.keep_classes = keep
    return dataset


def compute_ood_metrics(id_scores, ood_scores, tpr=0.95):
    id_scores = np.asarray(id_scores, dtype=np.float32)
    ood_scores = np.asarray(ood_scores, dtype=np.float32)
    if id_scores.size == 0 or ood_scores.size == 0:
        raise ValueError("Need non-empty ID and OOD score arrays.")
    labels = np.concatenate([np.ones_like(id_scores), np.zeros_like(ood_scores)])
    scores = np.concatenate([id_scores, ood_scores])
    threshold = np.percentile(id_scores, (1.0 - tpr) * 100.0)
    metrics = {
        "auroc": float(roc_auc_score(labels, scores)),
        "aupr": float(average_precision_score(labels, scores)),
        "fpr@95tpr": float(np.mean(ood_scores >= threshold)),
        "threshold@95tpr": float(threshold),
        "id_mean": float(id_scores.mean()),
        "id_std": float(id_scores.std()),
        "ood_mean": float(ood_scores.mean()),
        "ood_std": float(ood_scores.std()),
    }
    return metrics

In [5]:
class ClassAwareOODTrainer(Trainer):
    def __init__(self, cfg, device, id_classes, ood_classes=None, class_aware_k=4):
        self.id_classes = sorted(set(id_classes))
        self.ood_classes = sorted(set(ood_classes or []))
        self.external_ood_name = getattr(cfg, "ood_dataset", None)
        self.external_ood_name = (
            self.external_ood_name.lower() if self.external_ood_name else None
        )
        overlap = set(self.id_classes) & set(self.ood_classes)
        if overlap and not self.external_ood_name:
            raise ValueError(
                f"In-distribution and OOD classes overlap: {sorted(overlap)}"
            )
        self.class_aware_k = class_aware_k
        super().__init__(cfg, device)
        self.local_rank = 0
        self.world_size = 1
        self.ood_test_loader = None
        root_hint = Path(cfg.root or os.environ.get("CIFAR100_ROOT", "./data")).expanduser()
        class_names = getattr(TorchvisionCIFAR100, "classes", None)
        if class_names is None:
            preview_dataset = TorchvisionCIFAR100(
                root=str(root_hint), train=True, download=True
            )
            class_names = preview_dataset.classes
        self.global_classnames = class_names
        self.id_classnames = [class_names[idx] for idx in self.id_classes]
        if self.external_ood_name:
            self.ood_classnames = None
        else:
            self.ood_classnames = [class_names[idx] for idx in self.ood_classes]
        self.last_ood_scores = None

    def build_data_loader(self):
        cfg = self.cfg
        root = cfg.root
        resolution = cfg.resolution

        if cfg.backbone.startswith("CLIP"):
            mean = [0.48145466, 0.4578275, 0.40821073]
            std = [0.26862954, 0.26130258, 0.27577711]
        else:
            mean = [0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5]

        transform_train = transforms.Compose(
            [
                transforms.RandomResizedCrop(resolution),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
)

        transform_plain = transforms.Compose(
            [
                transforms.Resize(resolution),
                transforms.CenterCrop(resolution),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
)

        transform_test = transforms.Compose(
            [
                transforms.Resize(resolution * 8 // 7),
                transforms.CenterCrop(resolution),
                transforms.Lambda(
                    lambda crop: torch.stack([transforms.ToTensor()(crop)])
                ),
                transforms.Normalize(mean, std),
            ]
)

        dataset_cls = getattr(datasets, cfg.dataset)
        train_dataset = make_cifar_subset(
            dataset_cls, root, True, transform_train, self.id_classes, remap=True
        )
        train_init_dataset = make_cifar_subset(
            dataset_cls, root, True, transform_plain, self.id_classes, remap=True
        )
        train_test_dataset = make_cifar_subset(
            dataset_cls, root, True, transform_test, self.id_classes, remap=True
        )
        id_test_dataset = make_cifar_subset(
            dataset_cls, root, False, transform_test, self.id_classes, remap=True
        )

        if self.external_ood_name:
            ood_test_dataset = self._build_external_ood_dataset(transform_test)
        else:
            if not self.ood_classes:
                raise ValueError(
                    "No OOD classes specified and external OOD dataset not provided."
                )
            ood_test_dataset = make_cifar_subset(
                dataset_cls, root, False, transform_test, self.ood_classes, remap=False
            )

        self.num_classes = train_dataset.num_classes
        self.cls_num_list = train_dataset.cls_num_list
        self.classnames = train_dataset.classnames

        freq = np.array(self.cls_num_list)
        self.many_idxs = np.where(freq > 100)[0]
        self.med_idxs = np.where((freq >= 20) & (freq <= 100))[0]
        self.few_idxs = np.where(freq < 20)[0]

        if cfg.init_head == "1_shot":
            init_sampler = DownSampler(train_init_dataset, n_max=1)
        elif cfg.init_head == "10_shot":
            init_sampler = DownSampler(train_init_dataset, n_max=10)
        elif cfg.init_head == "100_shot":
            init_sampler = DownSampler(train_init_dataset, n_max=100)
        else:
            init_sampler = None

        self.accum_step = cfg.accum_step or 1
        self.eff_batch_size = cfg.batch_size
        denom = self.accum_step * self.world_size
        if self.eff_batch_size % denom != 0:
            raise ValueError(
                f"batch_size ({cfg.batch_size}) must be divisible by accum_step ({self.accum_step})."
            )
        self.per_gpu_batch_size = self.eff_batch_size // denom

        train_sampler = ClassAwareSampler(train_dataset, num_samples_cls=self.class_aware_k)

        pin = self.device.type == "cuda"
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.per_gpu_batch_size,
            sampler=train_sampler,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=pin,
        )

        self.train_init_loader = DataLoader(
            train_init_dataset,
            batch_size=min(64, len(train_init_dataset)),
            sampler=init_sampler,
            shuffle=init_sampler is None,
            num_workers=cfg.num_workers,
            pin_memory=pin,
        )

        self.train_test_loader = DataLoader(
            train_test_dataset,
            batch_size=64,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=pin,
        )

        self.test_loader = DataLoader(
            id_test_dataset,
            batch_size=64,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=pin,
        )

        self.ood_test_loader = DataLoader(
            ood_test_dataset,
            batch_size=64,
            shuffle=False,
            num_workers=cfg.num_workers,
            pin_memory=pin,
        )

        ood_desc = (
            f"OOD dataset ({self.external_ood_name.upper()}): {len(ood_test_dataset)} samples"
            if self.external_ood_name
            else f"OOD samples: {len(ood_test_dataset)}"
        )
        print(
            f"Train samples: {len(train_dataset)} | ID classes: {len(self.id_classes)} | {ood_desc}"
        )

    def build_model(self):
        cfg = self.cfg
        classnames = self.classnames
        num_classes = len(classnames)

        if cfg.backbone.startswith("CLIP"):
            clip_model = load_clip_to_cpu(cfg.backbone, cfg.prec)
            if cfg.zero_shot:
                self.model = ZeroShotCLIP(clip_model)
                self.model.to(self.device)
                self.tuner = None
                self.head = None
                template = "a photo of a {}."
                prompts = self.get_tokenized_prompts(classnames, template)
                self.model.init_text_features(prompts)
                return
            self.model = PeftModelFromCLIP(cfg, clip_model, num_classes)
        elif cfg.backbone.startswith("IN21K-ViT"):
            vit_model = load_vit_to_cpu(cfg.backbone, cfg.prec)
            self.model = PeftModelFromViT(cfg, vit_model, num_classes)
        elif cfg.backbone.startswith("SatMAE-ViT"):
            vit_model = load_vit_to_cpu(cfg.backbone, cfg.prec)
            self.model = PeftModelFromViT(cfg, vit_model, num_classes)
        else:
            raise ValueError(f"Unsupported backbone: {cfg.backbone}")

        self.model.to(self.device)
        self.tuner = getattr(self.model, "tuner", None)
        self.head = getattr(self.model, "head", None)

        if cfg.init_head == "text_feat":
            if not cfg.backbone.startswith("CLIP"):
                print("text_feat head init is only available for CLIP backbones.")
            else:
                text_feat_fname = TEXT_FEAT_FNAME.get(cfg.backbone)
                if text_feat_fname is None:
                    raise ValueError(
                        f"No text feature file registered for {cfg.backbone}"
                    )
                if cfg.head_init_folder is None:
                    raise ValueError(
                        "head_init_folder must be set for text feature initialization."
                    )
                text_feat_path = os.path.join(cfg.head_init_folder, text_feat_fname)
                self.init_head_text_feat(text_feat_path)
        elif cfg.init_head in ["class_mean", "1_shot", "10_shot", "100_shot"]:
            class_mean_fname = CLASS_MEAN_FNAME.get(cfg.backbone)
            if class_mean_fname is None:
                raise ValueError(
                    f"No class mean file registered for {cfg.backbone}"
                )
            if cfg.head_init_folder is None:
                raise ValueError(
                    "head_init_folder must be set for class mean initialization."
                )
            class_mean_path = os.path.join(cfg.head_init_folder, class_mean_fname)
            self.init_head_class_mean(class_mean_path)
        elif cfg.init_head == "linear_probe":
            self.init_head_linear_probe()

        if not (cfg.zero_shot or cfg.test_train or cfg.test_only):
            self.build_optimizer()
            self.build_criterion()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    def _build_external_ood_dataset(self, transform):
        dataset_name = self.external_ood_name
        if dataset_name == "svhn":
            root = getattr(self.cfg, "ood_root", None)
            if root is None:
                root = Path(self.cfg.root).expanduser() / "svhn"
            root = Path(root).expanduser()
            split = getattr(self.cfg, "ood_split", "test")
            return TorchvisionSVHN(
                root=str(root),
                split=split,
                download=True,
                transform=transform,
            )
        raise ValueError(
            f"Unsupported external OOD dataset: {self.external_ood_name}"
        )

    @torch.no_grad()
    def evaluate_ood(self):
        if self.ood_test_loader is None:
            raise RuntimeError("OOD loader not initialized.")
        self.model.eval()
        if self.tuner is not None:
            self.tuner.eval()
        if self.head is not None:
            self.head.eval()

        amp_enabled = self.cfg.prec == "amp" and self.device.type == "cuda"

        def collect_scores(loader):
            scores = []
            for images, _ in loader:
                images = images.to(self.device)
                batch_size, ncrops, c, h, w = images.size()
                images = images.view(batch_size * ncrops, c, h, w)
                with autocast(enabled=amp_enabled):
                    logits = self.model(images)
                logits = logits.view(batch_size, ncrops, -1).mean(dim=1)
                probs = torch.softmax(logits, dim=1)
                scores.extend(probs.max(dim=1)[0].cpu().numpy())
            return scores

        id_scores = collect_scores(self.test_loader)
        ood_scores = collect_scores(self.ood_test_loader)
        metrics = compute_ood_metrics(id_scores, ood_scores)
        self.last_ood_scores = {"id": id_scores, "ood": ood_scores, "metrics": metrics}
        return metrics

In [10]:
ID_CLASSES = list(range(100))

default_data_root = Path(
    os.environ.get("CIFAR100_ROOT") or (REPO_ROOT / "data")
).expanduser().resolve()
svhn_root = Path(
    os.environ.get("SVHN_ROOT") or (default_data_root / "svhn")
).expanduser().resolve()

class_names = getattr(TorchvisionCIFAR100, "classes", None)
if class_names is None:
    preview_dataset = TorchvisionCIFAR100(
        root=str(default_data_root), train=True, download=True
    )
    class_names = preview_dataset.classes

id_names_preview = [class_names[idx] for idx in ID_CLASSES[:10]]
print(
    f"Training on all {len(ID_CLASSES)} CIFAR-100 classes (first 10): {id_names_preview}"
)
print(f"Using SVHN at {svhn_root} as the OOD dataset.")

cfg = load_experiment_config(
    REPO_ROOT,
    dataset_name="cifar100",
    model_name="clip_vit_b16",
    tuner_name=None,
 )

cfg.use_meta = True
cfg.output_dir = str(REPO_ROOT / "output" / "notebooks" / "cifar100_class_aware_svhn")
cfg.root = str(default_data_root)
cfg.num_epochs = 5
cfg.batch_size = 64
cfg.accum_step = 1
cfg.loss_type = "CE"
cfg.head_only = False
cfg.init_head = "text_feat"
cfg.tte = False
cfg.lr = 0.01
cfg.print_freq = 20
cfg.seed = 0
cfg.deterministic = True
cfg.num_workers = min(4, os.cpu_count() or 4)
cfg.prec = "amp" if device.type == "cuda" else "fp32"
cfg.head_init_folder = cfg.output_dir
cfg.ood_dataset = "svhn"
cfg.ood_root = str(svhn_root)
cfg.ood_split = "test"

os.makedirs(cfg.output_dir, exist_ok=True)

if cfg.seed is not None:
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    os.environ["PYTHONHASHSEED"] = str(cfg.seed)
    torch.manual_seed(cfg.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)

if cfg.deterministic and torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
elif torch.cuda.is_available():
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

logger.init(cfg.output_dir)
cfg

[37mTraining on all 100 CIFAR-100 classes (first 10): ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle'][0m
[37mUsing SVHN at /content/metalora/data/svhn as the OOD dataset.[0m
[37mUsing SVHN at /content/metalora/data/svhn as the OOD dataset.[0m


{'dataset': 'CIFAR100', 'root': '/content/metalora/data', 'imb_factor': None, 'head_init_folder': '/content/metalora/output/notebooks/cifar100_class_aware_svhn', 'backbone': 'CLIP-ViT-B/16', 'resolution': 224, 'output_dir': '/content/metalora/output/notebooks/cifar100_class_aware_svhn', 'print_freq': 20, 'seed': 0, 'deterministic': True, 'num_workers': 2, 'prec': 'amp', 'num_epochs': 5, 'batch_size': 64, 'accum_step': 1, 'lr': 0.01, 'scheduler': 'CosineAnnealingLR', 'weight_decay': 0.0005, 'momentum': 0.9, 'loss_type': 'CE', 'classifier': 'CosineClassifier', 'scale': 25, 'fine_tuning': False, 'head_only': False, 'full_tuning': False, 'bias_tuning': False, 'ln_tuning': False, 'bn_tuning': False, 'vpt_shallow': False, 'vpt_deep': False, 'adapter': False, 'adaptformer': False, 'lora': False, 'lora_mlp': False, 'scale_alpha': 1, 'ssf_attn': False, 'ssf_mlp': False, 'ssf_ln': False, 'mask': False, 'partial': None, 'vpt_len': None, 'adapter_dim': None, 'adaptformer_scale': 'learnable', 'mask

In [11]:
trainer = ClassAwareOODTrainer(
    cfg, device, ID_CLASSES, ood_classes=None, class_aware_k=4
)
trainer.initialize()

[37mTrain samples: 50000 | ID classes: 100 | OOD dataset (SVHN): 26032 samples[0m
[37mFLoRA not used.[0m
[37mViT_Tuner initialization complete[0m
[37mFLoRA not used.[0m
[37mViT_Tuner initialization complete[0m
[37mLoading text features from /content/metalora/output/notebooks/cifar100_class_aware_svhn/txtfeat_clip_vit_b16.pth[0m
[37m
                               Building Optimizer                               
[0m
[37mTuner mode: Only tuning the tuner and head[0m
[37mTurning off gradients in the model[0m
[37mTurning on gradients in the tuner and head[0m
[37mLoading text features from /content/metalora/output/notebooks/cifar100_class_aware_svhn/txtfeat_clip_vit_b16.pth[0m
[37m
                               Building Optimizer                               
[0m
[37mTuner mode: Only tuning the tuner and head[0m
[37mTurning off gradients in the model[0m
[37mTurning on gradients in the tuner and head[0m
[37mTotal params: 149697536[0m
[37mTuner params: 0[0

In [12]:
RUN_TRAINING = True

if RUN_TRAINING:
    trainer.train()
else:
    print("Skipping training; existing weights will be evaluated.")

[37m
                                 Training model                                 
[0m
  with autocast():[0m
  with autocast():[0m
  return _methods._mean(a, axis=axis, dtype=dtype,[0m
  ret = ret.dtype.type(ret / rcount)[0m
[37mepoch [1/5] batch [20/782] time 0.294 (0.317) data 0.000 (0.022) loss 3.3974 (3.7254) acc 39.0625 (29.0672) (mean 28.7039 many 28.7039 med nan few nan) lr 1.0000e-02 elapsed 0:00:06 eta 0:20:32[0m
  return _methods._mean(a, axis=axis, dtype=dtype,[0m
  ret = ret.dtype.type(ret / rcount)[0m
[37mepoch [1/5] batch [20/782] time 0.294 (0.317) data 0.000 (0.022) loss 3.3974 (3.7254) acc 39.0625 (29.0672) (mean 28.7039 many 28.7039 med nan few nan) lr 1.0000e-02 elapsed 0:00:06 eta 0:20:32[0m
[37mepoch [1/5] batch [40/782] time 0.683 (0.343) data 0.000 (0.011) loss 3.2723 (3.4661) acc 42.1875 (31.6744) (mean 31.5586 many 31.5586 med nan few nan) lr 1.0000e-02 elapsed 0:00:13 eta 0:22:06[0m
[37mepoch [1/5] batch [40/782] time 0.683 (0.343) data 0.000

In [13]:
id_acc_scalar = float(trainer.test())
id_metrics = trainer.evaluator.evaluate()
ood_metrics = trainer.evaluate_ood()

summary = {
    "id_accuracy": id_acc_scalar,
    "many_acc": id_metrics.get("many_acc"),
    "med_acc": id_metrics.get("med_acc"),
    "few_acc": id_metrics.get("few_acc"),
}
summary.update(ood_metrics)

for key, value in summary.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")
    elif value is None:
        print(f"{key}: N/A")
    else:
        print(f"{key}: {value}")

[37m
                                Evaluating model                                
[0m
[37mEvaluate on the test set[0m
  0%|          | 0/157 [00:00<?][0m
[37mEvaluate on the test set[0m
  0%|          | 0/157 [00:00<?][0m
 10%|9         | 15/157 [00:11<01:50][0m
 10%|9         | 15/157 [00:11<01:50][0m
 19%|#9        | 30/157 [00:23<01:37][0m
 19%|#9        | 30/157 [00:23<01:37][0m
 29%|##8       | 45/157 [00:34<01:25][0m
 29%|##8       | 45/157 [00:34<01:25][0m
 38%|###8      | 60/157 [00:45<01:13][0m
 38%|###8      | 60/157 [00:45<01:13][0m
 48%|####7     | 75/157 [00:57<01:02][0m
 48%|####7     | 75/157 [00:57<01:02][0m
 57%|#####7    | 90/157 [01:08<00:51][0m
 57%|#####7    | 90/157 [01:08<00:51][0m
 67%|######6   | 105/157 [01:20<00:39][0m
 67%|######6   | 105/157 [01:20<00:39][0m
 76%|#######6  | 120/157 [01:31<00:28][0m
 76%|#######6  | 120/157 [01:31<00:28][0m
 86%|########5 | 135/157 [01:42<00:16][0m
 86%|########5 | 135/157 [01:42<00:16][0m
 96%

In [14]:
from sklearn.metrics import precision_recall_curve, auc

if trainer.last_ood_scores is None:
    raise RuntimeError("Run trainer.evaluate_ood() before computing precision curves.")

id_scores = np.asarray(trainer.last_ood_scores["id"], dtype=np.float32)
ood_scores = np.asarray(trainer.last_ood_scores["ood"], dtype=np.float32)
labels_id = np.concatenate([np.ones_like(id_scores), np.zeros_like(ood_scores)])
scores = np.concatenate([id_scores, ood_scores])

prec_id, rec_id, _ = precision_recall_curve(labels_id, scores)
aupr_id = auc(rec_id, prec_id)

labels_ood = 1 - labels_id
prec_ood, rec_ood, _ = precision_recall_curve(labels_ood, -scores)
aupr_ood = auc(rec_ood, prec_ood)

print(f"AUPR (ID positive): {aupr_id:.4f}")
print(f"AUPR (OOD positive): {aupr_ood:.4f}")

[37mAUPR (ID positive): 0.8272[0m
[37mAUPR (OOD positive): 0.9531[0m
[37mAUPR (OOD positive): 0.9531[0m
