In [14]:
import pyspark
import pandas as pd
import numpy as np
import pydataset

In [3]:
spark = pyspark.sql.SparkSession.builder.getOrCreate()

In [4]:
spark

In [6]:
np.random.seed(456)

pandas_dataframe = pd.DataFrame(
    dict(n=np.arange(20), group=np.random.choice(list("abc"), 20))
)
pandas_dataframe

Unnamed: 0,n,group
0,0,b
1,1,b
2,2,c
3,3,a
4,4,c
5,5,c
6,6,a
7,7,b
8,8,a
9,9,b


In [7]:
df = spark.createDataFrame(pandas_dataframe)
df

DataFrame[n: bigint, group: string]

In [11]:
df.select(df.group)

DataFrame[group: string]

In [12]:
df.select(df.n + 1).show()

+-------+
|(n + 1)|
+-------+
|      1|
|      2|
|      3|
|      4|
|      5|
|      6|
|      7|
|      8|
|      9|
|     10|
|     11|
|     12|
|     13|
|     14|
|     15|
|     16|
|     17|
|     18|
|     19|
|     20|
+-------+



In [18]:
# turn dataset into a dataframe and then dataframe into a spark object
mpg = spark.createDataFrame(pydataset.data('mpg'))

In [24]:
# can use .column or quote format to select columns
mpg.select(mpg.model, 'manufacturer', mpg.hwy).show()

+------------------+------------+---+
|             model|manufacturer|hwy|
+------------------+------------+---+
|                a4|        audi| 29|
|                a4|        audi| 29|
|                a4|        audi| 31|
|                a4|        audi| 30|
|                a4|        audi| 26|
|                a4|        audi| 26|
|                a4|        audi| 27|
|        a4 quattro|        audi| 26|
|        a4 quattro|        audi| 25|
|        a4 quattro|        audi| 28|
|        a4 quattro|        audi| 27|
|        a4 quattro|        audi| 25|
|        a4 quattro|        audi| 25|
|        a4 quattro|        audi| 25|
|        a4 quattro|        audi| 25|
|        a6 quattro|        audi| 24|
|        a6 quattro|        audi| 25|
|        a6 quattro|        audi| 23|
|c1500 suburban 2wd|   chevrolet| 20|
|c1500 suburban 2wd|   chevrolet| 15|
+------------------+------------+---+
only showing top 20 rows



In [29]:
mpg.select((mpg.hwy + mpg.cty) / 2).alias('avg_mileage').show()

+-----------------+
|((hwy + cty) / 2)|
+-----------------+
|             23.5|
|             25.0|
|             25.5|
|             25.5|
|             21.0|
|             22.0|
|             22.5|
|             22.0|
|             20.5|
|             24.0|
|             23.0|
|             20.0|
|             21.0|
|             21.0|
|             20.0|
|             19.5|
|             21.0|
|             19.5|
|             17.0|
|             13.0|
+-----------------+
only showing top 20 rows



In [30]:
#can store transformation in a variable and use it in select
avg_mileage_column = ((mpg.cty + mpg.hwy)/2).alias('avg_mileage')

In [32]:
# can use '*' to select all
mpg.select('*', avg_mileage_column).show()

+------------+------------------+-----+----+---+----------+---+---+---+---+-------+-----------+
|manufacturer|             model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|avg_mileage|
+------------+------------------+-----+----+---+----------+---+---+---+---+-------+-----------+
|        audi|                a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|       23.5|
|        audi|                a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|       25.0|
|        audi|                a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|       25.5|
|        audi|                a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|       25.5|
|        audi|                a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|       21.0|
|        audi|                a4|  2.8|1999|  6|manual(m5)|  f| 18| 26|  p|compact|       22.0|
|        audi|                a4|  3.1|2008|  6|  auto(av)|  f| 18| 27|  p|compact|       22.5|
|        audi|        a4 quattro|  1.8|1

In [34]:
# shows data types
mpg.printSchema()

root
 |-- manufacturer: string (nullable = true)
 |-- model: string (nullable = true)
 |-- displ: double (nullable = true)
 |-- year: long (nullable = true)
 |-- cyl: long (nullable = true)
 |-- trans: string (nullable = true)
 |-- drv: string (nullable = true)
 |-- cty: long (nullable = true)
 |-- hwy: long (nullable = true)
 |-- fl: string (nullable = true)
 |-- class: string (nullable = true)



In [36]:
from pyspark.sql.functions import col
col

<function pyspark.sql.functions._create_function.<locals>._(col)>

In [42]:
(col('asldkjv')) * 2

Column<b'(asldkjv * 2)'>