In [1]:
sc

In [10]:
data_file = "file:///home/lygbug666/workdir/spark-py-notebooks/kddcup.data_10_percent.gz"
raw_data = sc.textFile(data_file)

# Getting a Data Frame

A Spark DataFrame is a distributed collection of data organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R or Pandas. 

The entry point into all SQL functionality in Spark is the SQLContext class. To create a basic instance, all we need is a SparkContext reference. Since we are running Spark in shell mode (using pySpark) we can use the global context object sc for this purpose.

In [11]:
from pyspark.sql import SQLContext

sqlContext = SQLContext(sc)


## Inferring the schema

With a SQLContext, we are ready to create a DataFrame from our existing RDD. But first we need to tell Spark SQL the schema in our data.

Spark SQL can convert an RDD of Row objects to a DataFrame. Rows are constructed by passing a list of key/value pairs as kwargs to the Row class. The keys define the column names, and the types are inferred by looking at the first row.

In [12]:
from pyspark.sql import Row
csv_data = raw_data.map(lambda l: l.split(","))
row_data = csv_data.map(lambda p: Row(
    duration=int(p[0]), 
    protocol_type=p[1],
    service=p[2],
    flag=p[3],
    src_bytes=int(p[4]),
    dst_bytes=int(p[5])
    )
)

training = spark.createDataFrame([
    (0, "a b c d e spark", 1.0),
    (1, "b d", 0.0),
    (2, "spark f g h", 1.0),
    (3, "hadoop mapreduce", 0.0)], ["id", "text", "label"])

Once we have our RDD of Row we can infer and register the schema.

Then we create dataframe by using the rdd row as scheme, and register the temp table.

Now we can run SQL queries over our data frame that has been registered as a table.


In [13]:
interactions_df = sqlContext.createDataFrame(row_data)

interactions_df.registerTempTable("interactions")

interactions_df.printSchema()

In [19]:
# Select tcp network interactions with more than 1 second duration and no transfer from destination
tcp_interactions = sqlContext.sql("select duration, dst_bytes from interactions \
                                  where protocol_type = 'tcp' and duration > 1000 and dst_bytes = 0")
tcp_interactions.show(10)


+--------+---------+
|duration|dst_bytes|
+--------+---------+
|    5057|        0|
|    5059|        0|
|    5051|        0|
|    5056|        0|
|    5051|        0|
|    5039|        0|
|    5062|        0|
|    5041|        0|
|    5056|        0|
|    5064|        0|
+--------+---------+
only showing top 10 rows



The results of SQL queries are RDDs and support all the normal RDD operations.

In [24]:
tcp_interactions_out = tcp_interactions \
                    .rdd.map(lambda p: "Duration: {}, Dest. bytes: {}".format(p.duration, p.dst_bytes)).collect()
    
# 如果不加.collect() ,是pipeline rdd 无法迭代
# TypeError: 'PipelinedRDD' object is not iterable

for ti_out in tcp_interactions_out:
    print (ti_out)

Duration: 5057, Dest. bytes: 0
Duration: 5059, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5056, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5039, Dest. bytes: 0
Duration: 5062, Dest. bytes: 0
Duration: 5041, Dest. bytes: 0
Duration: 5056, Dest. bytes: 0
Duration: 5064, Dest. bytes: 0
Duration: 5043, Dest. bytes: 0
Duration: 5061, Dest. bytes: 0
Duration: 5049, Dest. bytes: 0
Duration: 5061, Dest. bytes: 0
Duration: 5048, Dest. bytes: 0
Duration: 5047, Dest. bytes: 0
Duration: 5044, Dest. bytes: 0
Duration: 5063, Dest. bytes: 0
Duration: 5068, Dest. bytes: 0
Duration: 5062, Dest. bytes: 0
Duration: 5046, Dest. bytes: 0
Duration: 5052, Dest. bytes: 0
Duration: 5044, Dest. bytes: 0
Duration: 5054, Dest. bytes: 0
Duration: 5039, Dest. bytes: 0
Duration: 5058, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5032, Dest. bytes: 0
Duration: 5063, Dest. bytes: 0
Duration: 5040, Dest. bytes: 0
Duration: 5051, Dest. bytes: 0
Duration: 5066, Dest. bytes: 0
Duration

We can easily have a look at our data frame schema using printSchema.

In [27]:
interactions_df.printSchema()

root
 |-- dst_bytes: long (nullable = true)
 |-- duration: long (nullable = true)
 |-- flag: string (nullable = true)
 |-- protocol_type: string (nullable = true)
 |-- service: string (nullable = true)
 |-- src_bytes: long (nullable = true)



## Queries as DataFrame operations

let's say we want to count how many interactions are there for each protocol type. 

In [28]:
from time import time

t0 = time()
# interactions_df.select("protocol_type", "duration", "dst_bytes").groupBy("protocol_type").count().show()
interactions_df.select("protocol_type", "duration").groupBy("protocol_type").count().show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+-------------+------+
|protocol_type| count|
+-------------+------+
|          tcp|190065|
|          udp| 20354|
|         icmp|283602|
+-------------+------+

Query performed in 8.454 seconds


In [31]:
interaction_time = sqlContext.sql("select protocol_type, count(duration) from interactions \
                                  group by protocol_type")
interaction_time.show()

+-------------+---------------+
|protocol_type|count(duration)|
+-------------+---------------+
|          tcp|         190065|
|          udp|          20354|
|         icmp|         283602|
+-------------+---------------+



Now imagine that we want to count how many interactions last more than 1 second, with no data transfer from destination, grouped by protocol type.

In [38]:
t0 = time()
interactions_df.select("protocol_type", "duration", "dst_bytes")\
.filter(interactions_df.duration > 1000).filter(interactions_df.dst_bytes == 0).groupBy("protocol_type").count().show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+-------------+-----+
|protocol_type|count|
+-------------+-----+
|          tcp|  139|
+-------------+-----+

Query performed in 7.549 seconds


In [37]:
t0 = time()
interactions_df.select("protocol_type", "duration", "dst_bytes")\
.filter("duration > 1000").filter("dst_bytes == 0").groupBy("protocol_type").count().show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+-------------+-----+
|protocol_type|count|
+-------------+-----+
|          tcp|  139|
+-------------+-----+

Query performed in 7.44 seconds


In [39]:
t0 = time()
sqlContext.sql("select protocol_type, count(duration) from interactions where duration > 1000 and dst_bytes == 0 \
                                  group by protocol_type").show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+-------------+---------------+
|protocol_type|count(duration)|
+-------------+---------------+
|          tcp|            139|
+-------------+---------------+

Query performed in 7.752 seconds


Let's count how many attack and normal interactions we have. First we need to add the label column to our data.

In [40]:
def get_label_type(label):
    if label!="normal.":
        return "attack"
    else:
        return "normal"


In [43]:
from pyspark.sql import Row
csv_data = raw_data.map(lambda l: l.split(","))
row_labeled_data = csv_data.map(lambda p: Row(
    duration=int(p[0]), 
    protocol_type=p[1],
    service=p[2],
    flag=p[3],
    src_bytes=int(p[4]),
    dst_bytes=int(p[5]),
    label=get_label_type(p[41])
    )
)

In [44]:
interactions_labeled_df = sqlContext.createDataFrame(row_labeled_data)

interactions_labeled_df.registerTempTable("labled_interactions")

interactions_labeled_df.printSchema()

root
 |-- dst_bytes: long (nullable = true)
 |-- duration: long (nullable = true)
 |-- flag: string (nullable = true)
 |-- label: string (nullable = true)
 |-- protocol_type: string (nullable = true)
 |-- service: string (nullable = true)
 |-- src_bytes: long (nullable = true)



In [50]:
t0 = time()
interactions_labeled_df.select("label").groupBy("label").count().show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+------+------+
| label| count|
+------+------+
|normal| 97278|
|attack|396743|
+------+------+

Query performed in 7.925 seconds


In [54]:
t0 = time()
sqlContext.sql("select label, count(*) as count from labled_interactions group by label").show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+------+------+
| label| count|
+------+------+
|normal| 97278|
|attack|396743|
+------+------+

Query performed in 7.977 seconds


Now we want to count them by label and protocol type, in order to see how important the protocol type is to detect when an interaction is or not an attack.

In [55]:
t0 = time()
interactions_labeled_df.select("label","protocol_type").groupBy("label","protocol_type").count().show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+------+-------------+------+
| label|protocol_type| count|
+------+-------------+------+
|normal|          udp| 19177|
|normal|         icmp|  1288|
|normal|          tcp| 76813|
|attack|         icmp|282314|
|attack|          tcp|113252|
|attack|          udp|  1177|
+------+-------------+------+

Query performed in 8.182 seconds


At first sight it seems that udp interactions are in lower proportion between network attacks versus other protocol types.

In [56]:
t0 = time()
sqlContext.sql("select label, protocol_type, count(*) as count from labled_interactions group by label, protocol_type").show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+------+-------------+------+
| label|protocol_type| count|
+------+-------------+------+
|normal|          udp| 19177|
|normal|         icmp|  1288|
|normal|          tcp| 76813|
|attack|         icmp|282314|
|attack|          tcp|113252|
|attack|          udp|  1177|
+------+-------------+------+

Query performed in 7.71 seconds


In [58]:
t0 = time()
# interactions_labeled_df.select("label", "protocol_type", "dst_bytes").groupBy("label", "protocol_type", "dst_bytes==0").count().show()

interactions_labeled_df.select("label", "protocol_type", "dst_bytes").groupBy("label", "protocol_type", interactions_labeled_df.dst_bytes==0).count().show()
tt = time() - t0

print ("Query performed in {} seconds".format(round(tt,3)))

+------+-------------+---------------+------+
| label|protocol_type|(dst_bytes = 0)| count|
+------+-------------+---------------+------+
|normal|          udp|          false| 15583|
|attack|          udp|          false|    11|
|attack|          tcp|           true|110583|
|normal|          tcp|          false| 67500|
|attack|         icmp|           true|282314|
|attack|          tcp|          false|  2669|
|normal|          tcp|           true|  9313|
|normal|          udp|           true|  3594|
|normal|         icmp|           true|  1288|
|attack|          udp|           true|  1166|
+------+-------------+---------------+------+

Query performed in 8.041 seconds


We see how relevant is this new split to determine if a network interaction is an attack.

## join and aggregation

val sql = new SQLContext(sc)

val pInfo = sql.read.json("people.txt")

val pSalar = sql.read.json("salary.txt")

val info_salary = pInfo.join(pSalar,"id")//单个字段进行内连接

val info_salary1 = pInfo.join(pSalar,Seq("id","name"))//多字段链接