<a href="https://colab.research.google.com/github/cmikke97/Automatic-Malware-Signature-Generation/blob/main/src/DetectionBase/DetectionBase_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Train and Evaluate ML detection model**

# **Needed packages**

In [1]:
!pip install -U logzero

Collecting logzero
  Downloading https://files.pythonhosted.org/packages/b3/68/aa714515d65090fcbcc9a1f3debd5a644b14aad11e59238f42f00bd4b298/logzero-1.7.0-py2.py3-none-any.whl
Installing collected packages: logzero
Successfully installed logzero-1.7.0


# **Set up Drive**

In [2]:
from google.colab import drive

# set path where to mount drive
drive_path = "/content/drive"

# mount drive
drive.mount(drive_path)

Mounted at /content/drive


# **Import needed modules**

In [3]:
import torch  # Tensor library like NumPy, with strong GPU support
from torch import nn  # a neural networks library deeply integrated with autograd designed for maximum flexibility
import torch.nn.functional as F # pytorch neural network functional interface
from torch.utils import data  # used to import data.Dataset -> we will subclass it; it will then be passed to data.Dataloader which is at the heart of PyTorch data loading utility
import numpy as np # The fundamental package for scientific computing with Python
import pandas as pd # Pandas is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool

from copy import deepcopy # Used to construct a new compound object and then, recursively, insert copies into it of the objects found in the original
from collections import defaultdict # Imports defaultdict from collections (which implements specialized container datatypes providing alternatives to Python’s general purpose built-in containers)
from sklearn.metrics import roc_auc_score # Used to compute the Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
from sklearn.metrics import roc_curve # Used to compute the Receiver operating characteristic (ROC) curve

import matplotlib # Comprehensive library for creating static, animated, and interactive visualizations in Python
matplotlib.use('Agg') # Select 'Agg' as the backend used for rendering and GUI integration
from matplotlib import pyplot as plt  # State-based interface to matplotlib, provides a MATLAB-like way of plotting

import lmdb # Python binding for the LMDB ‘Lightning’ Database
import sqlite3  # Provides a SQL interface compliant with the DB-API 2.0 specification
import msgpack # Efficient binary serialization format
import zlib # Allows compression and decompression, using the zlib library
import json # JSON encoder and decoder
import pickle # Implements binary protocols for serializing and de-serializing a Python object structure

import warnings # Warning control module
import sys  # System-specific parameters and functions
import os # Provides a portable way of using operating system dependent functionality
from multiprocessing import cpu_count # Used to get the number of CPUs in the system

import tqdm # Instantly makes loops show a smart progress meter
from logzero import logger # Robust and effective logging for Python

# **Configuration**

In [4]:
class Config(object):
    # NOTE -- if you change the "validation_test_split" and/or "train_validation_split" values, your results will not
    #         be comparable with those from other users of this data set.

    def __init__(self,
                 db_path, # path to the directory that contains the meta_db
                 checkpoint_dir, # path where to save the model training checkpoints to
                 runs = 5,  # how many times to run the model (training + evaluation) to plot mean and confidence of the results
                 device = 'cuda:0', # set this to the desired device, e.g. 'cuda:0' if a GPU is available, 'cpu' otherwise
                 validation_test_split = 1547279640.0,  # timestamp that divides the validation data (used to check convergence/overfitting) from test data (used to assess final performance)
                 train_validation_split = 1543542570.0, # timestamp that splits training data from validation data
                 batch_size = 8192):  # Dataloader batch size (change as needed given memory/bus constraints)
      
        self.device = device
        self.validation_test_split = validation_test_split
        self.train_validation_split = train_validation_split
        self.db_path = db_path
        self.checkpoint_dir = checkpoint_dir
        self.batch_size = batch_size
        self.runs = runs

        # create directory path if it does not exist (it succeeds even if the directory already exists)
        os.makedirs(checkpoint_dir, exist_ok=True)

# instantiate configuration object
config = Config(db_path = drive_path + "/MyDrive/thesis/Dataset/09-DEC-2020/processed-data",
                checkpoint_dir = drive_path + "/MyDrive/thesis/Checkpoints/",
                runs = 2)

# **Load Dataset**

In [5]:
class LMDBReader(object): # lmdb (lightning database) reader

    def __init__(self,
                 path,  # Location of lmdb database
                 postproc_func = None): # post processing function to apply to data points

        # open the lmdb (lightning database) -> the result is an open lmdb environment
        self.env = lmdb.open(path, # Location of directory
                             readonly = True, # Disallow any write operations
                             map_size = 1e13, # Maximum size database may grow to; used to size the memory mapping
                             max_readers = 1024)  # Maximum number of simultaneous read transactions
        
        # set self data post processing function
        self.postproc_func = postproc_func

    def __call__(self,
                 key):  # key (sha256) of the data point to retrieve

        # Execute a transaction on the database
        with self.env.begin() as txn:
            x = txn.get(key.encode('ascii')) # Fetch the first value matching key (encoded in ascii)

        if x is None: return None  # is no value was found matching key then return None
        # otherwise decompress the (x) bytes, returning a bytes object containing
        # the uncompressed data (x) and unpack it (from msgpack's array) to Python's list
        x = msgpack.loads(zlib.decompress(x), strict_map_key=False)

        if self.postproc_func is not None: # if the data post processing function was defined
            x = self.postproc_func(x) # apply post processing function on the data point
            
        return x  # return the data point


def features_postproc_func(x):  # data point to apply the post processing function to

    x = np.asarray(x[0], dtype=np.float32) # Convert the input (x[0]) to a numpy array of float32
    lz = x < 0  # create a numpy array of boolean -> lz[i] is true when x[i] < 0
    gz = x > 0  # create a numpy array of boolean -> lz[i] is true when x[i] > 0
    x[lz] = - np.log(1 - x[lz]) # if lz[i] is true -> assign x[i] = -np.log(1-x[i])
    x[gz] = np.log(1 + x[gz])   # if gz[i] is true -> assign x[i] = np.log(1+x[i])
    return x


def tags_postproc_func(x):  # data point to apply the post processing function to

    x = list(x[b'labels'].values()) # return datapoint labels as a list of labels
    x = np.asarray(x) # transform list to a numpy array of labels
    return x


class Dataset(data.Dataset):

    # list of malware tags
    tags = ["adware", "flooder", "ransomware", "dropper", "spyware", "packed",
            "crypto_miner", "file_infector", "installer", "worm", "downloader"]

    def __init__(self,
                 metadb_path, 	# path to the metadb (sqlite3 database containing index, labels, tags, and counts for the data)
                 features_lmdb_path,  # path to the features lmbd (database containing the data features)
                 return_malicious = True, # wether to return the malicious label for the data point or not
                 return_counts = True,  # wether to return the counts for the data point or not
                 return_tags = True,  # wether to return the tags for the data points or not
                 return_shas = False, # wether to return the sha256 of the data points or not
                 mode = 'train',  # mode of use of the dataset object (may be 'train', 'validation' or 'test')
                 binarize_tag_labels = True,  # wether to binarize or not the tag values
                 n_samples = None,  # maximum number of data points to consider (None if you want to consider them all)
                 remove_missing_features = True,  # wether to remove data points with missing features or not; it can be False/None/'scan'/filepath
                                                # in case it is 'scan' a scan will be performed on the database in order to remove the data points with missing features
                                                # in case it is a filepaht then a file (in Json format) will be used to determine the data points with missing features
                 postprocess_function = features_postproc_func):  # post processing function to use on each data point

        # set some attributes
        self.return_counts = return_counts
        self.return_tags = return_tags
        self.return_malicious = return_malicious
        self.return_shas = return_shas

        # define a lmdb reader with the features lmbd path (LMDB directory with baseline features) and post processing function
        self.features_lmdb_reader = LMDBReader(features_lmdb_path,
                                               postproc_func = postprocess_function)


        retrieve = ["sha256"] # initialize list of strings with "sha256"

        if return_malicious:
            retrieve += ["is_malware"]  # add to the list of strings "is_malware"

        if return_counts: 
            retrieve += ["rl_ls_const_positives"] # add to the list of strings "rl_ls_const_positives"

        if return_tags:
            retrieve.extend(Dataset.tags) # adds all the elements of tags list (iterable) to the end of the list of strings

        conn = sqlite3.connect(metadb_path) # connect to the sqlite3 database containing index, labels, tags, and counts for the data
        cur = conn.cursor() # create a cursor object for the db

        # create SQL query
        query = 'select ' + ','.join(retrieve)  # concatenate stringsf from the previously define list of strings with ','
        query += " from meta"

        # if in training select all data points before train_validation_split timestamp
        if mode == 'train':
            query += ' where(rl_fs_t <= {})'.format(config.train_validation_split)

        # if in validation select all data points between two timestamps (train_validation_split and validation_test_split)
        elif mode == 'validation':
            query += ' where((rl_fs_t >= {}) and (rl_fs_t < {}))'.format(config.train_validation_split, config.validation_test_split)
        
        # if in test select all data points after validation_test_split timestamp
        elif mode == 'test':
            query += ' where(rl_fs_t >= {})'.format(config.validation_test_split)

        # else provide an error
        else:
            raise ValueError('invalid mode: {}'.format(mode))

        # log info
        logger.info('Opening Dataset at {} in {} mode.'.format(metadb_path, mode))

        # if n_samples is not None then limit the query to output a maximum of n_samples rows
        if type(n_samples) != type(None):
            query += ' limit {}'.format(n_samples)

        vals = cur.execute(query).fetchall()  # execute the SQL query and fetch all results as a list
        conn.close()  # close database connection

        # log info
        logger.info(f"{len(vals)} samples loaded.")

        # map the items we're retrieving to an index (e.g. {'sha256': 0, 'is_malware': 1, ...})
        retrieve_ind = dict(zip(retrieve, list(range(len(retrieve)))))

        if remove_missing_features == 'scan': # if remove_missing_features is equal to the keyword 'scan'
            # log info
            logger.info("Removing samples with missing features...")

            indexes_to_remove = []  # initialize list of indexes to remove

            # log info
            logger.info("Checking dataset for keys with missing features.")
            
            # open the lmdb (lightning database) -> the result is an open lmdb environment
            temp_env = lmdb.open(features_lmdb_path,  # Location of directory
                                 readonly = True, # Disallow any write operations
                                 map_size = 1e13, # Maximum size database may grow to; used to size the memory mapping
                                 max_readers = 256) # Maximum number of simultaneous read transactions
            
            # Execute a transaction on the database
            with temp_env.begin() as txn:
                # perform a loop -> for index, item in decorated iterator over samples (from metadb)
                for index, item in tqdm.tqdm(enumerate(vals), # Iterable to decorate with a progressbar
                                             total = len(vals), # The number of expected iterations
                                             mininterval = .5,  # Minimum progress display update interval seconds
                                             smoothing = 0.): # Exponential moving average smoothing factor for speed estimates
                  
                    # if in the features lmbd no element with the specified sha256 (got by metadb item) is found
                    if txn.get(item[retrieve_ind['sha256']].encode('ascii')) is None:
                        indexes_to_remove.append(index) # add index to the list of indexes to remove

            indexes_to_remove = set(indexes_to_remove)  # create a set from list (duplicate values will be ignored)

            # remove from vals all the items that are in indexes_to_remove set
            vals = [value for index, value in enumerate(vals) if index not in indexes_to_remove]

            # log info
            logger.info(f"{len(indexes_to_remove)} samples had no associated feature and were removed.")
            logger.info(f"Dataset now has {len(vals)} samples.")

        elif (remove_missing_features is False) or (remove_missing_features is None):
            pass  # NOP

        else:
            # assume remove_missing_features is a filepath

            # log info
            logger.info(f"Trying to load shas to ignore from {remove_missing_features}...")

            # open file in read mode
            with open(remove_missing_features, 'r') as f:
                shas_to_remove = json.load(f) # deserialize from Json object to python object
            shas_to_remove = set(shas_to_remove)  # create a set from list (duplicate values will be ignored)

            # remove from vals all the items that are in indexes_to_remove set
            vals = [value for value in vals if value[retrieve_ind['sha256']] not in shas_to_remove]

            #log info
            logger.info(f"Dataset now has {len(vals)} samples.")

        # create a list of keys (sha256) from vals
        self.keylist = list(map(lambda x: x[retrieve_ind['sha256']], vals))

        if self.return_malicious:
            # create a list of labels from vals
            self.labels = list(map(lambda x: x[retrieve_ind['is_malware']], vals))

        if self.return_counts:
            # retrieve the list of counts from vals
            self.count_labels = list(map(lambda x: x[retrieve_ind['rl_ls_const_positives']], vals))

        if self.return_tags:
            # create a numpy array of lists of tags from vals
            self.tag_labels = np.asarray([list(map(lambda x: x[retrieve_ind[t]], vals)) for t in Dataset.tags]).T # Convert the input (list of tags per val in vals) to a nunpy array and get the transpose (.T)

            if binarize_tag_labels:
                # binarize the tag labels -> if the tag is different from 0 then it is set 1, otherwise it is set to 0
                self.tag_labels = (self.tag_labels != 0).astype(int)

    def __len__(self):

        return len(self.keylist)  # return the total number of samples

    def __getitem__(self,
                    index): # index of the item to get

        labels = {} # initialize labels set for this particular sample
        key = self.keylist[index] # get sha256 key associated to this index
        features = self.features_lmdb_reader(key) # get feature vector associated to this sample sha256

        if self.return_malicious:
            labels['malware'] = self.labels[index]  # get malware label for this sample through the index

        if self.return_counts:
            labels['count'] = self.count_labels[index]  # get count for this sample through the index

        if self.return_tags:
            labels['tags'] = self.tag_labels[index] # get tags list for this sample through the index

        if self.return_shas:
            return key, features, labels  # return sha256, features and labels associated to the sample with index 'index'
        else:
            return features, labels # return features and labels associated to the sample with index 'index'


# **Define Generator (Dataloader)**

In [6]:
# set max_workers to be equal to the current system cpu_count
max_workers = cpu_count()


class GeneratorFactory(object):

    def __init__(self,
                 ds_root, # path of the directory where to find the meta.db and ember_features files
                 batch_size = None, # how many samples per batch to load
                 mode = 'train',  # mode of use of the dataset object (may be 'train', 'validation' or 'test')
                 num_workers = max_workers, # how many subprocesses to use for data loading by the Dataloader
                 use_malicious_labels = False,  # wether to return the malicious label for the data points or not
                 use_count_labels = False,  # wether to return the counts for the data points or not
                 use_tag_labels = False,  # wether to return the tags for the data points or not
                 return_shas = False, # wether to return the sha256 of the data points or not
                 features_lmdb = 'ember_features',  # name of the file containing the ember_features for the data
                 remove_missing_features = 'scan',  # wether to remove data points with missing features or not; it can be False/None/'scan'/filepath
                                                  # in case it is 'scan' a scan will be performed on the database in order to remove the data points with missing features
                                                  # in case it is a filepaht then a file (in Json format) will be used to determine the data points with missing features
                 shuffle = None): # set to True to have the data reshuffled at every epoch
      
        # if mode is not in one of the expected values raise an exception
        if mode not in {'train', 'validation', 'test'}:
            raise ValueError('invalid mode {}'.format(mode))

        # define Dataset object pointing to the dataset databases (meta.db and ember_features)
        ds = Dataset(metadb_path = os.path.join(ds_root, 'meta.db'),  # join dataset_root path with the common name for the meta_db
                     features_lmdb_path = os.path.join(ds_root,
                                                     features_lmdb),  # join dataset_root path with the name of the file containing the ember_features
                     return_malicious = use_malicious_labels,
                     return_counts = use_count_labels,
                     return_tags = use_tag_labels,
                     return_shas = return_shas,
                     mode = mode,
                     remove_missing_features = remove_missing_features)
        
        # if the batch size was not defined (it was None) then set it to a default value of 1024
        if batch_size is None:
            batch_size = 1024
        
        # check passed-in value for shuffle; if it is not None it has to be either True or False
        if shuffle is not None:
            if not ( (shuffle is True) or (shuffle is False)):
                raise ValueError(f"'shuffle' should be either True or False, got {shuffle}")
        else:
            # if it is None then if mode of use is 'train' then set shuffle to True, otherwise to false
            if mode == 'train': shuffle = True
            else: shuffle = False

        # set up the parameters of the Dataloder
        params = {'batch_size': batch_size,
                  'shuffle': shuffle,
                  'num_workers': num_workers}

        # create Dataloader for the previously created dataset (ds) with the just specified parameters
        self.generator = data.DataLoader(ds, **params)

    def __call__(self):
        return self.generator


def get_generator(mode, # mode of use of the dataset object (may be 'train', 'validation' or 'test')
                  path = config.db_path,  # path of the directory where to find the meta.db and ember_features files
                  use_malicious_labels = True,  # wether to return the malicious label for the data points or not
                  use_count_labels = True,  # wether to return the counts for the data points or not
                  use_tag_labels = True,  # wether to return the tags for the data points or not
                  batch_size = config.batch_size, # how many samples per batch to load
                  return_shas = False,  # wether to return the sha256 of the data points or not
                  remove_missing_features = 'scan', # wether to remove data points with missing features or not; it can be False/None/'scan'/filepath
                                                  # in case it is 'scan' a scan will be performed on the database in order to remove the data points with missing features
                                                  # in case it is a filepaht then a file (in Json format) will be used to determine the data points with missing features
                  num_workers = None, # how many subprocesses to use for data loading by the Dataloader
                  shuffle = None, # set to True to have the data reshuffled at every epoch
                  feature_lmdb = 'ember_features'): # name of the file containing the ember_features for the data
  
    # if num_workers was not defined (it is None) then set it to the maximum number of workers previously defined as the current system cpu_count
    if num_workers is None:
        num_workers = max_workers

    # return the Generator (a.k.a. Dataloader)
    return GeneratorFactory(path,
                            batch_size = batch_size,
                            mode = mode,
                            num_workers = num_workers,
                            use_malicious_labels = use_malicious_labels,
                            use_count_labels = use_count_labels,
                            use_tag_labels = use_tag_labels,
                            return_shas = return_shas,
                            remove_missing_features = remove_missing_features,
                            shuffle = shuffle,
                            features_lmdb = feature_lmdb)()

# **Define Network**

In [7]:
class PENetwork(nn.Module):
    """
    This is a simple network loosely based on the one used in ALOHA: Auxiliary Loss Optimization for Hypothesis Augmentation (https://arxiv.org/abs/1903.05700)
    Note that it uses fewer (and smaller) layers, as well as a single layer for all tag predictions, performance will suffer accordingly.
    """
    def __init__(self,
                 use_malware = True,  # wether to use the malicious label for the data points or not
                 use_counts = True, # wether to use the counts for the data points or not
                 use_tags = True, # wether to use the tags for the data points or not
                 n_tags = None, # number of tags to predict
                 feature_dimension = 1024,  # dimension of the input data feature vector
                 layer_sizes = None): # layer sizes (array of sizes)
      
        # set some attributes
        self.use_malware = use_malware
        self.use_counts = use_counts
        self.use_tags = use_tags
        self.n_tags = n_tags

        if self.use_tags and self.n_tags == None: # if we set to use tags but n_tags was None raise an exception
            raise ValueError("n_tags was None but we're trying to predict tags. Please include n_tags")

        #super(PENetwork,self).__init__()
        super().__init__()  # call __init__() method of nn.Module

        # set dropout probability
        p = 0.05

        layers = [] # initialize layers array

        # if layer_sizes was not defined (it is None) then initialize it to a default of [512, 512, 128]
        if layer_sizes is None: layer_sizes = [512, 512, 128]

        # for each layer size in layer_sizes
        for i, ls in enumerate(layer_sizes):
            if i == 0:
                layers.append(nn.Linear(feature_dimension, ls)) # append the first Linear Layer with dimensions feature_dimension x ls
            else:
                layers.append(nn.Linear(layer_sizes[i-1], ls))  # append a Linear Layer with dimensions layer_sizes[i-1] x ls

            layers.append(nn.LayerNorm(ls)) # append a Norm layer of size ls
            layers.append(nn.ELU()) # append an ELU activation function module
            layers.append(nn.Dropout(p))  # append a dropout layer with probability of dropout p

        self.model_base = nn.Sequential(*tuple(layers)) # create a tuple from the layers list, then apply nn.Sequential to get a sequential container -> this will be the model base
        
        # create malware/benign labeling head
        self.malware_head = nn.Sequential(nn.Linear(layer_sizes[-1], 1),  # append a Linear Layer with size layer_sizes[-1] x 1
                                          nn.Sigmoid()) # append a sigmoid activation function module
        
        # create count poisson regression head
        self.count_head = nn.Linear(layer_sizes[-1], 1) # append a Linear Layer with size layer_sizes[-1] x 1

        # sigmoid activation function
        self.sigmoid = nn.Sigmoid()

        # create a tag multi-label classifting head
        self.tag_head = nn.Sequential(nn.Linear(layer_sizes[-1], 64),  # append a Linear Layer with size layer_sizes[-1] x 64
                                      nn.ELU(), # append an ELU activation function module
                                      nn.Linear(64, 64),  # append a Linear Layer with size 64 x 64
                                      nn.ELU(), # append an ELU activation function module
                                      nn.Linear(64, n_tags), # append a Linear Layer with size 64 x n_tags
                                      nn.Sigmoid()) # append a sigomid activation function module

    def forward(self,
                data):  # current batch of data (features)
      
        rv = {} # initialize return value

        base_result = self.model_base.forward(data) # get base result forwarding the data through the base model

        if self.use_malware:
            rv['malware'] = self.malware_head(base_result)  # append to return value the result of the malware head

        if self.use_counts:
            rv['count'] = self.count_head(base_result)  # append to return value the result of the count head

        if self.use_tags:
            rv['tags'] = self.tag_head(base_result) # append to return value the result of the tag head

        return rv # return the return value

# **Train Network**

## **Define training fuinction**

In [8]:
def compute_loss(predictions, # a dictionary of results from a PENetwork model
                 labels,  # a dictionary of labels
                 loss_wts = {'malware': 1.0,
                             'count': 0.1,
                             'tags': 0.1}): # weights to assign to each head of the network (if it exists)
    """
    Compute losses for a malware feed-forward neural network (optionally with SMART tags 
    and vendor detection count auxiliary losses).
    :param predictions: a dictionary of results from a PENetwork model
    :param labels: a dictionary of labels 
    :param loss_wts: weights to assign to each head of the network (if it exists); defaults to 
        values used in the ALOHA paper (1.0 for malware, 0.1 for count and each tag)
    """
    loss_dict = {'total':0.}  # initialize dictionary of losses

    if 'malware' in labels: # if the malware head is enabled
        # extract ground truth malware label, convert it to float and allocate it into the selected device (CPU or GPU)
        malware_labels = labels['malware'].float().to(config.device)

        # get predicted malware label, reshape it to the same shape of malware_labels
        # then calculate binary cross entropy loss with respect to the ground truth malware labels
        malware_loss = F.binary_cross_entropy(predictions['malware'].reshape(malware_labels.shape),
                                              malware_labels)

        # get loss weight (or set to default if not provided)
        weight = loss_wts['malware'] if 'malware' in loss_wts else 1.0

        # copy calculated malware loss into the loss dictionary
        loss_dict['malware'] = deepcopy(malware_loss.item())

        # update total loss
        loss_dict['total'] += malware_loss * weight

    if 'count' in labels: # if the count head is enabled
        # extract ground truth count, convert it to float and allocate it into the selected device (CPU or GPU)
        count_labels = labels['count'].float().to(config.device)

        # get predicted count, reshape it to the same shape of count_labels
        # then calculate poisson loss with respect to the ground truth count
        count_loss = torch.nn.PoissonNLLLoss()(predictions['count'].reshape(count_labels.shape),
                                               count_labels)

        # get loss weight (or set to default if not provided)
        weight = loss_wts['count'] if 'count' in loss_wts else 1.0

        # copy calculated count loss into the loss dictionary
        loss_dict['count'] = deepcopy(count_loss.item())

        # update total loss
        loss_dict['total'] += count_loss * weight

    if 'tags' in labels:  # if the tags head is enabled
        # extract ground truth tags, convert them to float and allocate them into the selected device (CPU or GPU)
        tag_labels = labels['tags'].float().to(config.device)

        # get predicted tags and then calculate binary cross entropy loss with respect to the ground truth tags
        tags_loss = F.binary_cross_entropy(predictions['tags'],
                                           tag_labels)

        # get loss weight (or set to default if not provided)
        weight = loss_wts['tags'] if 'tags' in loss_wts else 1.0

        # copy calculated tags loss into the loss dictionary
        loss_dict['tags'] = deepcopy(tags_loss.item())

        # update total loss
        loss_dict['total'] += tags_loss * weight

    return loss_dict  # return the losses


def train_network(train_db_path = config.db_path, # Path in which the meta.db is stored
                  checkpoint_dir = config.checkpoint_dir, # Directory in which to save model checkpoints
                  max_epochs = 10,  # How many epochs to train for
                  use_malicious_labels = True,  # Whether or not to use malware/benignware labels as a target
                  use_count_labels = True,  # Whether or not to use the counts as an additional target
                  use_tag_labels = True,  # Whether or not to use SMART tags as additional targets
                  feature_dimension = 2381, # The input dimension of the model
                  random_seed = None, # if provided, seed random number generation with this value (defaults None, no seeding)
                  workers = None, # How many worker processes should the dataloader use (if None use multiprocessing.cpu_count())
                  remove_missing_features = 'scan'):  # Strategy for removing missing samples, with meta.db entries but no associated features, from the data
                                                      # Must be one of: 'scan', 'none', or path to a missing keys file.  
                                                      # Setting to 'scan' (default) will check all entries in the LMDB and remove any keys that are missing -- safe but slow. 
                                                      # Setting to 'none' will not perform a check, but may lead to a run failure if any features are missing.  Setting to
                                                      # a path will attempt to load a json-serialized list of SHA256 values from the specified file, indicating which
                                                      # keys are missing and should be removed from the dataloader.
    """
    Train a feed-forward neural network on EMBER 2.0 features, optionally with additional targets as
    described in the ALOHA paper (https://arxiv.org/abs/1903.05700).  SMART tags based on
    (https://arxiv.org/abs/1905.06262)
    
    :param train_db_path: Path in which the meta.db is stored; defaults to the value specified in `config.py`
    :param checkpoint_dir: Directory in which to save model checkpoints; WARNING -- this will overwrite any existing checkpoints without warning.
    :param max_epochs: How many epochs to train for; defaults to 10
    :param use_malicious_labels: Whether or not to use malware/benignware labels as a target; defaults to True
    :param use_count_labels: Whether or not to use the counts as an additional target; defaults to True
    :param use_tag_labels: Whether or not to use SMART tags as additional targets; defaults to True
    :param feature_dimension: The input dimension of the model; defaults to 2381 (EMBER 2.0 feature size)
    :param random_seed: if provided, seed random number generation with this value (defaults None, no seeding)
    :param workers: How many worker processes should the dataloader use (default None, use multiprocessing.cpu_count())
    :param remove_missing_features: Strategy for removing missing samples, with meta.db entries but no associated features,
        from the data (e.g. feature extraction failures).  
        Must be one of: 'scan', 'none', or path to a missing keys file.  
        Setting to 'scan' (default) will check all entries in the LMDB and remove any keys that are missing -- safe but slow. 
        Setting to 'none' will not perform a check, but may lead to a run failure if any features are missing.  Setting to
        a path will attempt to load a json-serialized list of SHA256 values from the specified file, indicating which
        keys are missing and should be removed from the dataloader.
    """
    # if workers has a value (it is not None) then convert it to int
    workers = workers if workers is None else int(workers)

    # create checkpoint directory
    os.system('mkdir -p {}'.format(checkpoint_dir))

    if random_seed is not None: # if a seed was provided
        # log info
        logger.info(f"Setting random seed to {int(random_seed)}.")
        # set the seed for generating random numbers
        torch.manual_seed(int(random_seed))

    # log info
    logger.info('...instantiating network')

    # create malware-NN model and allocate it to the selected device (CPU or GPU)
    model = PENetwork(use_malware = True,
                      use_counts = True,
                      use_tags = True,
                      n_tags = len(Dataset.tags), # get n_tags counting tags from the dataset
                      feature_dimension = feature_dimension).to(config.device)

    # use Adam optimizer on all the model parameters 
    opt = torch.optim.Adam(model.parameters())

    # create generator (a.k.a. Dataloader)
    generator = get_generator(path = train_db_path,
                              mode = 'train', # select train mode
                              use_malicious_labels = use_malicious_labels,
                              use_count_labels = use_count_labels,
                              use_tag_labels = use_tag_labels,
                              num_workers = workers,
                              remove_missing_features = remove_missing_features)
    
    # create validation generator (a.k.a. validation Dataloader)
    val_generator = get_generator(path = train_db_path,
                                  mode = 'validation',  # select validation mode
                                  use_malicious_labels = use_malicious_labels,
                                  use_count_labels = use_count_labels,
                                  use_tag_labels = use_tag_labels,
                                  num_workers = workers,
                                  remove_missing_features = remove_missing_features)
    
    # get number of steps per epoch (# of total batches) from generator
    steps_per_epoch = len(generator)
    # get number of validation steps per epoch (# of total validation batches) from validation generator
    val_steps_per_epoch = len(val_generator)

    # loop for the selected number of epochs
    for epoch in range(1, max_epochs + 1):
        # instantiate a new dictionary-like object called loss_histories
        loss_histories = defaultdict(list)
        # set the model mode to 'train'
        model.train()

        # for all the training batches
        for i, (features, labels) in enumerate(generator):
            opt.zero_grad() # clear old gradients from the last step

            # copy current features and allocate them on the selected device (CPU or GPU)
            features = deepcopy(features).to(config.device)

            # perform a forward pass through the network
            out = model(features)
          
            # compute loss given the predicted output from the model
            loss_dict = compute_loss(out,
                                     deepcopy(labels))  # copy the ground truth labels

            # extract total loss                        
            loss = loss_dict['total']

            # compute gradients
            loss.backward()

            # update model parameters
            opt.step()

            # for all the calculated losses in loss_dict
            for k in loss_dict.keys():
                # if the loss is 'total' then append it to loss_histories['total'] after having detached it and passed it to the cpu
                if k == 'total': loss_histories[k].append(deepcopy(loss_dict[k].detach().cpu().item()))
                # otherwise append the loss to loss_histories without having to detach it
                else: loss_histories[k].append(loss_dict[k])

            # create loss string with the current losses
            loss_str = " ".join([f"{key} loss:{value:7.3f}" for key, value in loss_dict.items()])
            loss_str += " | "
            loss_str += " ".join([f"{key} mean:{np.mean(value):7.3f}" for key, value in loss_histories.items()])
            # write on standard out the loss string + other information
            sys.stdout.write('\r Epoch: {}/{} {}/{} '.format(epoch, max_epochs, i + 1, steps_per_epoch) + loss_str)
            # flush standard output
            sys.stdout.flush()
            del features, labels # to avoid weird references that lead to generator errors
        
        # save model in checkpoint dir
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, "epoch_{}.pt".format(str(epoch))))
        print()

        # instantiate a new dictionary-like object called loss_histories
        loss_histories = defaultdict(list)
        # set the model mode to 'eval'
        model.eval()

        # for all the validation batches
        for i, (features, labels) in enumerate(val_generator):
            # copy current features and allocate them on the selected device (CPU or GPU)
            features = deepcopy(features).to(config.device)

            with torch.no_grad(): # disable gradient calculation
                # perform a forward pass through the network
                out = model(features)

            # compute loss given the predicted output from the model
            loss_dict = compute_loss(out,
                                    deepcopy(labels)) # copy the ground truth labels
            
            # extract total loss 
            loss = loss_dict['total']

            # for all the calculated losses in loss_dict
            for k in loss_dict.keys():
                # if the loss is 'total' then append it to loss_histories['total'] after having detached it and passed it to the cpu
                if k == 'total': loss_histories[k].append(deepcopy(loss_dict[k].detach().cpu().item()))
                # otherwise append the loss to loss_histories without having to detach it
                else: loss_histories[k].append(loss_dict[k])

             # create loss string with the current losses
            loss_str = " ".join([f"{key} loss:{value:7.3f}" for key, value in loss_dict.items()])
            loss_str += " | "
            loss_str += " ".join([f"{key} mean:{np.mean(value):7.3f}" for key, value in loss_histories.items()])
            # write on standard out the loss string + other information
            sys.stdout.write('\r   Val: {}/{} {}/{} '.format(epoch, max_epochs, i + 1, val_steps_per_epoch) + loss_str)
            # flush standard output
            sys.stdout.flush()
            del features, labels # to avoid weird references that lead to generator errors
        print() 
    print('...done')


## **Start training**

In [None]:
# for the number of configured runs
for i in range(config.runs):
    # train network removing missing features using the 'shas_missing_ember_features.json' file
    train_network(checkpoint_dir = config.checkpoint_dir + "/" + str(i),
                  remove_missing_features = drive_path + "/MyDrive/thesis/Dataset/09-DEC-2020/processed-data/shas_missing_ember_features.json")

[I 210330 17:21:24 <ipython-input-8-faa596046a07>:122] ...instantiating network
[I 210330 17:21:33 <ipython-input-5-eb915f5a92e0>:118] Opening Dataset at /content/drive/MyDrive/thesis/Dataset/09-DEC-2020/processed-data/meta.db in train mode.
[I 210330 17:22:30 <ipython-input-5-eb915f5a92e0>:128] 12908755 samples loaded.
[I 210330 17:22:30 <ipython-input-5-eb915f5a92e0>:176] Trying to load shas to ignore from /content/drive/MyDrive/thesis/Dataset/09-DEC-2020/processed-data/shas_missing_ember_features.json...
[I 210330 17:22:34 <ipython-input-5-eb915f5a92e0>:187] Dataset now has 12699013 samples.
[I 210330 17:23:06 <ipython-input-5-eb915f5a92e0>:118] Opening Dataset at /content/drive/MyDrive/thesis/Dataset/09-DEC-2020/processed-data/meta.db in validation mode.
[I 210330 17:23:17 <ipython-input-5-eb915f5a92e0>:128] 2544786 samples loaded.
[I 210330 17:23:17 <ipython-input-5-eb915f5a92e0>:176] Trying to load shas to ignore from /content/drive/MyDrive/thesis/Dataset/09-DEC-2020/processed-da

# **Evaluate Network**

## **Define evaluation function**

In [None]:
# get tags from the dataset
all_tags = Dataset.tags


def detach_and_copy_array(array): # utility function to detach and (deep) copy an array
    if isinstance(array, torch.Tensor): # if the provided array is of type Tensor
        # return a copy of the array after having detached it, passed it to the cpu and finally flattened
        return deepcopy(array.cpu().detach().numpy()).ravel()
    elif isinstance(array, np.ndarray): # else if it is of type ndarray
        # return a copy of the array after having flattened it
        return deepcopy(array).ravel()
    else:
        # otherwise raise an exception
        raise ValueError("Got array of unknown type {}".format(type(array)))


def normalize_results(labels_dict,  # (ground truth) labels dictionary
                      results_dict, # results (predicted labels) dictionary
                      use_malware=True, # Whether or not to use malware/benignware labels as a target
                      use_count=True, # Whether or not to use the counts as an additional target
                      use_tags=True): # Whether or not to use SMART tags as additional targets
    """
    Take a set of results dicts and break them out into
    a single dict of 1d arrays with appropriate column names
    that pandas can convert to a DataFrame.
    """
    # we do a lot of deepcopy stuff here to avoid a FD "leak" in the dataset generator
    # see here: https://github.com/pytorch/pytorch/issues/973#issuecomment-459398189
    
    rv = {} # initialize return value dict

    if use_malware: # if the malware/benign target label is enabled
        # normalize malware ground truth label array and save it into rv
        rv['label_malware'] = detach_and_copy_array(labels_dict['malware'])
        # normalize malware predicted label array and save it into rv
        rv['pred_malware'] = detach_and_copy_array(results_dict['malware'])

    if use_count: # if the count additional target is enabled
        # normalize ground truth count array and save it into rv
        rv['label_count'] = detach_and_copy_array(labels_dict['count'])
        # normalize predicted count array and save it into rv
        rv['pred_count'] = detach_and_copy_array(results_dict['count'])

    if use_tags:  # if the SMART tags additional targets are enabled
        for column, tag in enumerate(all_tags): # for all the tags
            # normalize ground truth tag array and save it into rv
            rv[f'label_{tag}_tag'] = detach_and_copy_array(labels_dict['tags'][:, column])
            # normalize predicted tag array and save it into rv
            rv[f'pred_{tag}_tag'] = detach_and_copy_array(results_dict['tags'][:, column])

    return rv


def evaluate_network(results_dir, # The directory to which to write the 'results.csv' file
                     checkpoint_file, # The checkpoint file containing the weights to evaluate
                     db_path = config.db_path,  # The path to the directory containing the meta.db file
                     evaluate_malware = True, # Whether or not to record malware labels and predictions
                     evaluate_count = True, # Whether or not to record count labels and predictions
                     evaluate_tags = True,  # Whether or not to record individual tag labels and predictions
                     remove_missing_features = 'scan'): # Strategy for removing missing samples, with meta.db entries but no associated features, from the data
                                                        # Must be one of: 'scan', 'none', or path to a missing keys file.  
                                                        # Setting to 'scan' (default) will check all entries in the LMDB and remove any keys that are missing -- safe but slow. 
                                                        # Setting to 'none' will not perform a check, but may lead to a run failure if any features are missing.  Setting to
                                                        # a path will attempt to load a json-serialized list of SHA256 values from the specified file, indicating which
                                                        # keys are missing and should be removed from the dataloader.
    """
    Take a trained feedforward neural network model and output evaluation results to a csv in the specified location.
    :param results_dir: The directory to which to write the 'results.csv' file; WARNING -- this will overwrite any
        existing results in that location
    :param checkpoint_file: The checkpoint file containing the weights to evaluate
    :param db_path: the path to the directory containing the meta.db file; defaults to the value in config.py
    :param evaluate_malware: defaults to True; whether or not to record malware labels and predictions
    :param evaluate_count: defaults to True; whether or not to record count labels and predictions
    :param evaluate_tags: defaults to True; whether or not to record individual tag labels and predictions
    :param remove_missing_features: See help for remove_missing_features in train.py / train_network
    """

    # create result directory
    os.system('mkdir -p {}'.format(results_dir))

    # create malware-NN model
    model = PENetwork(use_malware = True,
                      use_counts = True,
                      use_tags = True,
                      n_tags = len(Dataset.tags), # get n_tags counting tags from the dataset
                      feature_dimension = 2381)
    
    # load model parameters from checkpoint
    model.load_state_dict(torch.load(checkpoint_file))

    # allocate model to selected device (CPU or GPU)
    model.to(config.device)

    # create test generator (a.k.a. test Dataloader)
    generator = get_generator(mode = 'test', # select test mode
                              path = db_path,
                              use_malicious_labels = evaluate_malware,
                              use_count_labels = evaluate_count,
                              use_tag_labels = evaluate_tags,
                              return_shas = True, # return sha256 keys
                              remove_missing_features = remove_missing_features)
    
    #log info
    logger.info('...running network evaluation')

    # create and open the results file in write mode
    f = open(os.path.join(results_dir,'results.csv'),'w')

    first_batch = True
    # for all the batches in the generator (Dataloader)
    for shas, features, labels in tqdm.tqdm(generator):
        features = features.to(config.device)  # transfer features to selected device
        predictions = model(features) # perform a forward pass through the network and get predictions
        results = normalize_results(labels, predictions)  # normalize the results
        
        # store results into a pandas dataframe (indexed by the sha265 keys)
        # and then save it as csv into file f (inserting the header only if this is the first batch in the loop)
        pd.DataFrame(results, index=shas).to_csv(f, header=first_batch)

        first_batch=False
    f.close() #close results file
    print('...done')



## **Start evaluation**

In [None]:
#instantiate results_files dictionary
results_files = {}

# for the number of configured runs
for i in range(config.runs):
    # add file path to results_files dictionary (used for plotting results)
    results_files["run_id_" + str(i)] = drive_path + "/MyDrive/thesis/Results/" + str(i) + "/results.csv";

    # evaluate network removing missing features using the 'shas_missing_ember_features.json' file
    evaluate_network(results_dir = drive_path + "/MyDrive/thesis/Results/" + str(i),
                     checkpoint_file = config.checkpoint_dir + "/" + str(i) + "/epoch_10.pt",
                     remove_missing_features = drive_path + "/MyDrive/thesis/Dataset/09-DEC-2020/processed-data/shas_missing_ember_features.json")
    
# create and open the results.json file in write mode
with open(drive_path + "/MyDrive/thesis/Results/results.json", "w") as output_file:
    # save results_files dictionary as a json file
    json.dump(results_files, output_file)

# **Plot Results**

## **Define plotting functions**

In [None]:
# define default tags
default_tags = ['adware_tag', 'flooder_tag', 'ransomware_tag',
                'dropper_tag', 'spyware_tag', 'packed_tag',
                'crypto_miner_tag', 'file_infector_tag', 'installer_tag',
                'worm_tag', 'downloader_tag']

# define default tag colors to be used in the graph
default_tag_colors = ['r', 'r', 'r',
                      'g', 'g', 'b',
                      'b', 'm', 'm',
                      'c', 'c']

# define default tag linestyles to be used in the graph
default_tag_linestyles = [':', '--', '-.',
                          ':', '--', ':',
                          '--', ':', '--',
                          ':', '--']

# combine the previously defined information into a "style" dictionary (e.g. {'adware_tag': ('r', ':'), ..})
style_dict = {tag: (color, linestyle) for tag, color, linestyle in zip(default_tags,
                                                                       default_tag_colors,
                                                                       default_tag_linestyles)}

# append style information for label 'malware'
style_dict['malware'] = ('k', '-')


def collect_dataframes(run_id_to_filename_dictionary):  # run ID - filename dictionary
    # instantiate loaded_dataframes
    loaded_dataframes = {}

    #for each element in the run ID - filename dictionary
    for k, v in run_id_to_filename_dictionary.items():
        # read comma-separated values (csv) file into a DataFrame and save it into loaded dataframes dictionary
        loaded_dataframes[k] = pd.read_csv(v)

    return loaded_dataframes  # return all loaded dataframes


def get_tprs_at_fpr(result_dataframe, # result dataframe for a certain run
                    key,  # the name of the result to get the curve for
                    target_fprs = None):  # The FPRs at which you wish to estimate the TPRs
    """
    Estimate the True Positive Rate for a dataframe/key combination
    at specific False Positive Rates of interest.
    :param result_dataframe: a pandas dataframe
    :param key: the name of the result to get the curve for; if (e.g.) the key 'malware' is provided
    the dataframe is expected to have a column names `pred_malware` and `label_malware`
    :param target_fprs: The FPRs at which you wish to estimate the TPRs; None (uses default np.array([1e-5, 1e-4, 1e-3, 1e-2, 1e-1]) or a 1-d numpy array
    :return: target_fprs, the corresponsing TPRs
    """

    # if target_fprs is not defined (it is None)
    if target_fprs is None:
        # set some defaults (numpy array)
        target_fprs = np.array([1e-5, 1e-4, 1e-3, 1e-2, 1e-1])

    # get ROC curve given the dataframe
    fpr, tpr, thresholds = get_roc_curve(result_dataframe, key)

    # return target_fprs and the intepolated values of the ROC curve (tpr/fpr) at points target_fprs
    return target_fprs, np.interp(target_fprs, fpr, tpr)


def get_roc_curve(result_dataframe, # result dataframe for a certain run
                  key): # the name of the result to get the curve for
    """
    Get the ROC curve for a single result in a dataframe
    :param result_dataframe: a dataframe
    :param key: the name of the result to get the curve for; if (e.g.) the key 'malware' is provided
    the dataframe is expected to have a column names `pred_malware` and `label_malware`
    :return: false positive rates, true positive rates, and thresholds (all np.arrays)
    """

    # extract labels from result dataframe
    labels = result_dataframe['label_{}'.format(key)]
    # extract predictions from result dataframe
    predictions = result_dataframe['pred_{}'.format(key)]

    # return the ROC curve calculated given the labels and predictions
    return roc_curve(labels, predictions)


def get_auc_score(result_dataframe, # result dataframe for a certain run
                  key): # the name of the result to get the curve for
    """
    Get the Area Under the Curve for the indicated key in the dataframe
    :param result_dataframe: a dataframe
    :param key: the name of the result to get the curve for; if (e.g.) the key 'malware' is provided
    the dataframe is expected to have a column names `pred_malware` and `label_malware`
    :return: the AUC for the ROC generated for the provided key
    """

    # extract labels from result dataframe
    labels = result_dataframe['label_{}'.format(key)]
    # extract predictions from result dataframe
    predictions = result_dataframe['pred_{}'.format(key)]

    # return the ROC AUC score given the labels and predictions
    return roc_auc_score(labels, predictions)


def interpolate_rocs(id_to_roc_dictionary,  # a list of results from get_roc_score (run ID - ROC curve dictionary)
                     eval_fpr_points = None): # the set of FPR values at which to interpolate the results
    """
    This function takes several sets of ROC results and interpolates them to a common set of
    evaluation (FPR) values to allow for computing e.g. a mean ROC or pointwise variance of the curve
    across multiple model fittings.
    :param list_of_rocs: a list of results from get_roc_score (or sklearn.metrics.roc_curve) of the
    form [(fpr_1, tpr_1, threshold_1), (fpr_2, tpr_2, threshold_2)...]
    :param eval_fpr_points: the set of FPR values at which to interpolate the results; defaults to
    `np.logspace(-6, 0, 1000)`
    :return:
        eval_fpr_points  -- the set of common points to which TPRs have been interpolated
        interpolated_tprs -- an array with one row for each ROC provided, giving the interpolated TPR for that ROC at
    the corresponding column in eval_fpr_points
    """

    # if eval_frp_points was not defined (it is None)
    if eval_fpr_points is None:
        # set some default evaluation false positive rate points (fpr points)
        eval_fpr_points = np.logspace(-6, 0, 1000)

    # instantiate interpolated_tprs dictionary
    interpolated_tprs = {}

    # for all the runs
    for k, (fpr, tpr, thresh) in id_to_roc_dictionary.items():
        # interpolate ROC curve (tpr/fpr) at points eval_fpr_points
        interpolated_tprs[k] = np.interp(eval_fpr_points, fpr, tpr)

    # return the eval_fpr_points and interpolated_tprs
    return eval_fpr_points, interpolated_tprs


def plot_roc_with_confidence(id_to_dataframe_dictionary,  # run ID - result dataframe dictionary
                             key, # the name of the result to get the curve for
                             filename,  # The filename to save the resulting figure to
                             include_range = False, # plot the min/max value as well
                             style = None,  # style (color, linestyle) to use in the plot
                             std_alpha = .2,  # the alpha value for the shading for standard deviation range
                             range_alpha = .1): # the alpha value for the shading for range, if plotted
    """
    Compute the mean and standard deviation of the ROC curve from a sequence of results
    and plot it with shading.
    """

    # if the length of the run ID - result dataframe dictionary is not grater than 1
    if not len(id_to_dataframe_dictionary) > 1:
        # raise an exception
        raise ValueError("Need a minimum of 2 result sets to plot confidence region; found {}".format(
            len(id_to_dataframe_dictionary)
        ))

    # if the style was not defined (it is None)
    if style is None:
        # if the key is present inside style_dict then use a default style
        if key in style_dict:
            color, linestyle = style_dict[key]
        else: # otherwise raise an exception
            raise ValueError("No default style information is available for key {}; please provide (linestyle, color)".format(key))

    else: # otherwise (the style was defined)
        linestyle, color = style  # get linestyle and color from style

    # calculate ROC curve for each run and create a run ID - ROC curve dictionary
    id_to_roc_dictionary = {k: get_roc_curve(df, key) for k, df in id_to_dataframe_dictionary.items()}
    
    # interpolate ROC curves and get fpr (false positive rate) points and interpolated tprs (true positive rates)
    fpr_points, interpolated_tprs = interpolate_rocs(id_to_roc_dictionary)

    # stack the interpolated_tprs arrays in sequence vertically -> I obtain a vertical vector of vectors (each of which has all the interpolated values for one single run)
    tpr_array = np.vstack([v for v in interpolated_tprs.values()])

    # calculate mean tpr along dim 0 -> (for each fpr point under examination I calculate the mean along all runs)
    mean_tpr = tpr_array.mean(0)

    # calculate tpr standard deviation by calculating the tpr variance along dim 0 and then calculating the square root
    # -> (for each fpr point under examination I calculate the standard deviation along all runs)
    std_tpr = np.sqrt(tpr_array.var(0))

    # calculate AUC (area under (ROC) curve) score for each run and store them into a numpy array
    aucs = np.array([get_auc_score(v, key) for v in id_to_dataframe_dictionary.values()])

    # calculate the mean ROC AUC score along all runs
    mean_auc = aucs.mean()
    # caluclate the min value for the ROC AUC score along all runs
    min_auc = aucs.min()
    # caluclate the max value for the ROC AUC score alonf all runs
    max_auc = aucs.max()
    # caluclate the standard deviation for the ROC AUC score along all runs
    # (by calculating the ROC AUC score variance and then taking the square root)
    std_auc = np.sqrt(aucs.var())

    # create a new figure of size 12 x 12
    plt.figure(figsize = (12, 12))

    # plot ROC curve
    plt.semilogx( # make a plot with log scaling on the x axis
        fpr_points, # false positive rate points as 'x' values
        mean_tpr,  # mean true positive rates as 'y' values
        color + linestyle, # format string, e.g. 'ro' for red circles
        linewidth = 2.0, # line width in points
        label = f"{key}: {mean_auc:5.3f}$\pm${std_auc:5.3f} [{min_auc:5.3f}-{max_auc:5.3f}]")  # label that will be displayed in the legend
    
    # fill uncertainty area around ROC curve
    plt.fill_between( # fill the area between two horizontal curves
        fpr_points,  # false positive rate points as 'x' values
        mean_tpr - std_tpr,  # mean - standard deviation of true positive rates as 'y' coordinates of the first curve
        mean_tpr + std_tpr,  # mean + standard deviation of true positive rates as 'y' coordinates of the second curve
        color = color, # set both the edgecolor and the facecolor
        alpha = std_alpha) # set the alpha value used for blending
    
    # if the user wants to plot the min/max value as well
    if include_range:
        # fill area between min and max ROC curve values 
        plt.fill_between(# fill the area between two horizontal curves
            fpr_points, # false positive rate points as 'x' values
            tpr_array.min(0), # min true positive rates as 'y' coordinates of the first curve
            tpr_array.max(0), # max true positive rates as 'y' coordinates of the second curve
            color = color,  # set both the edgecolor and the facecolor
            alpha = range_alpha)  # set the alpha value used for blending
    
    plt.legend()  # place legend on the axes
    plt.xlim(1e-6, 1.0) # set the x plot limits
    plt.ylim([0., 1.])  # set the y plot limits
    plt.xlabel('False Positive Rate (FPR)') # set the label for the x-axis
    plt.ylabel('True Positive Rate (TPR)')  # set the label for the y-axis
    plt.savefig(filename) # save the current figure to file
    plt.clf() # clear the current figure


def plot_tag_results(dataframe, # run ID - result dataframe dictionary
                     filename): # the name of the file in which to save the resulting plot
  
    # calculate ROC curve for each tag of the current (single) run and create a tag - ROC curve dictionary
    all_tag_rocs = {tag: get_roc_curve(dataframe, tag) for tag in default_tags}

    # interpolate ROC curves and get fpr (false positive rate) points and interpolated tprs (true positive rates)
    eval_fpr_pts, interpolated_rocs = interpolate_rocs(all_tag_rocs)

    # create a new figure of size 12 x 12
    plt.figure(figsize=(12, 12))

    # for each tag
    for tag in default_tags:
        # use a default style
        color, linestyle = style_dict[tag]
        
        # calculate AUC (area under (ROC) curve) score
        auc = get_auc_score(dataframe, tag)

        # plot ROC curve
        plt.semilogx( # make a plot with log scaling on the x axis
            eval_fpr_pts, # false positive rate points as 'x' values
            interpolated_rocs[tag], # interpolated true positive rates for the current tag as 'y' values
            color + linestyle,  # format string, e.g. 'ro' for red circles
            linewidth = 2.0,  # line width in points
            label = f"{tag}:{auc:5.3f}")  # label that will be displayed in the legend
        
    plt.legend(loc = 'best')  # place legend in the location, among the nine possible locations, with the minimum overlap with other drawn objects
    plt.xlim(1e-6, 1.0) # set the x plot limits
    plt.ylim([0., 1.])  # set the y plot limits
    plt.xlabel('False Positive Rate (FPR)') # set the label for the x-axis
    plt.ylabel('True Positive Rate (TPR)')  # set the label for the y-axis
    plt.savefig(filename) # save the current figure to file
    plt.clf() # clear the current figure


def plot_tag_result(results_file, # complete path to a results.csv file that contains the output of a model run
                    output_filename): # the name of the file in which to save the resulting plot
    """
    Takes a result file from a feedforward neural network model that includes all
    tags, and produces multiple overlaid ROC plots for each tag individually.
    :param results_file: complete path to a results.csv file that contains the output of 
        a model run.  Note that the model must have been trained with --use_tag_labels=True
        and evaluated using --evaluate_tags=True
    :param output_filename: the name of the file in which to save the resulting plot.
    """

    # create run ID - filename correspondence dictionary (containing just one result file)
    id_to_resultfile_dict = {'run': results_file}

    # read csv result file and obtain a run ID - result dataframe dictionary
    id_to_dataframe_dict = collect_dataframes(id_to_resultfile_dict)

    # produce multiple overlaid ROC plots (one for each tag individually) and save the overall figure to file
    plot_tag_results(id_to_resultfile_dict['run'], output_filename)


def plot_roc_distribution_for_tag(run_to_filename_json, #  A json file that contains a key-value map that links run IDs to the full path to a results file (including the file name)
                                  output_filename,  # The filename to save the resulting figure to
                                  tag_to_plot = 'malware',  # the tag from the results to plot
                                  linestyle = None, # the linestyle to use in the plot (if None use some defaults)
                                  color = None, # the color to use in the plot (if None use some defaults)
                                  include_range = False,  # plot the min/max value as well
                                  std_alpha = .2, # the alpha value for the shading for standard deviation range
                                  range_alpha = .1):  # the alpha value for the shading for range, if plotted
    """
    Compute the mean and standard deviation of the TPR at a range of FPRS (the ROC curve)
    over several sets of results (at least 2 runs) for a given tag.  The run_to_filename_json file must have
    the following format:
    {"run_id_0": "/full/path/to/results.csv/for/run/0/results.csv",
     "run_id_1": "/full/path/to/results.csv/for/run/1/results.csv",
      ...
    }
    
    :param run_to_filename_json: A json file that contains a key-value map that links run IDs to
        the full path to a results file (including the file name)
    :param output_filename: The filename to save the resulting figure to
    :param tag_to_plot: the tag from the results to plot; defaults to "malware"
    :param linestyle: the linestyle to use in the plot (defaults to the tag value in 
        plot.style_dict)
    :param color: the color to use in the plot (defaults to the tag value in 
        plot.style_dict)
    :param include_range: plot the min/max value as well (default False)
    :param std_alpha: the alpha value for the shading for standard deviation range
        (default 0.2)
    :param range_alpha: the alpha value for the shading for range, if plotted
        (default 0.1)
    """

    # open json containing run ID - filename correspondeces and decode it as json object
    id_to_resultfile_dict = json.load(open(run_to_filename_json, 'r'))

    # read csv result files and obtain a run ID - result dataframe dictionary
    id_to_dataframe_dict = collect_dataframes(id_to_resultfile_dict)

    if color is None or linestyle is None: # if either color or linestyle is None
        if not (color is None and linestyle is None): # if just one of them is None
            raise ValueError("both color and linestyle should either be specified or None") # raise an exception
        
        # otherwise select None as style
        style = None

    else:
        # otherwise (both color and linestyle were specified) define the style as a tuple of color and linestyle
        style = (color, linestyle)
        
    # plot roc curve with confidence
    plot_roc_with_confidence(id_to_dataframe_dict,
                             tag_to_plot,
                             output_filename,
                             include_range = include_range,
                             style = style,
                             std_alpha = std_alpha,
                             range_alpha = range_alpha)



## **Start plotting results**

In [None]:
# for the number of configured runs
for i in range(config.runs):
    plot_tag_result(results_file = drive_path + "/MyDrive/thesis/Results/" + str(i) + "results.csv",
                    output_filename = drive_path + "/MyDrive/thesis/Results/" + str(i) + "results.png")

plot_roc_distribution_for_tag(run_to_filename_json = drive_path + "/MyDrive/thesis/Results/results.json",
                              output_filename = drive_path + "/MyDrive/thesis/Results/results.png",
                              tag_to_plot = 'malware')