In [None]:
dataPath = '/data/fs4/datasets/pcaps/smallFlows.pcap'
modelPath = '/data/fs4/home/bradh/outputs/cpuVersion.pickle'
modelType = 'cpu' # 'gpu' or 'cpu'
dimIn=257
maxPackets=2
packetReverse = False
packetTimeSteps = 16
padOldTimeSteps=True


import os
os.environ['THEANO_FLAGS'] = 'floatX=float32,device=%s' % modelType

import json
import subprocess
import cPickle
import sys
import binascii
import logging
#import ipaddress

import numpy as np
import random
from copy import copy

import blocks
from blocks.bricks import Linear, Softmax, Softplus, NDimensionalSoftmax, BatchNormalizedMLP,\
                          Rectifier, Logistic, Tanh, MLP
from blocks.bricks.recurrent import GatedRecurrent, Fork, LSTM
from blocks.initialization import Constant, IsotropicGaussian, Identity, Uniform
from blocks.bricks.cost import BinaryCrossEntropy, CategoricalCrossEntropy
from blocks.filter import VariableFilter
from blocks.roles import PARAMETER
from blocks.graph import ComputationGraph

import theano
from theano import tensor as T

In [None]:
def parse_header(line):  # pragma: no cover
    ret_dict = {}
    h = line.split()
    if h[2] == 'IP6':
        """
        Conditional formatting based on ethernet type.
        IPv4 format: 0.0.0.0.port
        IPv6 format (one of many): 0:0:0:0:0:0.port
        """
        ret_dict['src_port'] = h[3].split('.')[-1]
        ret_dict['src_ip'] = h[3].split('.')[0]
        ret_dict['dest_port'] = h[5].split('.')[-1].split(':')[0]
        ret_dict['dest_ip'] = h[5].split('.')[0]
    else:
        if len(h[3].split('.')) > 4:
            ret_dict['src_port'] = h[3].split('.')[-1]
            ret_dict['src_ip'] = '.'.join(h[3].split('.')[:-1])
        else:
            ret_dict['src_ip'] = h[3]
            ret_dict['src_port'] = ''
        if len(h[5].split('.')) > 4:
            ret_dict['dest_port'] = h[5].split('.')[-1].split(':')[0]
            ret_dict['dest_ip'] = '.'.join(h[5].split('.')[:-1])
        else:
            ret_dict['dest_ip'] = h[5].split(':')[0]
            ret_dict['dest_port'] = ''
    return ret_dict


def parse_data(line):  # pragma: no cover
    ret_str = ''
    h, d = line.split(':', 1)
    ret_str = d.strip().replace(' ', '')
    return ret_str


def process_packet(output):  # pragma: no cover
    # TODO!! throws away the first packet!
    ret_header = {}
    ret_dict = {}
    ret_data = ''
    hasHeader = False
    for line in output:
        line = line.strip()
        if line:
            if not line.startswith('0x'):
                # header line
                if ret_dict and ret_data:
                    # about to start new header, finished with hex
                    ret_dict['data'] = ret_data
                    yield ret_dict
                    ret_dict.clear()
                    ret_header.clear()
                    ret_data = ''
                    hasHeader = False

                # parse next header
                try:
                    ret_header = parse_header(line)
                    ret_dict.update(ret_header)
                    hasHeader = True
                except:
                    ret_header.clear()
                    ret_dict.clear()
                    ret_data = ''
                    hasHeader = False

            else:
                # hex data line
                if hasHeader:
                    data = parse_data(line)
                    ret_data = ret_data + data
                else:
                    continue


def is_clean_packet(packet):  # pragma: no cover
    """
    Returns whether or not the parsed packet is valid
    or not. Checks that both the src and dest
    ports are integers. Checks that src and dest IPs
    are valid address formats. Checks that packet data
    is hex. Returns True if all tests pass, False otherwise.
    """
    if not packet['src_port'].isdigit(): return False
    if not packet['dest_port'].isdigit(): return False

    if packet['src_ip'].isalpha(): return False
    if packet['dest_ip'].isalpha(): return False

    if 'data' in packet:
        try:
            int(packet['data'], 16)
        except:
            return False

    return True


def order_keys(hexSessionDict):
    """
    Returns list of the hex sessions in (rough) time order.
    """
    orderedKeys = []

    for key in sorted(hexSessionDict.keys(), key=lambda key: hexSessionDict[key][1]):
        orderedKeys.append(key)

    return orderedKeys


def read_pcap(path):  # pragma: no cover
    print 'starting reading pcap file'
    hex_sessions = {} 
    proc = subprocess.Popen('tcpdump -nn -tttt -xx -r '+path,
                            shell=True,
                            stdout=subprocess.PIPE)
    insert_num = 0  # keeps track of insertion order into dict
    for packet in process_packet(proc.stdout):
        if not is_clean_packet(packet):
            continue
        if 'data' in packet:
            key = (packet['src_ip']+":"+packet['src_port'], packet['dest_ip']+":"+packet['dest_port'])
            rev_key = (key[1], key[0])
            if key in hex_sessions:
                hex_sessions[key][0].append(packet['data'])
            elif rev_key in hex_sessions:
                hex_sessions[rev_key][0].append(packet['data'])
            else:
                hex_sessions[key] = ([packet['data']], insert_num)
                insert_num += 1

    print 'finished reading pcap file'
    return hex_sessions

In [None]:
def removeBadSessionizer(hexSessionDict, saveFile=False, dataPath=None, fileName=None):  # pragma: no cover
    for ses in hexSessionDict.keys():
        paclens = []
        for pac in hexSessionDict[ses][0]:
            paclens.append(len(pac))
        if np.min(paclens)<80:
            del hexSessionDict[ses]

    if saveFile:
        print 'pickling sessions'
        pickleFile(hexSessionDict, filePath=dataPath, fileName=fileName)

    return hexSessionDict

In [None]:
def loadFile(filePath):  # pragma: no cover
    file2open = file(filePath, 'rb')
    loadedFile = cPickle.load(file2open)
    file2open.close()
    return loadedFile

In [None]:
def hexTokenizer():  # pragma: no cover
    hexstring = '''0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F,
                   10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1A, 1B,
                   1C, 1D, 1E, 1F, 20, 21, 22, 23, 24, 25, 26, 27,
                   28, 29, 2A, 2B, 2C, 2D, 2E, 2F, 30, 31, 32, 33,
                   34, 35, 36, 37, 38, 39, 3A, 3B, 3C, 3D, 3E, 3F,
                   40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 4A, 4B,
                   4C, 4D, 4E, 4F, 50, 51, 52, 53, 54, 55, 56, 57,
                   58, 59, 5A, 5B, 5C, 5D, 5E, 5F, 60, 61, 62, 63,
                   64, 65, 66, 67, 68, 69, 6A, 6B, 6C, 6D, 6E, 6F,
                   70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 7A, 7B,
                   7C, 7D, 7E, 7F, 80, 81, 82, 83, 84, 85, 86, 87,
                   88, 89, 8A, 8B, 8C, 8D, 8E, 8F, 90, 91, 92, 93,
                   94, 95, 96, 97, 98, 99, 9A, 9B, 9C, 9D, 9E, 9F,
                   A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, AA, AB,
                   AC, AD, AE, AF, B0, B1, B2, B3, B4, B5, B6, B7,
                   B8, B9, BA, BB, BC, BD, BE, BF, C0, C1, C2, C3,
                   C4, C5, C6, C7, C8, C9, CA, CB, CC, CD, CE, CF,
                   D0, D1, D2, D3, D4, D5, D6, D7, D8, D9, DA, DB,
                   DC, DD, DE, DF, E0, E1, E2, E3, E4, E5, E6, E7,
                   E8, E9, EA, EB, EC, ED, EE, EF, F0, F1, F2, F3,
                   F4, F5, F6, F7, F8, F9, FA, FB, FC, FD, FE, FF'''.replace('\t', '')

    hexList = [x.strip() for x in hexstring.lower().split(',')]
    hexList.append('<EOP>')  # End Of Packet token
    hexDict = {}

    for key, val in enumerate(hexList):
        if len(val) == 1:
            val = '0'+val
        hexDict[val] = key  #dictionary k=hex, v=int  

    return hexDict

In [None]:
def oneHot(index, granular = 'hex'):  # pragma: no cover
    if granular == 'hex':
        vecLen = 257
    else:
        vecLen = 17
    
    zeroVec = np.zeros(vecLen)
    zeroVec[index] = 1.0
    
    return zeroVec


def oneSessionEncoder(sessionPackets, hexDict, maxPackets=2, packetTimeSteps=100,
                      packetReverse=False, charLevel=False, padOldTimeSteps=True):  # pragma: no cover
    
    sessionCollect = []
    packetCollect = []
    
    if charLevel:
        vecLen = 17
    else:
        vecLen = 257
    
    if len(sessionPackets) > maxPackets: #crop the number of sessions to maxPackets
        sessionList = copy(sessionPackets[:maxPackets])
    else:
        sessionList = copy(sessionPackets)

    for packet in sessionList:
        packet = [hexDict[packet[i:i+2]] for i in xrange(0,len(packet)-2+1,2)]
            
        if len(packet) >= packetTimeSteps: #crop packet to length packetTimeSteps
            packet = packet[:packetTimeSteps]
            packet = packet+[256] #add <EOP> end of packet token
        else:
            packet = packet+[256] #add <EOP> end of packet token
        
        packetCollect.append(packet)
        
        pacMat = np.array([oneHot(x) for x in packet]) #one hot encoding of packet into a matrix
        pacMatLen = len(pacMat)
        
        #padding packet
        if packetReverse:
            pacMat = pacMat[::-1]

        if pacMatLen < packetTimeSteps:
            #pad by stacking zeros on top of data so that earlier timesteps do not have information
            #padding the packet such that zeros are after the actual info for better translation
            if padOldTimeSteps:
                pacMat = np.vstack( ( np.zeros((packetTimeSteps-pacMatLen,vecLen)), pacMat) ) 
            else:
                pacMat = np.vstack( (pacMat, np.zeros((packetTimeSteps-pacMatLen,vecLen))) ) 

        if pacMatLen > packetTimeSteps:
            pacMat = pacMat[:packetTimeSteps, :]

        sessionCollect.append(pacMat)

    #padding session
    sessionCollect = np.asarray(sessionCollect, dtype=theano.config.floatX)
    numPacketsInSession = sessionCollect.shape[0]
    if numPacketsInSession < maxPackets:
        #pad sessions to fit the
        sessionCollect = np.vstack( (sessionCollect,np.zeros((maxPackets-numPacketsInSession,
                                                             packetTimeSteps, vecLen))) )
    
    return sessionCollect, packetCollect

In [None]:
def predict(dataPath=dataPath, modelPath=modelPath, 
            dimIn=dimIn, maxPackets=maxPackets, packetReverse = packetReverse,
            packetTimeSteps = packetTimeSteps, padOldTimeSteps=padOldTimeSteps):
    
    print 'sessionizing pcap file'
    hexSessions = read_pcap(dataPath)
    hexDict = hexTokenizer()
    
    print 'loading model'
    prediction = loadFile(modelPath)

    trainingSessions = []
    
    print 'predicting pcap file'
    for session in hexSessions.keys():
        oneHotSes = oneSessionEncoder(hexSessions[session][0],
                                      hexDict = hexDict,
                                      packetReverse = packetReverse, 
                                      padOldTimeSteps = padOldTimeSteps, 
                                      maxPackets = maxPackets, 
                                      packetTimeSteps = packetTimeSteps)

        trainingSessions.append(oneHotSes[0])

    sessionsMinibatch = np.asarray(trainingSessions, dtype=theano.config.floatX)\
                                   .reshape((-1, packetTimeSteps, 1, dimIn))

    predprobs = prediction(sessionsMinibatch)
    predtargets = np.argmax(predprobs,axis=1)

    return predprobs, predtargets 

In [None]:
probs, tars = predict(dataPath=dataPath, modelPath=modelPath, 
            dimIn=dimIn, maxPackets=maxPackets, packetReverse = packetReverse,
            packetTimeSteps = packetTimeSteps, padOldTimeSteps=padOldTimeSteps)