Skip to content

Commit

Permalink
CIFAR classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
So-Cool committed Apr 26, 2024
1 parent 4ca33b3 commit 7ed7497
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 2 deletions.
166 changes: 166 additions & 0 deletions surrogates_overview/scripts/cifar_label_map.py
@@ -0,0 +1,166 @@
"""
CIFAR 10 & 100 Labels Map
=========================
This module provides a map from a class id to label.
Two maps are available:
* ``CIFAR10_LABEL_MAP`` -- maps class ids to labels for CIFAR10; and
* ``CIFAR100_LABEL_MAP`` -- maps class ids to labels for CIFAR100.
The data set files needed to regenerate the label maps are available at
<https://www.cs.toronto.edu/~kriz/cifar.html>.
See <https://github.com/fat-forensics/resources/tree/master/surrogates_overview>
for more details.
"""
# Author: Kacper Sokol <k.sokol@bristol.ac.uk>
# License: new BSD


import pickle


def _load_cifar10_labels(data_folder):
"""Generates the label map for CIFAR10."""
with open(f'{data_folder}/cifar-10-batches-py/batches.meta', 'rb') as fo:
cf10meta = pickle.load(fo, encoding='bytes')

cf10labels = {i: j.decode()
for i, j in enumerate(cf10meta.get(b'label_names'))}

return cf10labels


def _load_cifar100_labels(data_folder, fine_labels=True):
"""Generates the label map for CIFAR100."""
with open(f'{data_folder}/cifar-100-python/meta', 'rb') as fo:
cf100meta = pickle.load(fo, encoding='bytes')

if fine_labels:
cf100_labels_type = b'fine_label_names'
else:
cf100_labels_type = b'coarse_label_names'

cf100labels = {i: j.decode()
for i, j in enumerate(cf100meta.get(cf100_labels_type))}

return cf100labels


CIFAR10_LABEL_MAP = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}


CIFAR100_LABEL_MAP = {
0: 'apple',
1: 'aquarium_fish',
2: 'baby',
3: 'bear',
4: 'beaver',
5: 'bed',
6: 'bee',
7: 'beetle',
8: 'bicycle',
9: 'bottle',
10: 'bowl',
11: 'boy',
12: 'bridge',
13: 'bus',
14: 'butterfly',
15: 'camel',
16: 'can',
17: 'castle',
18: 'caterpillar',
19: 'cattle',
20: 'chair',
21: 'chimpanzee',
22: 'clock',
23: 'cloud',
24: 'cockroach',
25: 'couch',
26: 'crab',
27: 'crocodile',
28: 'cup',
29: 'dinosaur',
30: 'dolphin',
31: 'elephant',
32: 'flatfish',
33: 'forest',
34: 'fox',
35: 'girl',
36: 'hamster',
37: 'house',
38: 'kangaroo',
39: 'keyboard',
40: 'lamp',
41: 'lawn_mower',
42: 'leopard',
43: 'lion',
44: 'lizard',
45: 'lobster',
46: 'man',
47: 'maple_tree',
48: 'motorcycle',
49: 'mountain',
50: 'mouse',
51: 'mushroom',
52: 'oak_tree',
53: 'orange',
54: 'orchid',
55: 'otter',
56: 'palm_tree',
57: 'pear',
58: 'pickup_truck',
59: 'pine_tree',
60: 'plain',
61: 'plate',
62: 'poppy',
63: 'porcupine',
64: 'possum',
65: 'rabbit',
66: 'raccoon',
67: 'ray',
68: 'road',
69: 'rocket',
70: 'rose',
71: 'sea',
72: 'seal',
73: 'shark',
74: 'shrew',
75: 'skunk',
76: 'skyscraper',
77: 'snail',
78: 'snake',
79: 'spider',
80: 'squirrel',
81: 'streetcar',
82: 'sunflower',
83: 'sweet_pepper',
84: 'table',
85: 'tank',
86: 'telephone',
87: 'television',
88: 'tiger',
89: 'tractor',
90: 'train',
91: 'trout',
92: 'tulip',
93: 'turtle',
94: 'wardrobe',
95: 'whale',
96: 'willow_tree',
97: 'wolf',
98: 'woman',
99: 'worm'
}
102 changes: 100 additions & 2 deletions surrogates_overview/scripts/image_classifier.py
Expand Up @@ -2,15 +2,19 @@
Image Classifier
================
This module implements an image classifier based on PyTorch.
Inception v3 and AlexNet are availabel.
This module implements image classifiers based on PyTorch.
Inception v3 and AlexNet are available for ImageNet;
ResNet56 is available for CIFAR10; and
RepVGG (a2) is available for CIFAR100.
See <https://github.com/fat-forensics/resources/tree/master/surrogates_overview>
for more details.
"""
# Author: Kacper Sokol <k.sokol@bristol.ac.uk>
# License: new BSD

from scripts.imagenet_label_map import IMAGENET_LABEL_MAP
from scripts.cifar_label_map import CIFAR10_LABEL_MAP, CIFAR100_LABEL_MAP

import numpy as np

Expand All @@ -33,6 +37,26 @@ def _get_preprocess_transform():
return transf


def _get_preprocess_transform_cifar10():
# https://github.com/chenyaofo/pytorch-cifar-models/issues/4
# https://github.com/chenyaofo/image-classification-codebase/blob/master/conf/cifar10.conf
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
transf = transforms.Compose([transforms.ToTensor(), normalize])

return transf


def _get_preprocess_transform_cifar100():
# https://github.com/chenyaofo/pytorch-cifar-models/issues/4
# https://github.com/chenyaofo/image-classification-codebase/blob/master/conf/cifar100.conf
normalize = transforms.Normalize(
mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761])
transf = transforms.Compose([transforms.ToTensor(), normalize])

return transf


class ImageClassifier(object):
"""Image classifier based on PyTorch."""

Expand Down Expand Up @@ -128,3 +152,77 @@ def proba2tuple(self, Y, labels_no=5):
tuples_.append((lab, Y[idx, cls], cls))
tuples.append(tuples_)
return tuples


class ImageNetClassifier(ImageClassifier):
"""ImageNet classifiers -- Inception v3 & AlexNet -- based on PyTorch."""


class Cifar10Classifier(ImageClassifier):
"""CIFAR10 classifiers -- ResNet56 -- based on PyTorch."""

def __init__(self, use_gpu=False):
"""Initialises the image classifier."""
# Get class labels
self.class_idx = CIFAR10_LABEL_MAP

# Get the model
# https://github.com/huyvnphan/PyTorch_CIFAR10
clf = torch.hub.load(
'chenyaofo/pytorch-cifar-models',
'cifar10_resnet56',
pretrained=True)

if use_gpu:
if CUDA_AVAILABLE:
clf = clf.to(DEVICE)
# clf.cuda()
predict_proba = self._predict_proba_gpu
else:
logger.warning('GPU was requested but it is not available. '
'Using CPU instead.')
predict_proba = self._predict_proba_cpu
else:
predict_proba = self._predict_proba_cpu
self.predict_proba = predict_proba

self.clf = clf
self.clf.eval()

# Get transformation
self.preprocess_transform = _get_preprocess_transform_cifar10()


class Cifar100Classifier(ImageClassifier):
"""CIFAR100 classifiers -- RepVGG (a2) -- based on PyTorch."""

def __init__(self, use_gpu=False):
"""Initialises the image classifier."""
# Get class labels
self.class_idx = CIFAR100_LABEL_MAP

# Get the model
# https://github.com/huyvnphan/PyTorch_CIFAR10
clf = torch.hub.load(
'chenyaofo/pytorch-cifar-models',
'cifar100_repvgg_a2',
pretrained=True)

if use_gpu:
if CUDA_AVAILABLE:
clf = clf.to(DEVICE)
# clf.cuda()
predict_proba = self._predict_proba_gpu
else:
logger.warning('GPU was requested but it is not available. '
'Using CPU instead.')
predict_proba = self._predict_proba_cpu
else:
predict_proba = self._predict_proba_cpu
self.predict_proba = predict_proba

self.clf = clf
self.clf.eval()

# Get transformation
self.preprocess_transform = _get_preprocess_transform_cifar100()

0 comments on commit 7ed7497

Please sign in to comment.