diff --git a/README.md b/README.md index 21d8a6090..27f898379 100644 --- a/README.md +++ b/README.md @@ -65,12 +65,14 @@ Code for algorithms, applications and tools contributed by: - [Don Dennis](https://dkdennis.xyz) - [Yash Gaurkar](https://github.com/mr-yamraj/) - [Sridhar Gopinath](http://www.sridhargopinath.in/) + - [Sachin Goyal](https://saching007.github.io/) - [Chirag Gupta](https://aigen.github.io/) - [Moksh Jain](https://github.com/MJ10) - [Ashish Kumar](https://ashishkumar1993.github.io/) - [Aditya Kusupati](https://adityakusupati.github.io/) - [Chris Lovett](https://github.com/lovettchris) - [Shishir Patil](https://shishirpatil.github.io/) + - [Oindrila Saha](https://github.com/oindrilasaha) - [Harsha Vardhan Simhadri](http://harsha-simhadri.org) [Contributors](https://microsoft.github.io/EdgeML/People) to this project. New contributors welcome. @@ -81,9 +83,9 @@ If you use software from this library in your work, please use the BibTex entry ``` @software{edgeml03, - author = {{Dennis, Don Kurian and Gaurkar, Yash and Gopinath, Sridhar and Gupta, Chirag and - Jain, Moksh and Kumar, Ashish and Kusupati, Aditya and Lovett, Chris - and Patil, Shishir G and Simhadri, Harsha Vardhan}}, + author = {{Dennis, Don Kurian and Gaurkar, Yash and Gopinath, Sridhar and Goyal, Sachin + and Gupta, Chirag and Jain, Moksh and Kumar, Ashish and Kusupati, Aditya and + Lovett, Chris and Patil, Shishir G and Saha, Oindrila and Simhadri, Harsha Vardhan}}, title = {{EdgeML: Machine Learning for resource-constrained edge devices}}, url = {https://github.com/Microsoft/EdgeML}, version = {0.3}, diff --git a/examples/pytorch/DROCC/README.md b/examples/pytorch/DROCC/README.md index f538e3c33..3dcb42d23 100644 --- a/examples/pytorch/DROCC/README.md +++ b/examples/pytorch/DROCC/README.md @@ -1,7 +1,7 @@ # Deep Robust One-Class Classification -In this directory we present examples of how to use the `DROCCTrainer` to replicate results in [paper](https://proceedings.icml.cc/book/4293.pdf). +In this directory we present examples of how to use the `DROCCTrainer` and `DROCCLFTrainer` to replicate results in [paper](https://proceedings.icml.cc/book/4293.pdf). -`DROCCTrainer` is part of the `edgeml_pytorch` package. Please install the `edgeml_pytorch` package as follows: +`DROCCTrainer` and `DROCCLFTrainer` are part of the `edgeml_pytorch` package. Please install the `edgeml_pytorch` package as follows: ``` git clone https://github.com/microsoft/EdgeML cd EdgeML/pytorch @@ -38,17 +38,17 @@ The output path is referred to as "root_data" in the following section. ### Command to run experiments to reproduce results #### Arrhythmia ``` -python3 main_tabular.py --hd 128 --lr 0.0001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 16 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric F1 -d "root_data" +python3 main_tabular.py --hd 128 --lr 0.0001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 16 --batch_size 256 --epochs 200 --optim 0 --metric F1 -d "root_data" ``` #### Thyroid ``` -python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 2.5 --batch_size 256 --epochs 100 --optim 0 --restore 0 --metric F1 -d "root_data" +python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 2.5 --batch_size 256 --epochs 100 --optim 0 --metric F1 -d "root_data" ``` #### Abalone ``` -python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 3 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric F1 -d "root_data" +python3 main_tabular.py --hd 128 --lr 0.001 --lamda 1 --gamma 2 --ascent_step_size 0.001 --radius 3 --batch_size 256 --epochs 200 --optim 0 --metric F1 -d "root_data" ``` @@ -67,20 +67,26 @@ The output path is referred to as "root_data" in the following section. ### Example Usage for Epilepsy Dataset ``` -python3 main_timeseries.py --hd 128 --lr 0.00001 --lamda 0.5 --gamma 2 --ascent_step_size 0.1 --radius 10 --batch_size 256 --epochs 200 --optim 0 --restore 0 --metric AUC -d "root_data" +python3 main_timeseries.py --hd 128 --lr 0.00001 --lamda 0.5 --gamma 2 --ascent_step_size 0.1 --radius 10 --batch_size 256 --epochs 200 --optim 0 --metric AUC -d "root_data" ``` ## CIFAR Experiments ``` -python3 main_cifar.py --lamda 1 --radius 8 --lr 0.001 --gamma 1 --ascent_step_size 0.001 --batch_size 256 --epochs 40 --optim 0 --normal_class 0 +python3 main_cifar.py --lamda 1 --radius 8 --lr 0.001 --gamma 1 --ascent_step_size 0.001 --batch_size 256 --epochs 100 --optim 0 --normal_class 0 ``` +## DROCC-LF MNIST Experiment +MNIST Digit 0 vs Digit 1 experiment where close negatives are generated by randomly masking the pixels. +``` +python3 main_drocclf_mnist.py --lamda 1 --radius 16 --lr 0.0001 --batch_size 256 --epochs 40 --one_class_adv 1 --optim 0 -oce 10 --ascent_num_steps 100 --ascent_step_size 0.1 --normal_class 0 +``` ### Arguments Detail normal_class => CIFAR10 class to be considered as normal lamda => Weightage to the loss from adversarially sampled negative points (\mu in the paper) -radius => radius corresponding to the definition of set N_i(r) +radius => Radius corresponding to the definition of set N_i(r) hd => LSTM Hidden Dimension optim => 0: Adam 1: SGD(M) -ascent_step_size => step size for gradient ascent to generate adversarial anomalies - +ascent_step_size => Step size for gradient ascent to generate adversarial anomalies +ascent_num_steps => Number of gradient ascent steps +oce => Only Cross Entropy Steps (No adversarial loss is calculated) diff --git a/examples/pytorch/DROCC/data_process_scripts/process_cifar.py b/examples/pytorch/DROCC/data_process_scripts/process_cifar.py index 61f9397be..6579a63d8 100644 --- a/examples/pytorch/DROCC/data_process_scripts/process_cifar.py +++ b/examples/pytorch/DROCC/data_process_scripts/process_cifar.py @@ -58,18 +58,6 @@ def __init__(self, root: str, normal_class=5): self.outlier_classes = list(range(0, 10)) self.outlier_classes.remove(normal_class) - # Pre-computed min and max values (after applying GCN) from train data per class - # min_max = [(-28.94083453598571, 13.802961825439636), - # (-6.681770233365245, 9.158067708230273), - # (-34.924463588638204, 14.419298165027628), - # (-10.599172931391799, 11.093187820377565), - # (-11.945022995801637, 10.628045447867583), - # (-9.691969487694928, 8.948326776180823), - # (-9.174940012342555, 13.847014686472365), - # (-6.876682005899029, 12.282371383343161), - # (-15.603507135507172, 15.2464923804279), - # (-6.132882973622672, 8.046098172351265)] - # CIFAR-10 preprocessing: GCN (with L1 norm) and min-max feature scaling to [0,1] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])]) diff --git a/examples/pytorch/DROCC/data_process_scripts/process_mnist.py b/examples/pytorch/DROCC/data_process_scripts/process_mnist.py new file mode 100644 index 000000000..f9c553fba --- /dev/null +++ b/examples/pytorch/DROCC/data_process_scripts/process_mnist.py @@ -0,0 +1,139 @@ +''' +Code borrowed from https://github.com/lukasruff/Deep-SVDD-PyTorch +''' +from PIL import Image +import numpy as np +from random import sample +from abc import ABC, abstractmethod +import torch +from torch.utils.data import Subset +from torchvision.datasets import MNIST +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +class BaseADDataset(ABC): + """Anomaly detection dataset base class.""" + + def __init__(self, root: str): + super().__init__() + self.root = root # root path to data + + self.n_classes = 2 # 0: normal, 1: outlier + self.normal_classes = None # tuple with original class labels that define the normal class + self.outlier_classes = None # tuple with original class labels that define the outlier class + + self.train_set = None # must be of type torch.utils.data.Dataset + self.test_set = None # must be of type torch.utils.data.Dataset + + @abstractmethod + def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> ( + DataLoader, DataLoader): + """Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set.""" + pass + + def __repr__(self): + return self.__class__.__name__ + +class TorchvisionDataset(BaseADDataset): + """TorchvisionDataset class for datasets already implemented in torchvision.datasets.""" + + def __init__(self, root: str): + super().__init__(root) + + def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> ( + DataLoader, DataLoader): + train_loader = DataLoader(dataset=self.train_set, batch_size=batch_size, shuffle=shuffle_train, + num_workers=num_workers) + test_loader = DataLoader(dataset=self.test_set, batch_size=batch_size, shuffle=shuffle_test, + num_workers=num_workers) + return train_loader, test_loader + +class MNIST_Dataset(TorchvisionDataset): + + def __init__(self, root: str, normal_class=0): + super().__init__(root) + #Loads only the digit 0 and digit 1 data + # for both train and test + self.n_classes = 2 # 0: normal, 1: outlier + self.normal_classes = tuple([0]) + self.train_classes = tuple([0,1]) + self.test_class = tuple([0,1]) + + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize(mean=[0.1307], + std=[0.3081])]) + + target_transform = transforms.Lambda(lambda x: int(x in self.normal_classes)) + + train_set = MyMNIST(root=self.root, train=True, download=True, + transform=transform, target_transform=target_transform) + # Subset train_set to normal class + train_idx_normal = get_target_label_idx(train_set.targets, self.train_classes) + self.train_set = Subset(train_set, train_idx_normal) + + test_set = MyMNIST(root=self.root, train=False, download=True, + transform=transform, target_transform=target_transform) + test_idx_normal = get_target_label_idx(test_set.targets, self.test_class) + self.test_set = Subset(test_set, test_idx_normal) + +class MyMNIST(MNIST): + """Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample.""" + + def __init__(self, *args, **kwargs): + super(MyMNIST, self).__init__(*args, **kwargs) + + def __getitem__(self, index): + """Override the original method of the MNIST class. + Args: + index (int): Index + Returns: + triple: (image, target, index) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode='L') + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target, index # only line changed + + +def get_target_label_idx(labels, targets): + """ + Get the indices of labels that are included in targets. + :param labels: array of labels + :param targets: list/tuple of target labels + :return: list with indices of target labels + """ + return np.argwhere(np.isin(labels, targets)).flatten().tolist() + + +def global_contrast_normalization(x: torch.tensor, scale='l2'): + """ + Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale, + which is either the standard deviation, L1- or L2-norm across features (pixels). + Note this is a *per sample* normalization globally across features (and not across the dataset). + """ + + assert scale in ('l1', 'l2') + + n_features = int(np.prod(x.shape)) + + mean = torch.mean(x) # mean over all features (pixels) per sample + x -= mean + + if scale == 'l1': + x_scale = torch.mean(torch.abs(x)) + + if scale == 'l2': + x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features + + x /= x_scale + + return x diff --git a/examples/pytorch/DROCC/main_cifar.py b/examples/pytorch/DROCC/main_cifar.py index 505aa4f83..51622fdc4 100644 --- a/examples/pytorch/DROCC/main_cifar.py +++ b/examples/pytorch/DROCC/main_cifar.py @@ -85,19 +85,24 @@ def main(): lr=args.lr) print("using Adam") - # Training the model trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device) - - # Restore from checkpoint - if args.restore == 1: + + if args.eval == 0: + # Training the model + trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs, + metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = 0) + + trainer.save(args.model_dir) + + else: if os.path.exists(os.path.join(args.model_dir, 'model.pt')): trainer.load(args.model_dir) print("Saved Model Loaded") - - trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs, - metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = 0) - - trainer.save(args.model_dir) + else: + print('Saved model not found. Cannot run evaluation.') + exit() + score = trainer.test(test_loader, 'AUC') + print('Test AUC: {}'.format(score)) if __name__ == '__main__': torch.set_printoptions(precision=5) @@ -111,7 +116,7 @@ def main(): help='number of epochs to train') parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=50, metavar='N', help='number of epochs to train with only CE loss') - parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N', + parser.add_argument('--ascent_num_steps', type=int, default=100, metavar='N', help='Number of gradient ascent steps') parser.add_argument('--hd', type=int, default=128, metavar='N', help='Num hidden nodes for LSTM model') @@ -119,7 +124,7 @@ def main(): help='learning rate') parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR', help='step size of gradient ascent') - parser.add_argument('--mom', type=float, default=0.99, metavar='M', + parser.add_argument('--mom', type=float, default=0.0, metavar='M', help='momentum') parser.add_argument('--model_dir', default='log', help='path where to save checkpoint') @@ -131,8 +136,8 @@ def main(): help='Weight to the adversarial loss') parser.add_argument('--reg', type=float, default=0, metavar='N', help='weight reg') - parser.add_argument('--restore', type=int, default=0, metavar='N', - help='whether to load a pretrained model, 1: load 0: train from scratch ') + parser.add_argument('--eval', type=int, default=0, metavar='N', + help='whether to load a saved model and evaluate (0/1)') parser.add_argument('--optim', type=int, default=0, metavar='N', help='0 : Adam 1: SGD') parser.add_argument('--gamma', type=float, default=2.0, metavar='N', diff --git a/examples/pytorch/DROCC/main_drocclf_mnist.py b/examples/pytorch/DROCC/main_drocclf_mnist.py new file mode 100644 index 000000000..2144c3b60 --- /dev/null +++ b/examples/pytorch/DROCC/main_drocclf_mnist.py @@ -0,0 +1,200 @@ +from __future__ import print_function +import os +import numpy as np +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from collections import OrderedDict +from data_process_scripts.process_mnist import MNIST_Dataset +from edgeml_pytorch.trainer.drocclf_trainer import DROCCLFTrainer, cal_precision_recall + +class MNIST_LeNet(nn.Module): + + def __init__(self): + super().__init__() + + self.rep_dim = 64 + self.pool = nn.MaxPool2d(2, 2) + self.conv1 = nn.Conv2d(1, 8, 5, bias=False, padding=2) + self.bn1 = nn.BatchNorm2d(8, eps=1e-04, affine=False) + self.conv2 = nn.Conv2d(8, 4, 5, bias=False, padding=2) + self.bn2 = nn.BatchNorm2d(4, eps=1e-04, affine=False) + self.fc1 = nn.Linear(4 * 7 * 7, self.rep_dim, bias=False) + self.fc2 = nn.Linear(self.rep_dim, 1, bias=False) + + def forward(self, x): + x = x.view(x.shape[0],1,28,28) + x = self.conv1(x) + x = self.pool(F.leaky_relu(self.bn1(x))) + x = self.conv2(x) + x = self.pool(F.leaky_relu(self.bn2(x))) + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = self.fc2(x) + return x + + +def adjust_learning_rate(epoch, total_epochs, only_ce_epochs, learning_rate, optimizer): + + """Adjust learning rate during training. + + Parameters + ---------- + epoch: Current training epoch. + total_epochs: Total number of epochs for training. + only_ce_epochs: Number of epochs for initial pretraining. + learning_rate: Initial learning rate for training. + """ + #We dont want to consider the only ce + #based epochs for the lr scheduler + epoch = epoch - only_ce_epochs + drocc_epochs = total_epochs - only_ce_epochs + # lr = learning_rate + if epoch <= drocc_epochs: + lr = learning_rate * 0.01 + if epoch <= 0.80 * drocc_epochs: + lr = learning_rate * 0.1 + if epoch <= 0.40 * drocc_epochs: + lr = learning_rate + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + return optimizer + +class CustomDataset(Dataset): + def __init__(self, data, labels): + self.data = data + self.labels = labels + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + return torch.from_numpy(self.data[idx]), (self.labels[idx]), torch.tensor([0]) + +def get_close_negs(test_loader): + # Getting the Close Negs data for MNIST Digit 0 by + # randomly masking 50% of the digit 0 test image pixels + batch_idx = -1 + for data, target, _ in test_loader: + batch_idx += 1 + data, target = data.to(device), target.to(device) + data = data.to(torch.float) + target = target.to(torch.float) + #In our case , label of digit 0 is 1 + data_0 = data[target==1] + aug1 = data_0.clone() + indices = np.random.choice(np.arange(torch.numel(aug1)), replace=False, + size=int(torch.numel(aug1) * 0.4)) + aug1[np.unravel_index(indices, np.shape(aug1))] = torch.min(data) + if batch_idx==0: + close_neg_data = aug1 + else: + close_neg_data = torch.cat((close_neg_data,aug1), dim=0) + + close_neg_data = close_neg_data.detach().cpu().numpy() + close_neg_labels = np.zeros((close_neg_data.shape[0])) + + return CustomDataset(close_neg_data, close_neg_labels) + +def main(): + #Load digit 0 and digit 1 data from MNIST + dataset = MNIST_Dataset("data") + train_loader, test_loader = dataset.loaders(batch_size=args.batch_size) + closeneg_test_data = get_close_negs(test_loader) + closeneg_test_loader = DataLoader(closeneg_test_data, args.batch_size, shuffle=True) + + model = MNIST_LeNet().to(device) + model = nn.DataParallel(model) + + if args.optim == 1: + optimizer = optim.SGD(model.parameters(), + lr=args.lr, + momentum=args.mom) + print("using SGD") + else: + optimizer = optim.Adam(model.parameters(), + lr=args.lr) + print("using Adam") + + trainer = DROCCLFTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device) + + if args.eval==0: + # Training the model + trainer.train(train_loader, test_loader, closeneg_test_loader, args.lr, adjust_learning_rate, args.epochs, + ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs) + + trainer.save(args.model_dir) + + else: + if os.path.exists(os.path.join(args.model_dir, 'model.pt')): + trainer.load(args.model_dir) + print("Saved Model Loaded") + else: + print('Saved model not found. Cannot run evaluation.') + exit() + _, pos_scores, far_neg_scores = trainer.test(test_loader, get_auc=False) + _, _, close_neg_scores = trainer.test(closeneg_test_loader, get_auc=False) + + precision_fpr03, recall_fpr03 = cal_precision_recall(pos_scores, far_neg_scores, close_neg_scores, 0.03) + precision_fpr05, recall_fpr05 = cal_precision_recall(pos_scores, far_neg_scores, close_neg_scores, 0.05) + print('Test Precision @ FPR 3% : {}, Recall @ FPR 3%: {}'.format( + precision_fpr03, recall_fpr03)) + print('Test Precision @ FPR 5% : {}, Recall @ FPR 5%: {}'.format( + precision_fpr05, recall_fpr05)) + +if __name__ == '__main__': + torch.set_printoptions(precision=5) + + parser = argparse.ArgumentParser(description='PyTorch Simple Training') + parser.add_argument('--normal_class', type=int, default=0, metavar='N', + help='CIFAR10 normal class index') + parser.add_argument('--batch_size', type=int, default=128, metavar='N', + help='batch size for training') + parser.add_argument('--epochs', type=int, default=100, metavar='N', + help='number of epochs to train') + parser.add_argument('-oce,', '--only_ce_epochs', type=int, default=50, metavar='N', + help='number of epochs to train with only CE loss') + parser.add_argument('--ascent_num_steps', type=int, default=50, metavar='N', + help='Number of gradient ascent steps') + parser.add_argument('--hd', type=int, default=128, metavar='N', + help='Num hidden nodes for LSTM model') + parser.add_argument('--lr', type=float, default=0.001, metavar='LR', + help='learning rate') + parser.add_argument('--ascent_step_size', type=float, default=0.001, metavar='LR', + help='step size of gradient ascent') + parser.add_argument('--mom', type=float, default=0.99, metavar='M', + help='momentum') + parser.add_argument('--model_dir', default='log', + help='path where to save checkpoint') + parser.add_argument('--one_class_adv', type=int, default=1, metavar='N', + help='adv loss to be used or not, 1:use 0:not use(only CE)') + parser.add_argument('--radius', type=float, default=0.2, metavar='N', + help='radius corresponding to the definition of set N_i(r)') + parser.add_argument('--lamda', type=float, default=1, metavar='N', + help='Weight to the adversarial loss') + parser.add_argument('--reg', type=float, default=0, metavar='N', + help='weight reg') + parser.add_argument('--eval', type=int, default=0, metavar='N', + help='whether to load a saved model and evaluate (0/1)') + parser.add_argument('--optim', type=int, default=0, metavar='N', + help='0 : Adam 1: SGD') + parser.add_argument('--gamma', type=float, default=2.0, metavar='N', + help='r to gamma * r projection for the set N_i(r)') + parser.add_argument('-d', '--data_path', type=str, default='.') + args = parser. parse_args() + + # settings + #Checkpoint store path + model_dir = args.model_dir + if not os.path.exists(model_dir): + os.makedirs(model_dir) + use_cuda = torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + + main() diff --git a/examples/pytorch/DROCC/main_tabular.py b/examples/pytorch/DROCC/main_tabular.py index 9bd4563be..3728ab8b7 100644 --- a/examples/pytorch/DROCC/main_tabular.py +++ b/examples/pytorch/DROCC/main_tabular.py @@ -114,19 +114,24 @@ def main(): lr=args.lr) print("using Adam") - # Training the model trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device) - - # Restore from checkpoint - if args.restore == 1: + + if args.eval == 0: + # Training the model + trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs, + metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs) + + trainer.save(args.model_dir) + + else: if os.path.exists(os.path.join(args.model_dir, 'model.pt')): trainer.load(args.model_dir) print("Saved Model Loaded") - - trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs, - metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs) - - trainer.save(args.model_dir) + else: + print('Saved model not found. Cannot run evaluation.') + exit() + score = trainer.test(test_loader, 'F1') + print('Test F1: {}'.format(score)) if __name__ == '__main__': torch.set_printoptions(precision=5) @@ -158,8 +163,8 @@ def main(): help='Weight to the adversarial loss') parser.add_argument('--reg', type=float, default=0, metavar='N', help='weight reg') - parser.add_argument('--restore', type=int, default=0, metavar='N', - help='whether to load a pretrained model, 1: load 0: train from scratch') + parser.add_argument('--eval', type=int, default=0, metavar='N', + help='whether to load a saved model and evaluate (0/1)') parser.add_argument('--optim', type=int, default=0, metavar='N', help='0 : Adam 1: SGD') parser.add_argument('--gamma', type=float, default=2.0, metavar='N', diff --git a/examples/pytorch/DROCC/main_timeseries.py b/examples/pytorch/DROCC/main_timeseries.py index 42c685717..de9939e4f 100644 --- a/examples/pytorch/DROCC/main_timeseries.py +++ b/examples/pytorch/DROCC/main_timeseries.py @@ -114,19 +114,25 @@ def main(): lr=args.lr) print("using Adam") - # Training the model trainer = DROCCTrainer(model, optimizer, args.lamda, args.radius, args.gamma, device) - - # Restore from checkpoint - if args.restore == 1: + + if args.eval == 0: + # Training the model + trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs, + metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs) + + trainer.save(args.model_dir) + + else: if os.path.exists(os.path.join(args.model_dir, 'model.pt')): trainer.load(args.model_dir) print("Saved Model Loaded") - - trainer.train(train_loader, test_loader, args.lr, adjust_learning_rate, args.epochs, - metric=args.metric, ascent_step_size=args.ascent_step_size, only_ce_epochs = args.only_ce_epochs) + else: + print('Saved model not found. Cannot run evaluation.') + exit() + score = trainer.test(test_loader, 'AUC') + print('Test AUC: {}'.format(score)) - trainer.save(args.model_dir) if __name__ == '__main__': torch.set_printoptions(precision=5) @@ -158,8 +164,8 @@ def main(): help='Weight to the adversarial loss') parser.add_argument('--reg', type=float, default=0, metavar='N', help='weight reg') - parser.add_argument('--restore', type=int, default=1, metavar='N', - help='whether to load a pretrained model, 1: load 0: train from scratch ') + parser.add_argument('--eval', type=int, default=0, metavar='N', + help='whether to load a saved model and evaluate (0/1)') parser.add_argument('--optim', type=int, default=0, metavar='N', help='0 : Adam 1: SGD') parser.add_argument('--gamma', type=float, default=2.0, metavar='N', diff --git a/pytorch/README.md b/pytorch/README.md index 9a3909e44..a82fd34d2 100644 --- a/pytorch/README.md +++ b/pytorch/README.md @@ -28,8 +28,9 @@ for these algorithms are in `edgeml_pytorch.trainer`. 4. [S-RNN](https://github.com/microsoft/EdgeML/blob/master/docs/publications/SRNN.pdf): `edgeml_pytorch.graph.rnn.SRNN2` implements a 2 layer SRNN network which can be instantied with a choice of RNN cell. The training routine for SRNN is in `edgeml_pytorch.trainer.srnnTrainer`. -5. DROCC: `edgeml_pytorch.trainer.drocc_trainer` implements a meta-trainer for training any given model architecture - for one-class classification on the supplied dataset. +5. DROCC & DROCC-LF: `edgeml_pytorch.trainer.drocc_trainer` implements a DROCC meta-trainer for training any given model architecture + for one-class classification on the supplied dataset. `edgeml_pytorch.trainer.drocclf_trainer` implements the DROCC-LF varaint + for training models for one-class classification with limited negatives. Usage directions and examples notebooks for this package are provided [here](https://github.com/microsoft/EdgeML/blobl/master/examples/pytorch). diff --git a/pytorch/edgeml_pytorch/trainer/drocc_trainer.py b/pytorch/edgeml_pytorch/trainer/drocc_trainer.py index 8e40a2ba3..554e08ca8 100644 --- a/pytorch/edgeml_pytorch/trainer/drocc_trainer.py +++ b/pytorch/edgeml_pytorch/trainer/drocc_trainer.py @@ -1,4 +1,5 @@ import os +import copy import numpy as np import torch import torch.optim as optim @@ -20,7 +21,7 @@ def __init__(self, model, optimizer, lamda, radius, gamma, device): ---------- model: Torch neural network object optimizer: Total number of epochs for training. - lamda: Adversarial loss weight for input layer + lamda: Weight given to the adversarial loss radius: Radius of hypersphere to sample points from. gamma: Parameter to vary projection. device: torch.device object for device to use. @@ -51,6 +52,8 @@ def train(self, train_loader, val_loader, learning_rate, lr_scheduler, total_epo generation of negative points. metric: Metric used for evaluation (AUC / F1). """ + best_score = -np.inf + best_model = None self.ascent_num_steps = ascent_num_steps self.ascent_step_size = ascent_step_size for epoch in range(total_epochs): @@ -59,7 +62,7 @@ def train(self, train_loader, val_loader, learning_rate, lr_scheduler, total_epo lr_scheduler(epoch, total_epochs, only_ce_epochs, learning_rate, self.optimizer) #Placeholder for the respective 2 loss values - epoch_adv_loss = torch.tensor([0]).type(torch.float32).detach() #AdvLoss @ Input Layer + epoch_adv_loss = torch.tensor([0]).type(torch.float32).to(self.device) #AdvLoss epoch_ce_loss = 0 #Cross entropy Loss batch_idx = -1 @@ -86,10 +89,10 @@ def train(self, train_loader, val_loader, learning_rate, lr_scheduler, total_epo if epoch >= only_ce_epochs: data = data[target == 1] # AdvLoss - adv_loss_inp = self.one_class_adv_loss(data) - epoch_adv_loss += adv_loss_inp + adv_loss = self.one_class_adv_loss(data) + epoch_adv_loss += adv_loss - loss = ce_loss + adv_loss_inp * self.lamda + loss = ce_loss + adv_loss * self.lamda else: # If only CE based training has to be done loss = ce_loss @@ -99,13 +102,19 @@ def train(self, train_loader, val_loader, learning_rate, lr_scheduler, total_epo self.optimizer.step() epoch_ce_loss = epoch_ce_loss/(batch_idx + 1) #Average CE Loss - epoch_adv_loss = epoch_adv_loss/(batch_idx + 1) #Average AdvLoss @Input Layer + epoch_adv_loss = epoch_adv_loss/(batch_idx + 1) #Average AdvLoss test_score = self.test(val_loader, metric) - + if test_score > best_score: + best_score = test_score + best_model = copy.deepcopy(self.model) print('Epoch: {}, CE Loss: {}, AdvLoss: {}, {}: {}'.format( epoch, epoch_ce_loss.item(), epoch_adv_loss.item(), metric, test_score)) + self.model = copy.deepcopy(best_model) + print('\nBest test {}: {}'.format( + metric, best_score + )) def test(self, test_loader, metric): """Evaluate the model on the given test dataset. @@ -128,7 +137,7 @@ def test(self, test_loader, metric): logits = self.model(data) logits = torch.squeeze(logits, dim = 1) sigmoid_logits = torch.sigmoid(logits) - scores = sigmoid_logits + scores = logits label_score += list(zip(target.cpu().data.numpy().tolist(), scores.cpu().data.numpy().tolist())) # Compute test score @@ -208,4 +217,4 @@ def save(self, path): torch.save(self.model.state_dict(),os.path.join(path, 'model.pt')) def load(self, path): - self.model.load_state_dict(torch.load(os.path.join(path, 'model.pt'))) \ No newline at end of file + self.model.load_state_dict(torch.load(os.path.join(path, 'model.pt'))) diff --git a/pytorch/edgeml_pytorch/trainer/drocclf_trainer.py b/pytorch/edgeml_pytorch/trainer/drocclf_trainer.py new file mode 100644 index 000000000..8264b9b63 --- /dev/null +++ b/pytorch/edgeml_pytorch/trainer/drocclf_trainer.py @@ -0,0 +1,410 @@ +import os +import copy +import numpy as np +import torch +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score + +def cal_precision_recall(positive_scores, far_neg_scores, close_neg_scores, fpr): + """ + Computes the precision and recall for the given false positive rate. + """ + #combine the far and close negative scores + all_neg_scores = np.concatenate((far_neg_scores, close_neg_scores), axis = 0) + num_neg = all_neg_scores.shape[0] + idx = int((1-fpr) * num_neg) + #sort scores in ascending order + all_neg_scores.sort() + thresh = all_neg_scores[idx] + tp = np.sum(positive_scores > thresh) + recall = tp/positive_scores.shape[0] + fp = int(fpr * num_neg) + precision = tp/(tp+fp) + return precision, recall + + +def normalize_grads(grad): + """ + Utility function to normalize the gradients. + grad: (batch, -1) + """ + # make sum equal to the size of second dim + grad_norm = torch.sum(torch.abs(grad), dim=1) + grad_norm = torch.unsqueeze(grad_norm, dim = 1) + grad_norm = grad_norm.repeat(1, grad.shape[1]) + grad = grad/grad_norm * grad.shape[1] + return grad + +def compute_mahalanobis_distance(grad, diff, radius, device, gamma): + """ + Compute the mahalanobis distance. + grad: (batch,-1) + diff: (batch,-1) + """ + mhlnbs_dis = torch.sqrt(torch.sum(grad*diff**2, dim=1)) + #Categorize the batches based on mahalanobis distance + #lamda = 1 : mahalanobis distance < radius + #lamda = 2 : mahalanobis distance > gamma * radius + lamda = torch.zeros((grad.shape[0],1)) + lamda[mhlnbs_dis < radius] = 1 + lamda[mhlnbs_dis > (gamma * radius)] = 2 + return lamda, mhlnbs_dis + + +# The following are utitlity functions for checking the conditions in +# Proposition 1 in https://arxiv.org/abs/2002.12718 + +def check_left_part1(lam, grad, diff, radius, device): + #Part 1 condition value + n1 = diff**2 * lam**2 * grad**2 + d1 = (1 + lam * grad)**2 + 1e-10 + term = n1/d1 + term_sum = torch.sum(term) + return term_sum + +def check_left_part2(nu, grad, diff, radius, device, gamma): + #Part 2 condition value + n1 = diff**2 * grad**2 + d1 = (nu + grad)**2 + 1e-10 + term = n1/d1 + term_sum = torch.sum(term) + return term_sum + +def check_right_part1(lam, grad, diff, radius, device): + #Check if 'such that' condition is true in proposition 1 part 1 + n1 = grad + d1 = (1 + lam * grad)**2 + 1e-10 + term = diff**2 * n1/d1 + term_sum = torch.sum(term) + if term_sum > radius**2: + return check_left_part1(lam, grad, diff, radius, device) + else: + return np.inf + +def check_right_part2(nu, grad, diff, radius, device, gamma): + #Check if 'such that' condition is true in proposition 1 part 2 + n1 = grad*nu**2 + d1 = (nu + grad)**2 + 1e-10 + term = diff**2 * n1/d1 + term_sum = torch.sum(term) + if term_sum < (gamma*radius)**2: + return check_left_part2(nu, grad, diff, radius, device, gamma) + else: + # return torch.tensor(float('inf')) + return np.inf + +def range_lamda_lower(grad): + #Gridsearch range for lamda + lam, _ = torch.max(grad, dim=1) + eps, _ = torch.min(grad, dim=1) + lam = -1 / lam + eps*0.0001 + return lam + +def range_nu_upper(grad, mhlnbs_dis, radius, gamma): + #Gridsearch range for nu + alpha = (gamma*radius)/mhlnbs_dis + max_sigma, _ = torch.max(grad, dim=1) + nu = (alpha/(1-alpha))*max_sigma + return nu + +def optim_solver(grad, diff, radius, device, gamma=2): + """ + Solver for the optimization problem presented in Proposition 1 in + https://arxiv.org/abs/2002.12718 + """ + lamda, mhlnbs_dis = compute_mahalanobis_distance(grad, diff, radius, device, gamma) + lamda_lower_limit = range_lamda_lower(grad).detach().cpu().numpy() + nu_upper_limit = range_nu_upper(grad, mhlnbs_dis, radius, gamma).detach().cpu().numpy() + + #num of values of lamda and nu samples in the allowed range + num_rand_samples = 40 + final_lamda = torch.zeros((grad.shape[0],1)) + + #Solve optim for each example in the batch + for idx in range(lamda.shape[0]): + #Optim corresponding to mahalanobis dis < radius + if lamda[idx] == 1: + min_left = np.inf + best_lam = 0 + for k in range(num_rand_samples): + val = np.random.uniform(low = lamda_lower_limit[idx], high = 0) + left_val = check_right_part1(val, grad[idx], diff[idx], radius, device) + if left_val < min_left: + min_left = left_val + best_lam = val + + final_lamda[idx] = best_lam + + #Optim corresponding to mahalanobis dis > gamma * radius + elif lamda[idx] == 2: + min_left = np.inf + best_lam = np.inf + for k in range(num_rand_samples): + val = np.random.uniform(low = 0, high = nu_upper_limit[idx]) + left_val = check_right_part2(val, grad[idx], diff[idx], radius, device, gamma) + if left_val < min_left: + min_left = left_val + best_lam = val + + final_lamda[idx] = 1.0/best_lam + + else: + final_lamda[idx] = 0 + + final_lamda = final_lamda.to(device) + for j in range(diff.shape[0]): + diff[j,:] = diff[j,:]/(1+final_lamda[j]*grad[j,:]) + + return diff + +def get_gradients(model, device, data, target): + """ + Utility function to compute the gradients of the model on the + given data. + """ + total_train_pts = len(data) + data = data.to(torch.float) + target = target.to(torch.float) + target = torch.squeeze(target) + + #Extract the logits for cross entropy loss + data_copy = data + data_copy = data_copy.detach().requires_grad_() + # logits = model(data_copy) + logits = model(data_copy) + logits = torch.squeeze(logits, dim = 1) + ce_loss = F.binary_cross_entropy_with_logits(logits, target) + + grad = torch.autograd.grad(ce_loss, data_copy)[0] + + return torch.abs(grad) + +#trainer class for DROCC +class DROCCLFTrainer: + """ + Trainer class that implements the DROCC-LF algorithm proposed for + one-class classification with limited negative data presented in + https://arxiv.org/abs/2002.12718 + """ + + def __init__(self, model, optimizer, lamda, radius, gamma, device): + """Initialize the DROCC-LF Trainer class + + Parameters + ---------- + model: Torch neural network object + optimizer: Total number of epochs for training. + lamda: Weight given to the adversarial loss + radius: Radius of hypersphere to sample points from. + gamma: Parameter to vary projection. + device: torch.device object for device to use. + """ + self.model = model + self.optimizer = optimizer + self.lamda = lamda + self.radius = radius + self.gamma = gamma + self.device = device + + def train(self, train_loader, val_loader, closeneg_val_loader, learning_rate, lr_scheduler, total_epochs, + only_ce_epochs=50, ascent_step_size=0.001, ascent_num_steps=50): + """Trains the model on the given training dataset with periodic + evaluation on the validation dataset. + + Parameters + ---------- + train_loader: Dataloader object for the training dataset. + val_loader: Dataloader object for the validation dataset with far negatives. + closeneg_val_loader: Dataloader object for the validation dataset with close negatives. + learning_rate: Initial learning rate for training. + total_epochs: Total number of epochs for training. + only_ce_epochs: Number of epochs for initial pretraining. + ascent_step_size: Step size for gradient ascent for adversarial + generation of negative points. + ascent_num_steps: Number of gradient ascent steps for adversarial + generation of negative points. + """ + best_recall_fpr03 = -np.inf + best_precision_fpr03 = -np.inf + best_recall_fpr05 = -np.inf + best_precision_fpr05 = -np.inf + best_model = None + self.ascent_num_steps = ascent_num_steps + self.ascent_step_size = ascent_step_size + for epoch in range(total_epochs): + #Make the weights trainable + self.model.train() + lr_scheduler(epoch, total_epochs, only_ce_epochs, learning_rate, self.optimizer) + + #Placeholder for the respective 2 loss values + epoch_adv_loss = torch.tensor([0]).type(torch.float32).to(self.device) #AdvLoss + epoch_ce_loss = 0 #Cross entropy Loss + + batch_idx = -1 + for data, target, _ in train_loader: + batch_idx += 1 + data, target = data.to(self.device), target.to(self.device) + # Data Processing + data = data.to(torch.float) + target = target.to(torch.float) + target = torch.squeeze(target) + + self.optimizer.zero_grad() + + # Extract the logits for cross entropy loss + logits = self.model(data) + logits = torch.squeeze(logits, dim = 1) + ce_loss = F.binary_cross_entropy_with_logits(logits, target) + # Add to the epoch variable for printing average CE Loss + epoch_ce_loss += ce_loss + + ''' + Adversarial Loss is calculated only for the positive data points (label==1). + ''' + if epoch >= only_ce_epochs: + data = data[target == 1] + target = torch.ones(data.shape[0]).to(self.device) + gradients = get_gradients(self.model, self.device, data, target) + # AdvLoss + adv_loss = self.one_class_adv_loss(data, gradients) + epoch_adv_loss += adv_loss + + loss = ce_loss + adv_loss * self.lamda + else: + # If only CE based training has to be done + loss = ce_loss + + # Backprop + loss.backward() + self.optimizer.step() + + epoch_ce_loss = epoch_ce_loss/(batch_idx + 1) #Average CE Loss + epoch_adv_loss = epoch_adv_loss/(batch_idx + 1) #Average AdvLoss + + #normal val loader has the positive data and the far negative data + auc, pos_scores, far_neg_scores = self.test(val_loader, get_auc=True) + _, _, close_neg_scores = self.test(closeneg_val_loader, get_auc=False) + + precision_fpr03 , recall_fpr03 = cal_precision_recall(pos_scores, far_neg_scores, close_neg_scores, 0.03) + precision_fpr05 , recall_fpr05 = cal_precision_recall(pos_scores, far_neg_scores, close_neg_scores, 0.05) + if recall_fpr03 > best_recall_fpr03: + best_recall_fpr03 = recall_fpr03 + best_precision_fpr03 = precision_fpr03 + best_recall_fpr05 = recall_fpr05 + best_precision_fpr05 = precision_fpr05 + best_model = copy.deepcopy(self.model) + print('Epoch: {}, CE Loss: {}, AdvLoss: {}'.format( + epoch, epoch_ce_loss.item(), epoch_adv_loss.item())) + print('Precision @ FPR 3% : {}, Recall @ FPR 3%: {}'.format( + precision_fpr03, recall_fpr03)) + print('Precision @ FPR 5% : {}, Recall @ FPR 5%: {}'.format( + precision_fpr05, recall_fpr05)) + self.model = copy.deepcopy(best_model) + print('\nBest test Precision @ FPR 3% : {}, Recall @ FPR 3%: {}'.format( + best_precision_fpr03, best_recall_fpr03 + )) + print('\nBest test Precision @ FPR 5% : {}, Recall @ FPR 5%: {}'.format( + best_precision_fpr05, best_recall_fpr05 + )) + + def test(self, test_loader, get_auc = True): + """Evaluate the model on the given test dataset. + + Parameters + ---------- + test_loader: Dataloader object for the test dataset. + """ + label_score = [] + batch_idx = -1 + for data, target, _ in test_loader: + batch_idx += 1 + data, target = data.to(self.device), target.to(self.device) + data = data.to(torch.float) + target = target.to(torch.float) + target = torch.squeeze(target) + + logits = self.model(data) + logits = torch.squeeze(logits, dim = 1) + sigmoid_logits = torch.sigmoid(logits) + scores = sigmoid_logits + label_score += list(zip(target.cpu().data.numpy().tolist(), + scores.cpu().data.numpy().tolist())) + # Compute test score + labels, scores = zip(*label_score) + labels = np.array(labels) + scores = np.array(scores) + pos_scores = scores[labels==1] + neg_scores = scores[labels==0] + auc = -1 + if get_auc: + auc = roc_auc_score(labels, scores) + return auc, pos_scores, neg_scores + + + def one_class_adv_loss(self, x_train_data, gradients): + """Computes the adversarial loss: + 1) Sample points initially at random around the positive training + data points + 2) Gradient ascent to find the most optimal point in set N_i(r) + classified as +ve (label=0). This is done by maximizing + the CE loss wrt label 0 + 3) Project the points between spheres of radius R and gamma * R + (set N_i(r) with mahalanobis distance as a distance measure), + by solving the optimization problem + 4) Pass the calculated adversarial points through the model, + and calculate the CE loss wrt target class 0 + + Parameters + ---------- + x_train_data: Batch of data to compute loss on. + gradients: gradients of the model for the given data. + """ + batch_size = len(x_train_data) + # Randomly sample points around the training data + # We will perform SGD on these to find the adversarial points + x_adv = torch.randn(x_train_data.shape).to(self.device).detach().requires_grad_() + x_adv_sampled = x_adv + x_train_data + + for step in range(self.ascent_num_steps): + with torch.enable_grad(): + + new_targets = torch.zeros(batch_size, 1).to(self.device) + new_targets = torch.squeeze(new_targets) + new_targets = new_targets.to(torch.float) + + logits = self.model(x_adv_sampled) + logits = torch.squeeze(logits, dim = 1) + new_loss = F.binary_cross_entropy_with_logits(logits, new_targets) + + grad = torch.autograd.grad(new_loss, [x_adv_sampled])[0] + grad_norm = torch.norm(grad, p=2, dim = tuple(range(1, grad.dim()))) + grad_norm = grad_norm.view(-1, *[1]*(grad.dim()-1)) + grad_normalized = grad/grad_norm + with torch.no_grad(): + x_adv_sampled.add_(self.ascent_step_size * grad_normalized) + + if (step + 1) % 5==0: + # Project the normal points to the set N_i(r) based on mahalanobis distance + h = x_adv_sampled - x_train_data + h_flat = torch.reshape(h, (h.shape[0], -1)) + gradients_flat = torch.reshape(gradients, (gradients.shape[0], -1)) + #Normalize the gradients + gradients_normalized = normalize_grads(gradients_flat) + #Solve the non-convex 1D optimization + h_flat = optim_solver(gradients_normalized, h_flat, self.radius, self.device, self.gamma) + h = torch.reshape(h_flat, h.shape) + x_adv_sampled = x_train_data + h #These adv_points are now on the surface of hyper-sphere + + adv_pred = self.model(x_adv_sampled) + adv_pred = torch.squeeze(adv_pred, dim=1) + adv_loss = F.binary_cross_entropy_with_logits(adv_pred, (new_targets * 0)) + + return adv_loss + + def save(self, path): + torch.save(self.model.state_dict(),os.path.join(path, 'model.pt')) + + def load(self, path): + self.model.load_state_dict(torch.load(os.path.join(path, 'model.pt')))