## Set Up

In [1]:
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import Row
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession \
    .builder \
    .master("local") \
    .appName("review_and_category_analytics") \
    .config("spark.executor.memory", '8g') \
    .config('spark.executor.cores', '4') \
    .config('spark.cores.max', '4') \
    .config("spark.driver.memory",'8g') \
    .getOrCreate()

sc = spark.sparkContext

sqlCtx = SQLContext(sc)

In [2]:
galaxy_df = sqlCtx.read.\
    format("csv").\
    option("header", "true").\
    option("inferSchema", "true").\
    load("data/training_solutions_rev1.csv")

## Summarize the Data

#### Structure of the Data

In [3]:
galaxy_df.printSchema()

root
 |-- GalaxyID: integer (nullable = true)
 |-- Class1.1: double (nullable = true)
 |-- Class1.2: double (nullable = true)
 |-- Class1.3: double (nullable = true)
 |-- Class2.1: double (nullable = true)
 |-- Class2.2: double (nullable = true)
 |-- Class3.1: double (nullable = true)
 |-- Class3.2: double (nullable = true)
 |-- Class4.1: double (nullable = true)
 |-- Class4.2: double (nullable = true)
 |-- Class5.1: double (nullable = true)
 |-- Class5.2: double (nullable = true)
 |-- Class5.3: double (nullable = true)
 |-- Class5.4: double (nullable = true)
 |-- Class6.1: double (nullable = true)
 |-- Class6.2: double (nullable = true)
 |-- Class7.1: double (nullable = true)
 |-- Class7.2: double (nullable = true)
 |-- Class7.3: double (nullable = true)
 |-- Class8.1: double (nullable = true)
 |-- Class8.2: double (nullable = true)
 |-- Class8.3: double (nullable = true)
 |-- Class8.4: double (nullable = true)
 |-- Class8.5: double (nullable = true)
 |-- Class8.6: double (nullable = 

#### First 5 Rows of the Data

In [4]:
galaxy_df.show(5)

+--------+--------+--------+--------+-----------+-----------+-----------+-----------+-----------+-----------+--------+-----------+-----------+-----------+--------+--------+-----------+-----------+-----------+--------+---------+-----------+---------+-----------+---------+--------+-----------+--------+-----------+-----------+-----------+-----------+-----------+-----------+---------+---------+---------+-----------+
|GalaxyID|Class1.1|Class1.2|Class1.3|   Class2.1|   Class2.2|   Class3.1|   Class3.2|   Class4.1|   Class4.2|Class5.1|   Class5.2|   Class5.3|   Class5.4|Class6.1|Class6.2|   Class7.1|   Class7.2|   Class7.3|Class8.1| Class8.2|   Class8.3| Class8.4|   Class8.5| Class8.6|Class8.7|   Class9.1|Class9.2|   Class9.3|  Class10.1|  Class10.2|  Class10.3|  Class11.1|  Class11.2|Class11.3|Class11.4|Class11.5|  Class11.6|
+--------+--------+--------+--------+-----------+-----------+-----------+-----------+-----------+-----------+--------+-----------+-----------+-----------+--------+-----

#### Breakdown of Objects in the Data

In [5]:
galaxy_df.createOrReplaceTempView("df") #allow us to use SQL statements

#How many objects are there
numTotal = galaxy_df.count()
print("There are", numTotal, "objects.")

#How many smooth galaxies
numSmooth = sqlCtx.sql("SELECT * FROM df WHERE `Class1.1` >= 0.5").count()
print("There are", numSmooth, "smooth galaxies.")

#How many edge-on glaxies
numEdge = sqlCtx.sql("SELECT * FROM df WHERE `Class1.2` >= 0.5 \
                                        AND `Class2.1` >= `Class2.2`").count()
print("There are", numEdge, "edge-on galaxies.")

#How many spiral glaxies
numSpiral = sqlCtx.sql("SELECT * FROM df WHERE `Class1.2` >= 0.5 AND \
                                        `Class2.1` < `Class2.2` \
                                        AND `Class4.1` >= `Class4.2`").count()
print("There are", numSpiral, "spiral galaxies.")

numOther = numTotal-numSmooth-numEdge-numSpiral
print("There are", numOther, "other objects.")

There are 61578 objects.
There are 25868 smooth galaxies.
There are 6628 edge-on galaxies.
There are 15074 spiral galaxies.
There are 14008 other objects.


## Create New DataFrame

In [6]:
gal_class = galaxy_df #make a copy of the full dataset

#Classify as Smooth
gal_class = gal_class.withColumn(
        "Smooth", 
        F.when(F.col("`Class1.1`") >= 0.5, 1).otherwise(0)
)

#Classify as Edge
gal_class = gal_class.withColumn(
        "Edge", 
        F.when((F.col("`Class1.2`") >= 0.5) & 
               (F.col("`Class2.1`") >= F.col("`Class2.2`")), 1).otherwise(0)
)

#Classify as Spiral
gal_class = gal_class.withColumn(
        "Spiral", 
        F.when((F.col("`Class1.2`") >= 0.5) & 
               (F.col("`Class2.1`") < F.col("`Class2.2`")) &
               (F.col("`Class4.1`") >= F.col("`Class4.2`")), 1).otherwise(0)
)

gal_class.select("Smooth", "Edge", "Spiral").show(20)

+------+----+------+
|Smooth|Edge|Spiral|
+------+----+------+
|     0|   0|     1|
|     0|   0|     1|
|     1|   0|     0|
|     1|   0|     0|
|     1|   0|     0|
|     1|   0|     0|
|     0|   0|     0|
|     1|   0|     0|
|     0|   0|     1|
|     0|   1|     0|
|     0|   0|     0|
|     0|   0|     0|
|     0|   0|     0|
|     0|   0|     0|
|     0|   0|     1|
|     0|   0|     0|
|     0|   1|     0|
|     1|   0|     0|
|     1|   0|     0|
|     0|   1|     0|
+------+----+------+
only showing top 20 rows



In [7]:
#Check for galaxies which were doubly classified.
gal_class.createOrReplaceTempView("df2")

sqlCtx.sql("SELECT GalaxyID, Smooth, Edge, Spiral FROM df2 WHERE smooth+edge+spiral > 1").show() #Galaxies 239928 and 874101

+--------+------+----+------+
|GalaxyID|Smooth|Edge|Spiral|
+--------+------+----+------+
|  239928|     1|   1|     0|
|  874101|     1|   1|     0|
+--------+------+----+------+



### Remove Unecessary Information

In [8]:
final_df = sqlCtx.sql("SELECT GalaxyID, Smooth, Edge, Spiral FROM df2 WHERE GalaxyID != 239928 AND GalaxyID != 874101")
final_df.show(5)

+--------+------+----+------+
|GalaxyID|Smooth|Edge|Spiral|
+--------+------+----+------+
|  100008|     0|   0|     1|
|  100023|     0|   0|     1|
|  100053|     1|   0|     0|
|  100078|     1|   0|     0|
|  100090|     1|   0|     0|
+--------+------+----+------+
only showing top 5 rows



In [9]:
##Create a single column of classifications
##This code was not run but is saved here in case it will become useful.

##merge classes columns {0: nothing, 1:smooth, 2:edge, 3:spiral}
#classes = class_df.withColumn("Class", 
#        F.when((F.col("smooth") == 1), 1)
#        .when((F.col("edge") == 1), 2)
#        .when((F.col("spiral") == 1), 3).otherwise(0)
#        ).select("galaxyID", "class")
#
#classes.show(10)

## Save Classifications Table as New CSV

In [10]:
#final_df.write.csv('galaxyClassifications.csv', header=True)
#This code is commented out because the csv doesn't need to be regenerated.