In [None]:
import os
os.environ['THEANO_FLAGS'] = 'floatX=float32,device=gpu'

import json
import subprocess
import cPickle
import sys
import binascii
import multiprocessing as mp
from itertools import chain
from collections import OrderedDict
import logging

import numpy as np
import random
from copy import copy

import blocks
from blocks.bricks import Linear, Softmax, Softplus, NDimensionalSoftmax,\
                            BatchNormalizedMLP,Rectifier, Logistic, Tanh, MLP
from blocks.bricks.recurrent import GatedRecurrent, Fork, LSTM
from blocks.initialization import Constant, IsotropicGaussian, Identity, Uniform
from blocks.bricks.cost import BinaryCrossEntropy, CategoricalCrossEntropy
from blocks.filter import VariableFilter
from blocks.roles import PARAMETER
from blocks.graph import ComputationGraph

import theano
from theano import tensor as T

sys.setrecursionlimit(100000)

In [None]:
def parse_header(line):
    ret_dict = {}
    h = line.split()
    #ret_dict['direction'] = " ".join(h[3:6])
    if h[2] == 'IP6':
        """
        Conditional formatting based on ethernet type.
        IPv4 format: 0.0.0.0.port
        IPv6 format (one of many): 0:0:0:0:0:0.port
        """
        ret_dict['src_port'] = h[3].split('.')[-1]
        ret_dict['src_ip'] = h[3].split('.')[0]
        ret_dict['dest_port'] = h[5].split('.')[-1].split(':')[0]
        ret_dict['dest_ip'] = h[5].split('.')[0]
    else:
        if len(h[3].split('.')) > 4:
            ret_dict['src_port'] = h[3].split('.')[-1]
            ret_dict['src_ip'] = '.'.join(h[3].split('.')[:-1])
        else:
            ret_dict['src_ip'] = h[3]
            ret_dict['src_port'] = ''
        if len(h[5].split('.')) > 4:
            ret_dict['dest_port'] = h[5].split('.')[-1].split(':')[0]
            ret_dict['dest_ip'] = '.'.join(h[5].split('.')[:-1])
        else:
            ret_dict['dest_ip'] = h[5].split(':')[0]
            ret_dict['dest_port'] = ''
    return ret_dict

def parse_data(line):
    ret_str = ''
    h, d = line.split(':', 1)
    ret_str = d.strip().replace(' ', '')
    return ret_str

def process_packet(output):
    # TODO!! throws away the first packet!
    ret_header = {}
    ret_dict = {}
    ret_data = ''
    hasHeader = False
    for line in output:
        line = line.strip()
        if line:
            if not line.startswith('0x'):
                # header line
                if ret_dict and ret_data:
                    # about to start new header, finished with hex
                    ret_dict['data'] = ret_data
                    yield ret_dict
                    ret_dict.clear()
                    ret_header.clear()
                    ret_data = ''
                    hasHeader = False
                    
                # parse next header    
                try:
                    ret_header = parse_header(line)
                    ret_dict.update(ret_header)
                    hasHeader = True
                except:
                    ret_header.clear()
                    ret_dict.clear()
                    ret_data = ''
                    hasHeader = False
                    
            else:
                # hex data line
                if hasHeader:
                    data = parse_data(line)
                    ret_data = ret_data + data
                else:
                    continue
    
def is_clean_packet(packet):
    """
    Returns whether or not the parsed packet is valid
    or not. Checks that both the src and dest
    ports are integers. Checks that src and dest IPs
    are valid address formats. Checks that packet data
    is hex. Returns True if all tests pass, False otherwise.
    """
    if not packet['src_port'].isdigit(): return False
    if not packet['dest_port'].isdigit(): return False
    
    if packet['src_ip'].isalpha(): return False
    if packet['dest_ip'].isalpha(): return False
    #try:
    #    ipaddress.ip_address(packet['src_ip'])
    #    ipaddress.ip_address(packet['dest_ip'])
    #except:
    #    return False
     
    if 'data' in packet:
        try:
            int(packet['data'], 16)
        except:
            return False
    
    return True

def order_keys(hexSessionDict):
    orderedKeys = []
    
    for key in sorted(hexSessionDict.keys(), key=lambda key: hexSessionDict[key][1]):
        orderedKeys.append(key) 
        
    return orderedKeys

def read_pcap(path):
    hex_sessions = {}
    proc = subprocess.Popen('tcpdump -nn -tttt -xx -r '+path,
                            shell=True,
                            stdout=subprocess.PIPE)
    insert_num = 0  # keeps track of insertion order into dict
    for packet in process_packet(proc.stdout):
        if not is_clean_packet(packet):
            continue
        if 'data' in packet:
            key = (packet['src_ip']+":"+packet['src_port'], packet['dest_ip']+":"+packet['dest_port'])
            rev_key = (key[1], key[0])
            if key in hex_sessions:
                hex_sessions[key][0].append(packet['data'])
            elif rev_key in hex_sessions:
                hex_sessions[rev_key][0].append(packet['data'])
            else:
                hex_sessions[key] = ([packet['data']], insert_num)
                insert_num += 1
        
    return hex_sessions

def pickleFile(thing2save, file2save2 = None, filePath='/work/notebooks/drawModels/', fileName = 'myModels'):
    
    if file2save2 == None:
        f=file(filePath+fileName+'.pickle', 'wb')
    else:
        f=file(filePath+file2save2, 'wb')
        
    cPickle.dump(thing2save, f, protocol=cPickle.HIGHEST_PROTOCOL)

    f.close()
    
def loadFile(filePath):
    file2open = file(filePath, 'rb')
    loadedFile = cPickle.load(file2open)
    file2open.close()
    
    return loadedFile

def removeBadSessionizer(hexSessionDict, saveFile=False, dataPath=None, fileName=None):
    for ses in hexSessionDict.keys():
        paclens = []
        for pac in hexSessionDict[ses][0]:
            paclens.append(len(pac))
        if np.min(paclens)<80:
            del hexSessionDict[ses]

    if saveFile:
        print 'pickling sessions dictionary... mmm'
        pickleFile(hexSessionDict, filePath=dataPath, fileName=fileName)
        
        #with open(dataPath+'/'+fileName+'.pickle', 'wb') as handle:
        #    cPickle.dump(hexSessions, handle)
            
    return hexSessionDict

dataPath = '/data/fs4/datasets/pcaps/gregPcaps/'
dirList = os.listdir('/data/fs4/datasets/pcaps/gregPcaps/')
dirList

complicated = {}

for capture in dirList:
    dictName = capture.split('.')[0]
    
    start = time.time()
    
    hexSessions = read_pcap(dataPath+capture)
    hexSessions = removeBadSessionizer(hexSessions)
    complicated[dictName] = hexSessions
    
    end = time.time()
    
    print dictName + '  is done'
    print 'time to run (secs): ', (end - start)
    
    
pickleFile(complicated, filePath='/data/fs4/home/bradh/', fileName='complicated')

In [None]:
#big file
with open('complicated.pickle', 'rb') as unhandle:
     compDict= cPickle.load(unhandle)

In [None]:
sess = 0
for di in compDict.keys():
    sess += len(compDict[di].keys())
    print di, "  ", len(compDict[di].keys())

print sess

In [None]:
#%matplotlib inline

maxPackets = 2
packetTimeSteps = 28
loadPrepedData = True
dataPath = '/data/fs4/home/bradh/bigFlows.pickle'

packetReverse = False
padOldTimeSteps = True

runname = 'bakeoff_nomac_noip_noport'
rnnType = 'gru' #gru or lstm

wtstd = 0.2
dimIn = 257 #hex has 256 characters + the <EOP> character
dim = 100 #dimension reduction size
batch_size = 20
numClasses = 6
clippings = 1

epochs = 5000
lr = 0.0001
decay = 0.9
trainPercent = 0.8

module_logger = logging.getLogger(__name__)

import ast
import json
import subprocess
import sys


def pickleFile(thing2save, file2save2 = None, filePath='/work/notebooks/drawModels/', fileName = 'myModels'):
    
    if file2save2 == None:
        f=file(filePath+fileName+'.pickle', 'wb')
    else:
        f=file(filePath+file2save2, 'wb')
        
    cPickle.dump(thing2save, f, protocol=cPickle.HIGHEST_PROTOCOL)

    f.close()
    
def loadFile(filePath):
    file2open = file(filePath, 'rb')
    loadedFile = cPickle.load(file2open)
    file2open.close()
    
    return loadedFile


def removeBadSessionizer(hexSessionDict, saveFile=False, dataPath=None, fileName=None):
    for ses in hexSessionDict.keys():
        paclens = []
        for pac in hexSessionDict[ses][0]:
            paclens.append(len(pac))
        if np.min(paclens)<80:
            del hexSessionDict[ses]

    if saveFile:
        print 'pickling sessions'
        pickleFile(hexSessionDict, filePath=dataPath, fileName=fileName)
        
    return hexSessionDict


#Making the hex dictionary

#def dstPortSwapOneOut(hexSessionList):
    #THINK THROUGH    

def oneHot(index, granular = 'hex'):
    if granular == 'hex':
        vecLen = 257
    else:
        vecLen = 17
    
    zeroVec = np.zeros(vecLen)
    zeroVec[index] = 1.0
    
    return zeroVec


def oneSessionEncoder(sessionPackets, hexDict, maxPackets = 2, packetTimeSteps = 100,
                       packetReverse = False, charLevel = False, padOldTimeSteps = True):    
            
    sessionCollect = []
    packetCollect = []
    
    if charLevel:
        vecLen = 17
    else:
        vecLen = 257
    
    if len(sessionPackets) > maxPackets: #crop the number of sessions to maxPackets
        sessionList = copy(sessionPackets[:maxPackets])
    else:
        sessionList = copy(sessionPackets)

    for rawpacket in sessionList:
        packet = copy(rawpacket)
        packet = packet[24:52]
        #packet = packet[32:36]+packet[44:46]+packet[46:48]+packet[52:60]+packet[60:68]\
        #+packet[68:70]+packet[70:72]+packet[72:74]
        packet = [hexDict[packet[i:i+2]] for i in xrange(0,len(packet)-2+1,2)]
            
        if len(packet) >= packetTimeSteps: #crop packet to length packetTimeSteps
            packet = packet[:packetTimeSteps]
            packet = packet+[256] #add <EOP> end of packet token
        else:
            packet = packet+[256] #add <EOP> end of packet token
        
        packetCollect.append(packet)
        
        pacMat = np.array([oneHot(x) for x in packet]) #one hot encoding of packet into a matrix
        pacMatLen = len(pacMat)
        
        #padding packet
        if packetReverse:
            pacMat = pacMat[::-1]

        if pacMatLen < packetTimeSteps:
            #pad by stacking zeros on top of data so that earlier timesteps do not have information
            #padding the packet such that zeros are after the actual info for better translation
            if padOldTimeSteps:
                pacMat = np.vstack( ( np.zeros((packetTimeSteps-pacMatLen,vecLen)), pacMat) ) 
            else:
                pacMat = np.vstack( (pacMat, np.zeros((packetTimeSteps-pacMatLen,vecLen))) ) 

        if pacMatLen > packetTimeSteps:
            pacMat = pacMat[:packetTimeSteps, :]

        sessionCollect.append(pacMat)

    #padding session
    sessionCollect = np.asarray(sessionCollect, dtype=theano.config.floatX)
    numPacketsInSession = sessionCollect.shape[0]
    if numPacketsInSession < maxPackets:
        #pad sessions to fit the 
        sessionCollect = np.vstack( (sessionCollect,np.zeros((maxPackets-numPacketsInSession, 
                                                             packetTimeSteps, vecLen))) )
    
    return sessionCollect, packetCollect


# # Learning functions

# In[14]:

def floatX(X):
    return np.asarray(X, dtype=theano.config.floatX)

def dropout(X, p=0.):
    if p != 0:
        retain_prob = 1 - p
        X = X / retain_prob * srng.binomial(X.shape, p=retain_prob, dtype=theano.config.floatX)
    return X

# Gradient clipping
def clip_norm(g, c, n): 
    '''n is the norm, c is the threashold, and g is the gradient'''
    
    if c > 0: 
        g = T.switch(T.ge(n, c), g*c/n, g) 
    return g

def clip_norms(gs, c):
    norm = T.sqrt(sum([T.sum(g**2) for g in gs]))
    return [clip_norm(g, c, norm) for g in gs]

# Regularizers
def max_norm(p, maxnorm = 0.):
    if maxnorm > 0:
        norms = T.sqrt(T.sum(T.sqr(p), axis=0))
        desired = T.clip(norms, 0, maxnorm)
        p = p * (desired/ (1e-7 + norms))
    return p

def gradient_regularize(p, g, l1 = 0., l2 = 0.):
    g += p * l2
    g += T.sgn(p) * l1
    return g

def weight_regularize(p, maxnorm = 0.):
    p = max_norm(p, maxnorm)
    return p

def Adam(params, cost, lr=0.0002, b1=0.1, b2=0.001, e=1e-8, l1 = 0., l2 = 0., maxnorm = 0., c = 8):
    
    updates = []
    grads = T.grad(cost, params)
    grads = clip_norms(grads, c)
    
    i = theano.shared(floatX(0.))
    i_t = i + 1.
    fix1 = 1. - b1**(i_t)
    fix2 = 1. - b2**(i_t)
    lr_t = lr * (T.sqrt(fix2) / fix1)
    
    for p, g in zip(params, grads):
        m = theano.shared(p.get_value() * 0.)
        v = theano.shared(p.get_value() * 0.)
        m_t = (b1 * g) + ((1. - b1) * m)
        v_t = (b2 * T.sqr(g)) + ((1. - b2) * v)
        g_t = m_t / (T.sqrt(v_t) + e)
        g_t = gradient_regularize(p, g_t, l1=l1, l2=l2)
        p_t = p - (lr_t * g_t)
        p_t = weight_regularize(p_t, maxnorm=maxnorm)
        
        updates.append((m, m_t))
        updates.append((v, v_t))
        updates.append((p, p_t))
    
    updates.append((i, i_t))
    
    return updates

def RMSprop(cost, params, lr = 0.001, l1 = 0., l2 = 0., maxnorm = 0., rho=0.9, epsilon=1e-6, c = 8):
    
    grads = T.grad(cost, params)
    grads = clip_norms(grads, c)
    updates = []
    
    for p, g in zip(params, grads):
        g = gradient_regularize(p, g, l1 = l1, l2 = l2)
        acc = theano.shared(p.get_value() * 0.)
        acc_new = rho * acc + (1 - rho) * g ** 2
        updates.append((acc, acc_new))
        
        updated_p = p - lr * (g / T.sqrt(acc_new + epsilon))
        updated_p = weight_regularize(updated_p, maxnorm = maxnorm)
        updates.append((p, updated_p))
    return updates


# # Training functions
def predictClass(predictFun, sampleList, compDict, hexDict,
                 numClasses = 6, trainPercent = 0.8, dimIn=257, maxPackets=2,
                 packetTimeSteps = 16, padOldTimeSteps=True):
    
    testCollect = []
    predtargets = []
    actualtargets = []
        
    trainingSessions = []
    trainingTargets = []
    
    for d in range(len(sampleList)):
        sampleLen = len(compDict[sampleList[d]].keys()) #num sessions in a dictionary
        sampleKeys = compDict[sampleList[d]].keys()[-400:]
        for key in sampleKeys:
            oneEncoded = oneSessionEncoder(compDict[sampleList[d]][key][0],
                                                      hexDict = hexDict,
                                                      packetReverse=packetReverse, 
                                                      padOldTimeSteps = padOldTimeSteps, 
                                                      maxPackets = maxPackets, 
                                                      packetTimeSteps = packetTimeSteps)
            trainingSessions.append(oneEncoded[0])
            trainIndex = [0]*numClasses
            trainIndex[d] = 1
            trainingTargets.append(trainIndex)

    sessionsMinibatch = np.asarray(trainingSessions, dtype=theano.config.floatX).reshape((-1, packetTimeSteps, 1, dimIn))
    targetsMinibatch = np.asarray(trainingTargets, dtype=theano.config.floatX)

    predcostfun = predictFun(sessionsMinibatch)
    testCollect.append(np.mean(np.argmax(predcostfun,axis=1) == np.argmax(targetsMinibatch, axis=1)))

    predtargets = np.argmax(predcostfun,axis=1)
    actualtargets = np.argmax(targetsMinibatch, axis=1)

    print "TEST accuracy:         ", np.mean(testCollect)
    print

    return actualtargets, predtargets, np.mean(testCollect)


def binaryPrecisionRecall(predictions, targets, numClasses = 6):
    for cla in range(numClasses):
        
        confustop = np.array([])
        confusbottom = np.array([])

        predictions = np.asarray(predictions).flatten()
        targets = np.asarray(targets).flatten()

        pred1 = np.where(predictions == cla)
        pred0 = np.where(predictions != cla)
        target1 = np.where(targets == cla)
        target0 = np.where(targets != cla)

        truePos = np.intersect1d(pred1[0],target1[0]).shape[0]
        trueNeg = np.intersect1d(pred0[0],target0[0]).shape[0]
        falsePos = np.intersect1d(pred1[0],target0[0]).shape[0]
        falseNeg = np.intersect1d(pred0[0],target1[0]).shape[0]

        top = np.append(confustop, (truePos, falsePos))
        bottom = np.append(confusbottom, (falseNeg, trueNeg))
        confusionMatrix = np.vstack((top, bottom))
        
        precision  = float(truePos)/(truePos + falsePos + 0.00001) #1 - (how much junk did we give user)
        recall = float(truePos)/(truePos + falseNeg + 0.00001) #1 - (how much good stuff did we miss)
        f1 = 2*((precision*recall)/(precision+recall+0.00001))
        
        print 'class '+str(cla)+' precision: ', precision
        print 'class '+str(cla)+' recall:    ', recall
        print 'class '+str(cla)+' f1:        ', f1
        print
    

In [None]:
sampleList = ['NESTthermostat-nf-10days-96bytes',
 'a-printers-24hrs-96bytes-E-VA-SRV-FW1A-2016-08-09_14-17-vlan34',
 'SonySmartTV-nf-10days-96bytes',
 'a-fs-24hrs-96bytes-E-ASH-SRV-FW1A-2016-08-09_18-17-vlan40',
 'TiVoSeries4-nf-10days-96bytes',
 'b-dc-24hrs-96bytes-E-QD-SRV-FW1A-2016-08-09_18-16-vlan210']

hexDict = hexTokenizer()
trainingTargets = []
trainingSessions = []
for d in range(len(sampleList)):
    sampleLen = len(compDict[sampleList[d]].keys())
    sampleKeys = random.sample(compDict[sampleList[d]].keys()[:sampleLen], 5)

    for key in sampleKeys:
        oneEncoded = oneSessionEncoder(compDict[sampleList[d]][key][0],
                                                  hexDict = hexDict,
                                                  packetReverse=packetReverse, 
                                                  padOldTimeSteps = padOldTimeSteps, 
                                                  maxPackets = maxPackets, 
                                                  packetTimeSteps = packetTimeSteps)
        trainIndex = [0]*numClasses
        trainIndex[d] = 1
        trainingTargets.append(trainIndex)
        trainingSessions.append(oneEncoded[0])


In [None]:
#Making the hex dictionary
def hexTokenizer():
    hexstring = '0,	1,	2,	3,	4,	5,	6,	7,	8,	9,	A,	B,	C,	D,	E,	F,	10,	11,	12,	13,	14,	15,	16,	17,	18,	19\
    ,	1A,	1B,	1C,	1D,	1E,	1F,	20,	21,	22,	23,	24,	25,	26,	27,	28,	29,	2A,	2B,	2C,	2D,	2E,	2F,	30,	31,	32,	33,	34,	35\
    ,	36,	37,	38,	39,	3A,	3B,	3C,	3D,	3E,	3F,	40,	41,	42,	43,	44,	45,	46,	47,	48,	49,	4A,	4B,	4C,	4D,	4E,	4F,	50,	51\
    ,	52,	53,	54,	55,	56,	57,	58,	59,	5A,	5B,	5C,	5D,	5E,	5F,	60,	61,	62,	63,	64,	65,	66,	67,	68,	69,	6A,	6B,	6C,	6D\
    ,	6E,	6F,	70,	71,	72,	73,	74,	75,	76,	77,	78,	79,	7A,	7B,	7C,	7D,	7E,	7F,	80,	81,	82,	83,	84,	85,	86,	87,	88,	89\
    ,	8A,	8B,	8C,	8D,	8E,	8F,	90,	91,	92,	93,	94,	95,	96,	97,	98,	99,	9A,	9B,	9C,	9D,	9E,	9F,	A0,	A1,	A2,	A3,	A4,	A5\
    ,	A6,	A7,	A8,	A9,	AA,	AB,	AC,	AD,	AE,	AF,	B0,	B1,	B2,	B3,	B4,	B5,	B6,	B7,	B8,	B9,	BA,	BB,	BC,	BD,	BE,	BF,	C0,	C1\
    ,	C2,	C3,	C4,	C5,	C6,	C7,	C8,	C9,	CA,	CB,	CC,	CD,	CE,	CF,	D0,	D1,	D2,	D3,	D4,	D5,	D6,	D7,	D8,	D9,	DA,	DB,	DC,	DD\
    ,	DE,	DF,	E0,	E1,	E2,	E3,	E4,	E5,	E6,	E7,	E8,	E9,	EA,	EB,	EC,	ED,	EE,	EF,	F0,	F1,	F2,	F3,	F4,	F5,	F6,	F7,	F8,	F9\
    ,	FA,	FB,	FC,	FD,	FE,	FF'.replace('\t', '')

    hexList = [x.strip() for x in hexstring.lower().split(',')]
    hexList.append('<EOP>') #End Of Packet token
    #EOS token??????
    hexDict = {}

    for key, val in enumerate(hexList):
        if len(val) == 1:
            val = '0'+val
        hexDict[val] = key  #dictionary k=hex, v=int  
    
    return hexDict

In [None]:
def training(runname, rnnType, maxPackets, packetTimeSteps, packetReverse, padOldTimeSteps, wtstd, 
             lr, decay, clippings, dimIn, dim, numClasses, batch_size, epochs, 
             trainPercent):
    print locals()
    print
    
    X = T.tensor4('inputs')
    Y = T.matrix('targets')
    linewt_init = IsotropicGaussian(wtstd)
    line_bias = Constant(1.0)
    rnnwt_init = IsotropicGaussian(wtstd)
    rnnbias_init = Constant(0.0)
    classifierWts = IsotropicGaussian(wtstd)

    learning_rateClass = theano.shared(np.array(lr, dtype=theano.config.floatX))
    learning_decay = np.array(decay, dtype=theano.config.floatX)
    
    hexDict = hexTokenizer()
    ###DATA PREP
     
    print 'initializing network graph'
    ###ENCODER
    if rnnType == 'gru':
        rnn = GatedRecurrent(dim=dim, weights_init = rnnwt_init, biases_init = rnnbias_init, name = 'gru')
        dimMultiplier = 2
    else:
        rnn = LSTM(dim=dim, weights_init = rnnwt_init, biases_init = rnnbias_init, name = 'lstm')
        dimMultiplier = 4

    fork = Fork(output_names=['linear', 'gates'],
                name='fork', input_dim=dimIn, output_dims=[dim, dim * dimMultiplier], 
                weights_init = linewt_init, biases_init = line_bias)

    ###CONTEXT
    if rnnType == 'gru':
        rnnContext = GatedRecurrent(dim=dim, weights_init = rnnwt_init, 
                                    biases_init = rnnbias_init, name = 'gruContext')
    else:
        rnnContext = LSTM(dim=dim, weights_init = rnnwt_init, biases_init = rnnbias_init, 
                          name = 'lstmContext')

    forkContext = Fork(output_names=['linearContext', 'gatesContext'],
                name='forkContext', input_dim=dim, output_dims=[dim, dim * dimMultiplier], 
                weights_init = linewt_init, biases_init = line_bias)

    forkDec = Fork(output_names=['linear', 'gates'],
                name='forkDec', input_dim=dim, output_dims=[dim, dim*dimMultiplier], 
                weights_init = linewt_init, biases_init = line_bias)

    #CLASSIFIER
    bmlp = BatchNormalizedMLP( activations=[Logistic(),Logistic()], 
               dims=[dim, dim, numClasses],
               weights_init=classifierWts,
               biases_init=Constant(0.0001) )

    #initialize the weights in all the functions
    fork.initialize()
    rnn.initialize()
    forkContext.initialize()
    rnnContext.initialize()
    forkDec.initialize()
    bmlp.initialize()

    def onestepEnc(X):
        data1, data2 = fork.apply(X) 

        if rnnType == 'gru':
            hEnc = rnn.apply(data1, data2) 
        else:
            hEnc, _ = rnn.apply(data2)

        return hEnc

    hEnc, _ = theano.scan(onestepEnc, X) #(mini*numPackets, packetLen, 1, hexdictLen)
    hEncReshape = T.reshape(hEnc[:,-1], (-1, maxPackets, 1, dim)) #[:,-1] takes the last rep for each packet
                                                                 #(mini, numPackets, 1, dimReduced)
    def onestepContext(hEncReshape):

        data3, data4 = forkContext.apply(hEncReshape)

        if rnnType == 'gru':
            hContext = rnnContext.apply(data3, data4)
        else:
            hContext, _ = rnnContext.apply(data4)

        return hContext

    hContext, _ = theano.scan(onestepContext, hEncReshape)
    hContextReshape = T.reshape(hContext[:,-1], (-1,dim))

    data5, _ = forkDec.apply(hContextReshape)

    pyx = bmlp.apply(data5)
    softmax = Softmax()
    softoutClass = softmax.apply(pyx)
    costClass = T.mean(CategoricalCrossEntropy().apply(Y, softoutClass))

    #CREATE GRAPH
    cgClass = ComputationGraph([costClass])
    paramsClass = VariableFilter(roles = [PARAMETER])(cgClass.variables)
    updatesClass = Adam(paramsClass, costClass, learning_rateClass, c=clippings) 
    #updatesClass = RMSprop(costClass, paramsClass, learning_rateClass, c=clippings)

    #print 'grad compiling'
    #gradients = T.grad(costClass, paramsClass)
    #gradients = clip_norms(gradients, clippings)
    #gradientFun = theano.function([X,Y], gradients, allow_input_downcast=True)
    #print 'finish with grads'

    print 'compiling graph you talented soul'
    classifierTrain = theano.function([X,Y], [costClass, hEnc, hContext, pyx, softoutClass], 
                                      updates=updatesClass, allow_input_downcast=True)
    classifierPredict = theano.function([X], softoutClass, allow_input_downcast=True)
    print 'finished compiling'

    epochCost = []
    gradNorms = []
    trainAcc = []
    testAcc = []

    costCollect = []
    trainCollect = []

    print 'training begins'
    iteration = 0
    #epoch
    for epoch in xrange(epochs):

        #iteration/minibatch
        #for start, end in zip(range(0, trainIndex,batch_size),
        #                      range(batch_size, trainIndex, batch_size)):
        
        trainingTargets = []
        trainingSessions = []
        
        for d in range(len(sampleList)):
            sampleLen = len(compDict[sampleList[d]].keys())
            sampleKeys = random.sample(compDict[sampleList[d]].keys()[:sampleLen], 5)
                        
            for key in sampleKeys:
                oneEncoded = oneSessionEncoder(compDict[sampleList[d]][key][0],
                                                          hexDict = hexDict,
                                                          packetReverse=packetReverse, 
                                                          padOldTimeSteps = padOldTimeSteps, 
                                                          maxPackets = maxPackets, 
                                                          packetTimeSteps = packetTimeSteps)
                trainIndex = [0]*numClasses
                trainIndex[d] = 1
                trainingTargets.append(trainIndex)
                trainingSessions.append(oneEncoded[0])

        sessionsMinibatch = np.asarray(trainingSessions).reshape((-1, packetTimeSteps, 1, dimIn))
        targetsMinibatch = np.asarray(trainingTargets)

        costfun = classifierTrain(sessionsMinibatch, targetsMinibatch)

        costCollect.append(costfun[0])
        trainCollect.append(np.mean(np.argmax(costfun[-1],axis=1) == np.argmax(targetsMinibatch, axis=1)))

        iteration+=1

        if iteration == 1:
            print 'you are amazing'


        if iteration%200 == 0:
            print
            print '   Iteration: ', iteration
            print '   Cost: ', np.mean(costCollect[-20:])
            print '   TRAIN accuracy: ', np.mean(trainCollect[-20:])
            print

            #grads = gradientFun(sessionsMinibatch, targetsMinibatch)
            #for gra in grads:
            #    print '  gradient norms: ', np.linalg.norm(gra)

            np.savetxt('/data/fs4/home/bradh/outputs/'+runname+"_TRAIN.csv", trainCollect[::50], delimiter=",")
            np.savetxt('/data/fs4/home/bradh/outputs/'+runname+"_COST.csv", costCollect[::50], delimiter=",")

        #testing accuracy
        if iteration%500 == 0:
            predtar, acttar, testCollect = predictClass(classifierPredict, sampleList, compDict, hexDict,
                                                        numClasses, trainPercent, dimIn, 
                                                        maxPackets,packetTimeSteps, padOldTimeSteps)

            binaryPrecisionRecall(predtar, acttar)

            testAcc.append(testCollect)
            np.savetxt('/data/fs4/home/bradh/outputs/'+runname+"_TEST.csv", testAcc, delimiter=",")

        #save the models
        if iteration%500 == 0:
            #pickleFile(classifierTrain, filePath='/data/fs4/home/bradh/outputs/',
            #            fileName=runname+'TRAIN'+str(iteration))
            pickleFile(classifierPredict, filePath='/data/fs4/home/bradh/outputs/',
                        fileName=runname+'PREDICT'+str(iteration))

    #epochCost.append(np.mean(costCollect[-50:]))
    #trainAcc.append(np.mean(trainCollect[-50:]))

    #print 'Epoch: ', epoch
    #module_logger.debug('Epoch:%r',epoch)
    #print 'Epoch cost average: ', epochCost[-1]
    #print 'Epoch TRAIN accuracy: ', trainAcc[-1]
    
    return classifierPredict, classifierTrain

In [None]:
#TODO: expose classifier dim
train, predict = training(runname, rnnType, maxPackets, packetTimeSteps, packetReverse, padOldTimeSteps, wtstd, 
             lr, decay, clippings, dimIn, dim, numClasses, batch_size, epochs, 
             trainPercent)

In [None]:
hexDict = hexTokenizer()
predtar, acttar, testCollect = predictClass(train,sampleList, compDict, hexDict,
            numClasses, trainPercent, dimIn, 
            maxPackets,packetTimeSteps, padOldTimeSteps)

In [None]:
binaryPrecisionRecall(predtar, acttar)

In [None]:
d=5
sampleLen = len(compDict[sampleList[d]].keys())
sampleKeys = random.sample(compDict[sampleList[d]].keys()[:sampleLen], 10)                      
trainingTargets = []
trainingSessions = []
for key in sampleKeys:
    oneEncoded = oneSessionEncoder(compDict[sampleList[d]][key][0],
                                              hexDict = hexDict,
                                              packetReverse=packetReverse, 
                                              padOldTimeSteps = padOldTimeSteps, 
                                              maxPackets = maxPackets, 
                                              packetTimeSteps = packetTimeSteps)
    trainIndex = [0]*numClasses
    trainIndex[d] = 1
    trainingTargets.append(trainIndex)
    trainingSessions.append(oneEncoded[0])
sessionsMinibatch = np.asarray(trainingSessions).reshape((-1, 28, 1, 257))    
np.argmax(train(sessionsMinibatch), axis = 1)

In [None]:
oneSessionEncoder(compDict[sampleList[d]][key][0],
                                                          hexDict = hexDict,
                                                          packetReverse=packetReverse, 
                                                          padOldTimeSteps = padOldTimeSteps, 
                                                          maxPackets = maxPackets, 
                                                          packetTimeSteps = packetTimeSteps)