In [1]:
#spark
__author__ = 'vcoder'

import os
import sys

#change them when spark and pyspark path changes

#SPARK_HOME = "/home/vcoder/EDA/spark-1.5.0"
#SPARK_HOME_PYTHON = SPARK_HOME + "/python"

SPARK_HOME = "/home/worker/software/spark-1.5.0"
SPARK_HOME_PYTHON = SPARK_HOME + "/python"

os.environ['SPARK_HOME'] = SPARK_HOME
sys.path.append(SPARK_HOME_PYTHON)

from pyspark import SparkContext
from pyspark import SparkConf

import geoip2.database
reader = geoip2.database.Reader('GeoLite2-City.mmdb')

def ip2city(ip):
    try:
        city = reader.city(ip).city.name
    except:
        city = 'not found'
    return city

def ip2la(ip):
    try:
        la = reader.city(ip).location.latitude
    except:
        la = 0.
    return la

def ip2lo(ip):
    try:
        lo = reader.city(ip).location.longitude
    except:
        lo = 0.
    return lo


In [2]:
sc = SparkContext('local', 'ip2geo')

In [3]:
normalPath = 'csv/normal.csv'
attackPath = 'csv/attack.csv'

In [4]:
normalRaw = sc.textFile(normalPath).map(lambda x: x.split()).filter(lambda x: len(x) == 3) #delimiter may need to change
# print normalRaw.take(10)
attackRaw = sc.textFile(attackPath).map(lambda x: x.split()).filter(lambda x: len(x) == 3)
# print attackRaw.take(10)
normalMax = normalRaw.map(lambda x: float(x[0])).max()
# print normalMax
normalMin = normalRaw.map(lambda x: float(x[0])).min()
# print normalMin
attackMax = attackRaw.map(lambda x: float(x[0])).max()
# print attackMax
attackMin = attackRaw.map(lambda x: float(x[0])).min()
# print attackMin


In [5]:
normalIp = normalRaw.map(lambda x: [x[1], x[2], float(x[0])*100/(normalMax - normalMin)])
attackIp = attackRaw.map(lambda x: [x[1], x[2], float(x[0])*100/(attackMax - attackMin)])

# attackx = attackIp.map(lambda x: len(x))
# print attackx.filter(lambda x: x < 3).collect()

# normalx = normalIp.map(lambda x: len(x))
# print normalx.filter(lambda x: x < 3).collect()
# print normalIp.take(5)
# print normalRaw.takeOrdered(10, key = lambda x: -float(x[0]))
# print normalRaw.map(lambda x:len(x)).count()

In [7]:
# print normalIp.count()

In [6]:
#original un-optimized src and dst IP set
#it can be very slow when the data is very big
#so use the top Num instead
# normalSrc = normalRaw.map(lambda x: x[1])
# normalDst = normalRaw.map(lambda x: x[2])
# attackSrc = attackRaw.map(lambda x: x[1])
# attackDst = attackRaw.map(lambda x: x[2])

# print normalSrc.take(10)
# print normalDst.take(10)
# print attackSrc.take(10)
# print attackDst.take(10)
attackNum = 1000
normalNum = 1000

normalSrc = sc.parallelize(normalIp.takeOrdered(normalNum, key = lambda x: -x[2])).map(lambda x: x[0])
normalDst = sc.parallelize(normalIp.takeOrdered(normalNum, key = lambda x: -x[2])).map(lambda x: x[1])

attackSrc = sc.parallelize(normalIp.takeOrdered(normalNum, key = lambda x: -x[2])).map(lambda x: x[0])
attackDst = sc.parallelize(normalIp.takeOrdered(normalNum, key = lambda x: -x[2])).map(lambda x: x[1])


In [7]:
allIPs = normalSrc.union(normalDst).union(attackSrc).union(attackDst).distinct().collect()
# print allIPs[0:9]

In [8]:
geoData = []
for ip in allIPs:
    lo = ip2lo(ip)
    la = ip2la(ip)
    geoData.append({'ip' : ip,
                    'longitude' : lo,
                    'lantitude' : la})

In [9]:
attackNum = 1000
normalNum = 1000
victimNum = 1000

attackDS = []
for x in attackIp.takeOrdered(attackNum, key = lambda x: -x[2]):
    src = x[0]
    dst = x[1]
    val = x[2]
    attackDS .append({'source': src,
                      'destination' : dst,
                      'value' : val})
    
normalDS = []
for x in normalIp.takeOrdered(normalNum, key = lambda x: -x[2]):
    src = x[0]
    dst = x[1]
    val = x[2]
    normalDS .append({'source': src,
                      'destination' : dst,
                      'value' : val})

In [10]:
victimData = []
victimIPs = attackIp.takeOrdered(victimNum, key = lambda x: -x[2])
for x in victimIPs:
    dst = x[1]
    val = x[2]
    victimData.append({'destination' : dst,
                       'value' :val})

In [11]:
import sys
ipDistribution = {'recordID': 1,#sys.argv[1], #add arg!!!!!!!!
                  'geoData' : geoData,
                  'attackDS' : attackDS,
                  'normalDS' : normalDS,
                  'victimData' : victimData}

In [12]:
import pymongo
from pymongo import MongoClient
client = MongoClient()

dbName = 'test'
collectionName = 'maps'

db = client[dbName]
collection = db[collectionName]

collection.insert_one(ipDistribution)

<pymongo.results.InsertOneResult at 0x7f01146f9550>

In [14]:
sc.stop()