In [2]:
from pyspark.sql import Window
from pyspark.sql.functions import lag, col, lit, lead
from sparkstudy.deploy.demo_sessions import DemoSQLSessionFactory
from sparkstudy.libs.tools import create_random_data
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
COLUMNS = ["name","age","salary"]
session_factory_normal = DemoSQLSessionFactory(name="normal")
data = create_random_data(10)
spark_session = session_factory_normal.build_session()
df = spark_session.createDataFrame(data,COLUMNS).cache()
df.show()

+----+---+--------------------+
|name|age|              salary|
+----+---+--------------------+
|   D|  3|  0.2883449937034951|
|   O|  1|  0.8952704771774317|
|   F|  2|  0.2622462757336257|
|   A|  6|  0.7930823582184781|
|   Y|  5| 0.16709710188277627|
|   M|  8|  0.8183381439211345|
|   H|  9|  0.8052992150655519|
|   L|  2|0.059200771680881004|
|   K|  2|   0.893548986244985|
|   B|  9|0.026260197849612976|
+----+---+--------------------+



本来想要实验一下怎么shift的。结果发现在spark里面，这类操作必须通过window

In [4]:
w = Window().partitionBy("name").orderBy(col("name"))
df.select("*", lag("age").over(w).alias("new_col")).show()

+----+---+--------------------+-------+
|name|age|              salary|new_col|
+----+---+--------------------+-------+
|   K|  2|   0.893548986244985|   null|
|   F|  2|  0.2622462757336257|   null|
|   B|  9|0.026260197849612976|   null|
|   Y|  5| 0.16709710188277627|   null|
|   L|  2|0.059200771680881004|   null|
|   M|  8|  0.8183381439211345|   null|
|   D|  3|  0.2883449937034951|   null|
|   O|  1|  0.8952704771774317|   null|
|   A|  6|  0.7930823582184781|   null|
|   H|  9|  0.8052992150655519|   null|
+----+---+--------------------+-------+



类似于禁用orderby

In [5]:
w = Window().partitionBy().orderBy(lit(1))
df.select("*", lag("age",2).over(w).alias("new_col")).show()

+----+---+--------------------+-------+
|name|age|              salary|new_col|
+----+---+--------------------+-------+
|   D|  3|  0.2883449937034951|   null|
|   O|  1|  0.8952704771774317|   null|
|   F|  2|  0.2622462757336257|      3|
|   A|  6|  0.7930823582184781|      1|
|   Y|  5| 0.16709710188277627|      2|
|   M|  8|  0.8183381439211345|      6|
|   H|  9|  0.8052992150655519|      5|
|   L|  2|0.059200771680881004|      8|
|   K|  2|   0.893548986244985|      9|
|   B|  9|0.026260197849612976|      2|
+----+---+--------------------+-------+



In [6]:
w = Window().partitionBy().orderBy(lit(1))
df.select("*", lead("age",2).over(w).alias("new_col")).show()



+----+---+--------------------+-------+
|name|age|              salary|new_col|
+----+---+--------------------+-------+
|   D|  3|  0.2883449937034951|      2|
|   O|  1|  0.8952704771774317|      6|
|   F|  2|  0.2622462757336257|      5|
|   A|  6|  0.7930823582184781|      8|
|   Y|  5| 0.16709710188277627|      9|
|   M|  8|  0.8183381439211345|      2|
|   H|  9|  0.8052992150655519|      2|
|   L|  2|0.059200771680881004|      9|
|   K|  2|   0.893548986244985|   null|
|   B|  9|0.026260197849612976|   null|
+----+---+--------------------+-------+

