In [None]:
import numpy as np
from modules import Rule
import ipaddress
import math


def load_ruleset(fname, except_zero = True, random_priority = 0):
	""" Load ruleset from ClassBench filter file """
	"""
	(expect_zero = True): due to zero nodes(0.0.0.0/n) causes huge computing resource, you can exclude all zero nodes here.
	(random_priorit = n): from 0 to n. if n = 0, use the sequence of the ruleset as each rule's priority.
	"""
	ruleset = []
	with open(fname, 'r') as f:
		for n, line in enumerate(f):
			# LINE FORMAT
			# @sip_network dip_network sp_low : sp_high dp_low: dp_high protocal/protocol_mask xxx/xxx
			tok = line.strip().split('\t')
			rule = Rule()
			sip = ipaddress.ip_network(tok[0][1:])
			dip = ipaddress.ip_network(tok[1])
			sp = tok[2].split(':')
			dp = tok[3].split(':')
			protocol = tok[4].split('/')
			
			if except_zero:
				if int(sip[0]) == 0 or int(dip[0]) == 0:
					continue
			#sip[0] : given a IP/mask, upper bound of IP address
			#sip[-1] : given a IP/mask, lower bound of IP address
			rule.sip_low = int(sip[0])
			rule.sip_high = int(sip[-1])
			rule.dip_low = int(dip[0])
			rule.dip_high = int(dip[-1])
			rule.sp_low, rule.sp_high = int(sp[0]), int(sp[1])
			rule.dp_low, rule.dp_high = int(dp[0]), int(dp[1])
			rule.protocol_val, rule.protocol_mask = int(protocol[0], 16), int(protocol[1], 16)

			if random_priority:
				rule.priority = int(random.randint(0, random_priority))
			else:
				rule.priority = n
			ruleset.append(rule)
	return ruleset

def cut2points(ruleset):
    sip = []
    dip = []
    sp = []
    dp = []
    for i in ruleset:
        sip.append(i.sip_low)
        sip.append(i.sip_high)
        dip.append(i.dip_low)
        dip.append(i.dip_high)
        if i.sp_low == i.sp_high:
            sp.append(i.sp_low)
            sp.append(i.sp_high+1)
        else:
            sp.append(i.sp_low)
            sp.append(i.sp_high)
        if i.dp_low == i.dp_high:
            dp.append(i.dp_low)
            dp.append(i.dp_high+1)
        else:
            dp.append(i.dp_low)
            dp.append(i.dp_high)
    return sorted(set(sip)), sorted(set(dip)), sorted(set(sp)), sorted(set(dp))

def list2mapping(a):
    mapping = {}
    for i,v in enumerate(a):
        mapping[v] = i
    return mapping

def get_point_index(rule,sip_map, dip_map, sp_map, dp_map):
    sip_low = sip_map[rule.sip_low]
    sip_high = sip_map[rule.sip_high]
    dip_low = dip_map[rule.dip_low]
    dip_high = dip_map[rule.dip_high]
    sp_low = sp_map[rule.sp_low]
    sp_high = sp_map[rule.sp_high]
    dp_low = dp_map[rule.dp_low]
    dp_high = dp_map[rule.dp_high]
    return sip_low, sip_high, dip_low, dip_high, sp_low, sp_high, dp_low, dp_high

def reverse_map(dict_ori):
    dict_new = {value:key for key,value in dict_ori.items()}
    return dict_new

In [None]:
def doit(index):
    """this cell read the ruleset and cut the full space based on the rules' ranges"""
    ruleset = load_ruleset("../data/fw filters/MyFilters{}k_1.txt".format(index), False)
    sip, dip, sp, dp = cut2points(ruleset) # returns a set
    sip_len = len(sip)
    dip_len = len(dip)
    sp_len = len(sp)
    dp_len = len(dp)
    print("current division of 4 dimension is: sip: {} , dip: {} , sp: {} ,dp: {}".format(sip_len, dip_len, sp_len, dp_len))

    """construct a map(Using map is for quick searching with O(1) complexity. It can accelerate the process where squares be 
    put into the divided space and quickly find their correct position."""
    sip_map = list2mapping(sip)
    dip_map = list2mapping(dip)
    sp_map = list2mapping(sp)
    dp_map = list2mapping(dp)
    sip_re_map = reverse_map(sip_map)
    dip_re_map = reverse_map(dip_map)
    sp_re_map = reverse_map(sp_map)
    dp_re_map = reverse_map(dp_map)

    """this cell using numpy function to construct a extra-huge space and the dimension is based on previous cutting"""
    temp_1 = np.zeros((sip_len-1, dip_len-1, sp_len-1, dp_len-1), dtype=np.uint8) # sip_len-1 is intervals
    print("the space has shape of: {}".format(temp_1.shape))


    """this cell implements putting squares into the devided space. 
    It uses 1000 cycles here(about 1s) and for 10K, it is about 10s"""
    for i in ruleset:
        sip_low, sip_high, dip_low, dip_high, sp_low, sp_high, dp_low, dp_high = get_point_index(i, sip_map, dip_map, sp_map, dp_map)
        # print("{} {} {} {} {} {} {} {}".format(sip_low, sip_high, dip_low, dip_high, sp_low, sp_high, dp_low, dp_high))
        temp_1[sip_low:sip_high, dip_low:dip_high, sp_low:sp_high+1, dp_low:dp_high+1] += 1
        # this is the main problem where sp_low:sp_high may be 4:4 and means nothing. Here I use sp_low:sp_high+1, which means 4:5 
        # hit the 4:4 rule, and 14:15 hits the 14:14 rule.


    # assuming hit_mat
    hit_mat = temp_1
    sip_range, dip_range, sp_range = hit_mat.shape[0], hit_mat.shape[1], hit_mat.shape[2]
    dp_range = hit_mat.shape[3] # can be merged

    weight = 0
    for dp_ind in range(0, dp_range):
        dp_width = dp_re_map[dp_ind+1]-dp_re_map[dp_ind]
        volume = np.zeros((sip_range, dip_range, sp_range), dtype = np.float32)
        for index in np.ndindex(sip_range, dip_range, sp_range):
            if hit_mat[index[0],index[1],index[2],dp_ind] > 1:
                volume[index] = math.log((sip_re_map[index[0]+1]-sip_re_map[index[0]])*(dip_re_map[index[1]+1]-dip_re_map[index[1]])*(sp_re_map[index[2]+1]-sp_re_map[index[2]])*dp_width)
        hit_split = hit_mat[:,:,:,dp_ind]
        np.add(np.sum(np.multiply(hit_split, volume)), weight, out=weight)
    return weight

In [None]:
import time
for i in range(1, 11):
    start = time.clock()
    weight = doit(i)
    print(i, weight)
    print("time elapsed: {}".format(time.clock() - start))