Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CIFAR models #6

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on: push
jobs:
build-n-deploy:
name: Test 🔧
runs-on: ubuntu-18.04
runs-on: ubuntu-22.04
steps:
- name: Checkout code 🛎️
uses: actions/checkout@v2.3.1
Expand Down
166 changes: 166 additions & 0 deletions surrogates_overview/scripts/cifar_label_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
CIFAR 10 & 100 Labels Map
=========================

This module provides a map from a class id to label.
Two maps are available:

* ``CIFAR10_LABEL_MAP`` -- maps class ids to labels for CIFAR10; and
* ``CIFAR100_LABEL_MAP`` -- maps class ids to labels for CIFAR100.

The data set files needed to regenerate the label maps are available at
<https://www.cs.toronto.edu/~kriz/cifar.html>.

See <https://github.com/fat-forensics/resources/tree/master/surrogates_overview>
for more details.
"""
# Author: Kacper Sokol <k.sokol@bristol.ac.uk>
# License: new BSD


import pickle


def _load_cifar10_labels(data_folder):
"""Generates the label map for CIFAR10."""
with open(f'{data_folder}/cifar-10-batches-py/batches.meta', 'rb') as fo:
cf10meta = pickle.load(fo, encoding='bytes')

cf10labels = {i: j.decode()
for i, j in enumerate(cf10meta.get(b'label_names'))}

return cf10labels


def _load_cifar100_labels(data_folder, fine_labels=True):
"""Generates the label map for CIFAR100."""
with open(f'{data_folder}/cifar-100-python/meta', 'rb') as fo:
cf100meta = pickle.load(fo, encoding='bytes')

if fine_labels:
cf100_labels_type = b'fine_label_names'
else:
cf100_labels_type = b'coarse_label_names'

cf100labels = {i: j.decode()
for i, j in enumerate(cf100meta.get(cf100_labels_type))}

return cf100labels


CIFAR10_LABEL_MAP = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck'
}


CIFAR100_LABEL_MAP = {
0: 'apple',
1: 'aquarium_fish',
2: 'baby',
3: 'bear',
4: 'beaver',
5: 'bed',
6: 'bee',
7: 'beetle',
8: 'bicycle',
9: 'bottle',
10: 'bowl',
11: 'boy',
12: 'bridge',
13: 'bus',
14: 'butterfly',
15: 'camel',
16: 'can',
17: 'castle',
18: 'caterpillar',
19: 'cattle',
20: 'chair',
21: 'chimpanzee',
22: 'clock',
23: 'cloud',
24: 'cockroach',
25: 'couch',
26: 'crab',
27: 'crocodile',
28: 'cup',
29: 'dinosaur',
30: 'dolphin',
31: 'elephant',
32: 'flatfish',
33: 'forest',
34: 'fox',
35: 'girl',
36: 'hamster',
37: 'house',
38: 'kangaroo',
39: 'keyboard',
40: 'lamp',
41: 'lawn_mower',
42: 'leopard',
43: 'lion',
44: 'lizard',
45: 'lobster',
46: 'man',
47: 'maple_tree',
48: 'motorcycle',
49: 'mountain',
50: 'mouse',
51: 'mushroom',
52: 'oak_tree',
53: 'orange',
54: 'orchid',
55: 'otter',
56: 'palm_tree',
57: 'pear',
58: 'pickup_truck',
59: 'pine_tree',
60: 'plain',
61: 'plate',
62: 'poppy',
63: 'porcupine',
64: 'possum',
65: 'rabbit',
66: 'raccoon',
67: 'ray',
68: 'road',
69: 'rocket',
70: 'rose',
71: 'sea',
72: 'seal',
73: 'shark',
74: 'shrew',
75: 'skunk',
76: 'skyscraper',
77: 'snail',
78: 'snake',
79: 'spider',
80: 'squirrel',
81: 'streetcar',
82: 'sunflower',
83: 'sweet_pepper',
84: 'table',
85: 'tank',
86: 'telephone',
87: 'television',
88: 'tiger',
89: 'tractor',
90: 'train',
91: 'trout',
92: 'tulip',
93: 'turtle',
94: 'wardrobe',
95: 'whale',
96: 'willow_tree',
97: 'wolf',
98: 'woman',
99: 'worm'
}
102 changes: 100 additions & 2 deletions surrogates_overview/scripts/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
Image Classifier
================

This module implements an image classifier based on PyTorch.
Inception v3 and AlexNet are availabel.
This module implements image classifiers based on PyTorch.
Inception v3 and AlexNet are available for ImageNet;
ResNet56 is available for CIFAR10; and
RepVGG (a2) is available for CIFAR100.

See <https://github.com/fat-forensics/resources/tree/master/surrogates_overview>
for more details.
"""
# Author: Kacper Sokol <k.sokol@bristol.ac.uk>
# License: new BSD

from scripts.imagenet_label_map import IMAGENET_LABEL_MAP
from scripts.cifar_label_map import CIFAR10_LABEL_MAP, CIFAR100_LABEL_MAP

import numpy as np

Expand All @@ -33,6 +37,26 @@ def _get_preprocess_transform():
return transf


def _get_preprocess_transform_cifar10():
# https://github.com/chenyaofo/pytorch-cifar-models/issues/4
# https://github.com/chenyaofo/image-classification-codebase/blob/master/conf/cifar10.conf
normalize = transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
transf = transforms.Compose([transforms.ToTensor(), normalize])

return transf


def _get_preprocess_transform_cifar100():
# https://github.com/chenyaofo/pytorch-cifar-models/issues/4
# https://github.com/chenyaofo/image-classification-codebase/blob/master/conf/cifar100.conf
normalize = transforms.Normalize(
mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761])
transf = transforms.Compose([transforms.ToTensor(), normalize])

return transf


class ImageClassifier(object):
"""Image classifier based on PyTorch."""

Expand Down Expand Up @@ -128,3 +152,77 @@ def proba2tuple(self, Y, labels_no=5):
tuples_.append((lab, Y[idx, cls], cls))
tuples.append(tuples_)
return tuples


class ImageNetClassifier(ImageClassifier):
"""ImageNet classifiers -- Inception v3 & AlexNet -- based on PyTorch."""


class Cifar10Classifier(ImageClassifier):
"""CIFAR10 classifiers -- ResNet56 -- based on PyTorch."""

def __init__(self, use_gpu=False):
"""Initialises the image classifier."""
# Get class labels
self.class_idx = CIFAR10_LABEL_MAP

# Get the model
# https://github.com/huyvnphan/PyTorch_CIFAR10
clf = torch.hub.load(
'chenyaofo/pytorch-cifar-models',
'cifar10_resnet56',
pretrained=True)

if use_gpu:
if CUDA_AVAILABLE:
clf = clf.to(DEVICE)
# clf.cuda()
predict_proba = self._predict_proba_gpu
else:
logger.warning('GPU was requested but it is not available. '
'Using CPU instead.')
predict_proba = self._predict_proba_cpu
else:
predict_proba = self._predict_proba_cpu
self.predict_proba = predict_proba

self.clf = clf
self.clf.eval()

# Get transformation
self.preprocess_transform = _get_preprocess_transform_cifar10()


class Cifar100Classifier(ImageClassifier):
"""CIFAR100 classifiers -- RepVGG (a2) -- based on PyTorch."""

def __init__(self, use_gpu=False):
"""Initialises the image classifier."""
# Get class labels
self.class_idx = CIFAR100_LABEL_MAP

# Get the model
# https://github.com/huyvnphan/PyTorch_CIFAR10
clf = torch.hub.load(
'chenyaofo/pytorch-cifar-models',
'cifar100_repvgg_a2',
pretrained=True)

if use_gpu:
if CUDA_AVAILABLE:
clf = clf.to(DEVICE)
# clf.cuda()
predict_proba = self._predict_proba_gpu
else:
logger.warning('GPU was requested but it is not available. '
'Using CPU instead.')
predict_proba = self._predict_proba_cpu
else:
predict_proba = self._predict_proba_cpu
self.predict_proba = predict_proba

self.clf = clf
self.clf.eval()

# Get transformation
self.preprocess_transform = _get_preprocess_transform_cifar100()
Loading