Skip to content

Commit

Permalink
Merge pull request #149 from cglewis/master
Browse files Browse the repository at this point in the history
cleanup globals for tests
  • Loading branch information
cglewis authored Jul 18, 2018
2 parents baa939a + 2229ff7 commit cc17326
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
include:
- stage: test
script: docker build -f Dockerfile.base .
- script: make test && bash <(curl -s https://codecov.io/bash)
- script: find . -name requirements.txt -type f -exec pip3 install -r {} \; && pip3 uninstall -y poseidonml && pip3 install . && pytest -l -s -v --cov=tests/ --cov=utils/ --cov=DeviceClassifier/ --cov-report term-missing -c .coveragerc
before_install:
- sudo apt-get update
- sudo apt-get install docker-ce
Expand Down
54 changes: 35 additions & 19 deletions DeviceClassifier/OneLayer/eval_OneLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@
tf.logging.set_verbosity(tf.logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] ='3'

# Get time constant from config
with open('opts/config.json') as config_file:
config = json.load(config_file)
time_const = config['time constant']
state_size = config['state size']
duration = config['duration']
look_time = config['look time']
threshold = config['threshold']
batch_size = config['batch size']
rnn_size = config['rnn size']

def lookup_key(key):
'''
Expand All @@ -50,7 +40,7 @@ def lookup_key(key):

return address, None

def get_address_info(address, timestamp):
def get_address_info(address, timestamp, state_size):
'''
Look up address information prior to the timestamp
'''
Expand Down Expand Up @@ -142,6 +132,7 @@ def get_previous_state(source_ip, timestamp):
def average_representation(
representations,
timestamps,
time_const,
prev_representation=None,
last_update=None,
):
Expand Down Expand Up @@ -190,7 +181,8 @@ def update_data(
timestamps,
predictions,
other_ips,
model_hash
model_hash,
time_const
):
'''
Updates the stored data with the new information
Expand All @@ -213,12 +205,13 @@ def update_data(
last_update, prev_rep = get_previous_state(source_ip, timestamps[0])

# Compute current representation
time, current_rep = average_representation(representations, timestamps)
time, current_rep = average_representation(representations, timestamps, time_const)

# Compute moving average representation
time, avg_rep = average_representation(
representations,
timestamps,
time_const,
prev_representation=prev_rep,
last_update=last_update
)
Expand Down Expand Up @@ -276,7 +269,9 @@ def basic_decision(
timestamp,
labels,
confs,
abnormality
abnormality,
look_time,
threshold
):

valid = True
Expand Down Expand Up @@ -316,6 +311,22 @@ def basic_decision(
if __name__ == '__main__':
logger = logging.getLogger(__name__)

# Get time constant from config
try:
with open('opts/config.json') as config_file:
config = json.load(config_file)
time_const = config['time constant']
state_size = config['state size']
duration = config['duration']
look_time = config['look time']
threshold = config['threshold']
batch_size = config['batch size']
conf_labels = config['labels']
rnn_size = config['rnn size']
except Exception as e: # pragma: no cover
logger.error("unable to read 'opts/config.json' properly because: %s", str(e))
sys.exit(1)

# path to the pcap to get the update from
if len(sys.argv) < 2:
pcap_path = "/pcaps/eval.pcap"
Expand All @@ -333,7 +344,7 @@ def basic_decision(
else:
source_ip = None
except Exception as e:
logger.debug("Could not get address info beacuse %s", e)
logger.debug("Could not get address info beacuse %s", str(e))
logger.debug("Defaulting to inferring IP address from %s", pcap_path)
source_ip = None
key_address = None
Expand Down Expand Up @@ -371,6 +382,7 @@ def basic_decision(
_, mean_rep = average_representation(
reps,
timestamps,
time_const,
prev_representation=prev_rep,
last_update=last_update
)
Expand All @@ -388,7 +400,8 @@ def basic_decision(
timestamps,
preds,
others,
model_hash
model_hash,
time_const
)

# Get the sessions that the model looked at
Expand All @@ -415,10 +428,11 @@ def basic_decision(
logger.debug("Bypassing abnormality detection")
abnormality = 0
else:
abnormality = eval_pcap(pcap_path, label=labels[0])
abnormality = eval_pcap(pcap_path, conf_labels, time_const, label=labels[0], rnn_size=rnn_size)
repr_s, m_repr_s, _ , prev_s, _, _ = get_address_info(
source_ip,
timestamp
timestamp,
state_size
)
decision = basic_decision(
key,
Expand All @@ -427,7 +441,9 @@ def basic_decision(
timestamp,
labels,
confs,
abnormality
abnormality,
look_time,
threshold
)
logger.debug("Created message")
for i in range(3):
Expand Down
2 changes: 1 addition & 1 deletion DeviceClassifier/OneLayer/opts/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@
"time constant": 86400,
"threshold": 0.99,

"rnn size": 128,
"rnn size": 100,
"batch size": 32
}
54 changes: 35 additions & 19 deletions DeviceClassifier/RandomForest/eval_RandomForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,6 @@
tf.logging.set_verbosity(tf.logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] ='3'

# Get time constant from config
with open('opts/config.json') as config_file:
config = json.load(config_file)
time_const = config['time constant']
state_size = config['state size']
duration = config['duration']
look_time = config['look time']
threshold = config['threshold']
batch_size = config['batch size']
rnn_size = config['rnn size']

def lookup_key(key):
'''
Expand All @@ -50,7 +40,7 @@ def lookup_key(key):

return address, None

def get_address_info(address, timestamp):
def get_address_info(address, timestamp, state_size):
'''
Look up address information prior to the timestamp
'''
Expand Down Expand Up @@ -142,6 +132,7 @@ def get_previous_state(source_ip, timestamp):
def average_representation(
representations,
timestamps,
time_const,
prev_representation=None,
last_update=None,
):
Expand Down Expand Up @@ -190,7 +181,8 @@ def update_data(
timestamps,
predictions,
other_ips,
model_hash
model_hash,
time_const
):
'''
Updates the stored data with the new information
Expand All @@ -213,12 +205,13 @@ def update_data(
last_update, prev_rep = get_previous_state(source_ip, timestamps[0])

# Compute current representation
time, current_rep = average_representation(representations, timestamps)
time, current_rep = average_representation(representations, timestamps, time_const)

# Compute moving average representation
time, avg_rep = average_representation(
representations,
timestamps,
time_const,
prev_representation=prev_rep,
last_update=last_update
)
Expand Down Expand Up @@ -276,7 +269,9 @@ def basic_decision(
timestamp,
labels,
confs,
abnormality
abnormality,
look_time,
threshold
):

valid = True
Expand Down Expand Up @@ -316,6 +311,22 @@ def basic_decision(
if __name__ == '__main__':
logger = logging.getLogger(__name__)

# Get time constant from config
try:
with open('opts/config.json') as config_file:
config = json.load(config_file)
time_const = config['time constant']
state_size = config['state size']
duration = config['duration']
look_time = config['look time']
threshold = config['threshold']
batch_size = config['batch size']
conf_labels = config['labels']
rnn_size = config['rnn size']
except Exception as e: # pragma: no cover
logger.error("unable to read 'opts/config.json' properly because: %s", str(e))
sys.exit(1)

# path to the pcap to get the update from
if len(sys.argv) < 2:
pcap_path = "/pcaps/eval.pcap"
Expand All @@ -333,7 +344,7 @@ def basic_decision(
else:
source_ip = None
except Exception as e:
logger.debug("Could not get address info beacuse %s", e)
logger.debug("Could not get address info beacuse %s", str(e))
logger.debug("Defaulting to inferring IP address from %s", pcap_path)
source_ip = None
key_address = None
Expand Down Expand Up @@ -370,6 +381,7 @@ def basic_decision(
_, mean_rep = average_representation(
reps,
timestamps,
time_const,
prev_representation=prev_rep,
last_update=last_update
)
Expand All @@ -387,7 +399,8 @@ def basic_decision(
timestamps,
preds,
others,
model_hash
model_hash,
time_const
)

# Get the sessions that the model looked at
Expand Down Expand Up @@ -415,11 +428,12 @@ def basic_decision(
logger.debug("Bypassing abnormality detection")
abnormality = 0
else:
abnormality = eval_pcap(pcap_path, label=labels[0])
abnormality = eval_pcap(pcap_path, conf_labels, time_const, label=labels[0], rnn_size=rnn_size)

repr_s, m_repr_s, _ , prev_s, _, _ = get_address_info(
source_ip,
timestamp
timestamp,
state_size
)
decision = basic_decision(
key,
Expand All @@ -428,7 +442,9 @@ def basic_decision(
timestamp,
labels,
confs,
abnormality
abnormality,
look_time,
threshold
)
logger.debug("Created message")
for i in range(3):
Expand Down
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ build_onelayer: build_base
@pushd DeviceClassifier/OneLayer && docker build -t poseidonml:onelayer . && popd
build_randomforest: build_base
@pushd DeviceClassifier/RandomForest && docker build -t poseidonml:randomforest . && popd
test: install
pytest -l -s -v --cov=tests/ --cov=utils/ --cov=DeviceClassifier/ --cov-report term-missing
test-local: build_base
test: build_base
docker build -t poseidonml-test -f Dockerfile.test .
docker run -it --rm poseidonml-test
build_base:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_DeviceClassifier_OneLayer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from DeviceClassifier.OneLayer.test_OneLayer import calc_f1
from DeviceClassifier.OneLayer.eval_OneLayer import lookup_key
from DeviceClassifier.OneLayer.eval_OneLayer import get_address_info
from DeviceClassifier.OneLayer.eval_OneLayer import get_previous_state
from DeviceClassifier.OneLayer.eval_OneLayer import average_representation
from DeviceClassifier.OneLayer.eval_OneLayer import update_data
from DeviceClassifier.OneLayer.eval_OneLayer import basic_decision

def test_calc_f1():
calc_f1({})
6 changes: 6 additions & 0 deletions tests/test_DeviceClassifier_RandomForest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from DeviceClassifier.RandomForest.test_RandomForest import calc_f1
from DeviceClassifier.RandomForest.eval_RandomForest import lookup_key
from DeviceClassifier.RandomForest.eval_RandomForest import get_address_info
from DeviceClassifier.RandomForest.eval_RandomForest import get_previous_state
from DeviceClassifier.RandomForest.eval_RandomForest import average_representation
from DeviceClassifier.RandomForest.eval_RandomForest import update_data
from DeviceClassifier.RandomForest.eval_RandomForest import basic_decision

def test_calc_f1():
calc_f1({})
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from utils.eval_SoSModel import eval_pcap
from utils.featurizer import extract_features
from utils.iterator import BatchIterator
from utils.pcap_utils import is_private
Expand Down
21 changes: 12 additions & 9 deletions utils/eval_SoSModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,18 @@
logging.basicConfig(level=logging.INFO)
tf.logging.set_verbosity(tf.logging.ERROR)

# Load info from config
with open('opts/config.json') as config_file:
config = json.load(config_file)
rnn_size = config['rnn size']
labels = config['labels']

def eval_pcap(pcap, label=None):
def eval_pcap(pcap, labels, time_const, label=None, rnn_size=100):
logger = logging.getLogger(__name__)
data = create_dataset(pcap, label=label)
data = create_dataset(pcap, time_const, label=label)
# Create an iterator
iterator = BatchIterator(
data,
labels,
perturb_types=['random data']
)
logger.debug("Created iterator")
rnnmodel = SoSModel(rnn_size=100)
rnnmodel = SoSModel(rnn_size=rnn_size)
logger.debug("Created model")
rnnmodel.load(os.path.join(working_set.find(Requirement.parse('poseidonml')).location, 'poseidonml/models/SoSmodel'))
logger.debug("Loaded model")
Expand Down Expand Up @@ -73,5 +68,13 @@ def eval_pcap(pcap, label=None):
label = sys.argv[2]
else:
label = None
mean_score = eval_pcap(pcap,label=label)

# Load info from config
with open('opts/config.json') as config_file:
config = json.load(config_file)
rnn_size = config['rnn size']
labels = config['labels']
time_const = config['time constant']

mean_score = eval_pcap(pcap, labels, time_const, label=label, rnn_size=rnn_size)
print(mean_score)
Loading

0 comments on commit cc17326

Please sign in to comment.