Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ Code for algorithms, applications and tools contributed by:
- [Don Dennis](https://dkdennis.xyz)
- [Yash Gaurkar](https://github.com/mr-yamraj/)
- [Sridhar Gopinath](http://www.sridhargopinath.in/)
- [Sachin Goyal](https://saching007.github.io/)
- [Chirag Gupta](https://aigen.github.io/)
- [Moksh Jain](https://github.com/MJ10)
- [Ashish Kumar](https://ashishkumar1993.github.io/)
- [Aditya Kusupati](https://adityakusupati.github.io/)
- [Chris Lovett](https://github.com/lovettchris)
- [Shishir Patil](https://shishirpatil.github.io/)
- [Oindrila Saha](https://github.com/oindrilasaha)
- [Harsha Vardhan Simhadri](http://harsha-simhadri.org)

[Contributors](https://microsoft.github.io/EdgeML/People) to this project. New contributors welcome.
Expand All @@ -81,9 +83,9 @@ If you use software from this library in your work, please use the BibTex entry

```
@software{edgeml03,
author = {{Dennis, Don Kurian and Gaurkar, Yash and Gopinath, Sridhar and Gupta, Chirag and
Jain, Moksh and Kumar, Ashish and Kusupati, Aditya and Lovett, Chris
and Patil, Shishir G and Simhadri, Harsha Vardhan}},
author = {{Dennis, Don Kurian and Gaurkar, Yash and Gopinath, Sridhar and Goyal, Sachin
and Gupta, Chirag and Jain, Moksh and Kumar, Ashish and Kusupati, Aditya and
Lovett, Chris and Patil, Shishir G and Saha, Oindrila and Simhadri, Harsha Vardhan}},
title = {{EdgeML: Machine Learning for resource-constrained edge devices}},
url = {https://github.com/Microsoft/EdgeML},
version = {0.3},
Expand Down
26 changes: 16 additions & 10 deletions examples/pytorch/DROCC/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Deep Robust One-Class Classification
In this directory we present examples of how to use the `DROCCTrainer` to replicate results in [paper](https://proceedings.icml.cc/book/4293.pdf).
In this directory we present examples of how to use the `DROCCTrainer` and `DROCCLFTrainer` to replicate results in [paper](https://proceedings.icml.cc/book/4293.pdf).

`DROCCTrainer` is part of the `edgeml_pytorch` package. Please install the `edgeml_pytorch` package as follows:
`DROCCTrainer` and `DROCCLFTrainer` are part of the `edgeml_pytorch` package. Please install the `edgeml_pytorch` package as follows:
```
git clone https://github.com/microsoft/EdgeML
cd EdgeML/pytorch
Expand Down Expand Up @@ -38,17 +38,17 @@ The output path is referred to as "root_data" in the following section.
### Command to run experiments to reproduce results
#### Arrhythmia
```
python3 main_tabular.py --hd 128 --lr 0.0001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 16 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric F1 -d "root_data"
python3 main_tabular.py --hd 128 --lr 0.0001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 16 --batch_size 256 --epochs 200 --optim 0 --metric F1 -d "root_data"
```

#### Thyroid
```
python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 2.5 --batch_size 256 --epochs 100 --optim 0 --restore 0 --metric F1 -d "root_data"
python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 2.5 --batch_size 256 --epochs 100 --optim 0 --metric F1 -d "root_data"
```

#### Abalone
```
python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 3 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric F1 -d "root_data"
python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 3 --batch_size 256 --epochs 200 --optim 0 --metric F1 -d "root_data"
```


Expand All @@ -67,20 +67,26 @@ The output path is referred to as "root_data" in the following section.

### Example Usage for Epilepsy Dataset
```
python3 main_timeseries.py --hd 128 --lr 0.00001 --lamda 0.5 --gamma 2 --ascent_step_size 0.1 --radius 10 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric AUC -d "root_data"
python3 main_timeseries.py --hd 128 --lr 0.00001 --lamda 0.5 --gamma 2 --ascent_step_size 0.1 --radius 10 --batch_size 256 --epochs 200 --optim 0 --metric AUC -d "root_data"
```

## CIFAR Experiments
```
python3 main_cifar.py --lamda 1 --radius 8 --lr 0.001 --gamma 1 --ascent_step_size 0.001 --batch_size 256 --epochs 40 --optim 0 --normal_class 0
python3 main_cifar.py --lamda 1 --radius 8 --lr 0.001 --gamma 1 --ascent_step_size 0.001 --batch_size 256 --epochs 100 --optim 0 --normal_class 0
```

## DROCC-LF MNIST Experiment
MNIST Digit 0 vs Digit 1 experiment where close negatives are generated by randomly masking the pixels.
```
python3 main_drocclf_mnist.py --lamda 1 --radius 16 --lr 0.0001 --batch_size 256 --epochs 40 --one_class_adv 1 --optim 0 -oce 10 --ascent_num_steps 100 --ascent_step_size 0.1 --normal_class 0
```

### Arguments Detail
normal_class => CIFAR10 class to be considered as normal
lamda => Weightage to the loss from adversarially sampled negative points (\mu in the paper)
radius => radius corresponding to the definition of set N_i(r)
radius => Radius corresponding to the definition of set N_i(r)
hd => LSTM Hidden Dimension
optim => 0: Adam 1: SGD(M)
ascent_step_size => step size for gradient ascent to generate adversarial anomalies

ascent_step_size => Step size for gradient ascent to generate adversarial anomalies
ascent_num_steps => Number of gradient ascent steps
oce => Only Cross Entropy Steps (No adversarial loss is calculated)
12 changes: 0 additions & 12 deletions examples/pytorch/DROCC/data_process_scripts/process_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ def __init__(self, root: str, normal_class=5):
self.outlier_classes = list(range(0, 10))
self.outlier_classes.remove(normal_class)

# Pre-computed min and max values (after applying GCN) from train data per class
# min_max = [(-28.94083453598571, 13.802961825439636),
# (-6.681770233365245, 9.158067708230273),
# (-34.924463588638204, 14.419298165027628),
# (-10.599172931391799, 11.093187820377565),
# (-11.945022995801637, 10.628045447867583),
# (-9.691969487694928, 8.948326776180823),
# (-9.174940012342555, 13.847014686472365),
# (-6.876682005899029, 12.282371383343161),
# (-15.603507135507172, 15.2464923804279),
# (-6.132882973622672, 8.046098172351265)]
# CIFAR-10 preprocessing: GCN (with L1 norm) and min-max feature scaling to [0,1]
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
std=[0.247, 0.243, 0.261])])
Expand Down
139 changes: 139 additions & 0 deletions examples/pytorch/DROCC/data_process_scripts/process_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
'''
Code borrowed from https://github.com/lukasruff/Deep-SVDD-PyTorch
'''
from PIL import Image
import numpy as np
from random import sample
from abc import ABC, abstractmethod
import torch
from torch.utils.data import Subset
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class BaseADDataset(ABC):
"""Anomaly detection dataset base class."""

def __init__(self, root: str):
super().__init__()
self.root = root # root path to data

self.n_classes = 2 # 0: normal, 1: outlier
self.normal_classes = None # tuple with original class labels that define the normal class
self.outlier_classes = None # tuple with original class labels that define the outlier class

self.train_set = None # must be of type torch.utils.data.Dataset
self.test_set = None # must be of type torch.utils.data.Dataset

@abstractmethod
def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
DataLoader, DataLoader):
"""Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
pass

def __repr__(self):
return self.__class__.__name__

class TorchvisionDataset(BaseADDataset):
"""TorchvisionDataset class for datasets already implemented in torchvision.datasets."""

def __init__(self, root: str):
super().__init__(root)

def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
DataLoader, DataLoader):
train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train,
num_workers=num_workers)
test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test,
num_workers=num_workers)
return train_loader, test_loader

class MNIST_Dataset(TorchvisionDataset):

def __init__(self, root: str, normal_class=0):
super().__init__(root)
#Loads only the digit 0 and digit 1 data
# for both train and test
self.n_classes = 2 # 0: normal, 1: outlier
self.normal_classes = tuple([0])
self.train_classes = tuple([0,1])
self.test_class = tuple([0,1])

transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.1307],
std=[0.3081])])

target_transform = transforms.Lambda(lambda x: int(x in self.normal_classes))

train_set = MyMNIST(root=self.root, train=True, download=True,
transform=transform, target_transform=target_transform)
# Subset train_set to normal class
train_idx_normal = get_target_label_idx(train_set.targets, self.train_classes)
self.train_set = Subset(train_set, train_idx_normal)

test_set = MyMNIST(root=self.root, train=False, download=True,
transform=transform, target_transform=target_transform)
test_idx_normal = get_target_label_idx(test_set.targets, self.test_class)
self.test_set = Subset(test_set, test_idx_normal)

class MyMNIST(MNIST):
"""Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample."""

def __init__(self, *args, **kwargs):
super(MyMNIST, self).__init__(*args, **kwargs)

def __getitem__(self, index):
"""Override the original method of the MNIST class.
Args:
index (int): Index
Returns:
triple: (image, target, index) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]

# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
target = self.target_transform(target)

return img, target, index # only line changed


def get_target_label_idx(labels, targets):
"""
Get the indices of labels that are included in targets.
:param labels: array of labels
:param targets: list/tuple of target labels
:return: list with indices of target labels
"""
return np.argwhere(np.isin(labels, targets)).flatten().tolist()


def global_contrast_normalization(x: torch.tensor, scale='l2'):
"""
Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
which is either the standard deviation, L1- or L2-norm across features (pixels).
Note this is a *per sample* normalization globally across features (and not across the dataset).
"""

assert scale in ('l1', 'l2')

n_features = int(np.prod(x.shape))

mean = torch.mean(x) # mean over all features (pixels) per sample
x -= mean

if scale == 'l1':
x_scale = torch.mean(torch.abs(x))

if scale == 'l2':
x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features

x /= x_scale

return x
31 changes: 18 additions & 13 deletions examples/pytorch/DROCC/main_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,24 @@ def main():
lr=args.lr)
print("using Adam")

# Training the model
trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device)

# Restore from checkpoint
if args.restore == 1:

if args.eval == 0:
# Training the model
trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs,
metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = 0)

trainer.save(args.model_dir)

else:
if os.path.exists(os.path.join(args.model_dir, 'model.pt')):
trainer.load(args.model_dir)
print("Saved Model Loaded")

trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs,
metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = 0)

trainer.save(args.model_dir)
else:
print('Saved model not found. Cannot run evaluation.')
exit()
score = trainer.test(test_loader, 'AUC')
print('Test AUC: {}'.format(score))

if __name__ == '__main__':
torch.set_printoptions(precision=5)
Expand All @@ -111,15 +116,15 @@ def main():
help='number of epochs to train')
parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=50, metavar='N',
help='number of epochs to train with only CE loss')
parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N',
parser.add_argument('--ascent_num_steps', type=int, default=100, metavar='N',
help='Number of gradient ascent steps')
parser.add_argument('--hd', type=int, default=128, metavar='N',
help='Num hidden nodes for LSTM model')
parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
help='learning rate')
parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR',
help='step size of gradient ascent')
parser.add_argument('--mom', type=float, default=0.99, metavar='M',
parser.add_argument('--mom', type=float, default=0.0, metavar='M',
help='momentum')
parser.add_argument('--model_dir', default='log',
help='path where to save checkpoint')
Expand All @@ -131,8 +136,8 @@ def main():
help='Weight to the adversarial loss')
parser.add_argument('--reg', type=float, default=0, metavar='N',
help='weight reg')
parser.add_argument('--restore', type=int, default=0, metavar='N',
help='whether to load a pretrained model, 1: load 0: train from scratch ')
parser.add_argument('--eval', type=int, default=0, metavar='N',
help='whether to load a saved model and evaluate (0/1)')
parser.add_argument('--optim', type=int, default=0, metavar='N',
help='0 : Adam 1: SGD')
parser.add_argument('--gamma', type=float, default=2.0, metavar='N',
Expand Down
Loading