In [39]:
from pyspark.sql import Window
from pyspark.sql.functions import lag, col, lit, lead
from sparkstudy.deploy.demo_sessions import DemoSQLSessionFactory
from sparkstudy.libs.tools import create_random_data
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
COLUMNS = ["name","age","salary"]
session_factory_normal = DemoSQLSessionFactory(name="normal")
data = create_random_data(10)
spark_session = session_factory_normal.build_session()
df = spark_session.createDataFrame(data,COLUMNS).cache()
df.show()

+----+---+--------------------+
|name|age|              salary|
+----+---+--------------------+
|   K|  3| 0.02161606520352055|
|   Q|  3|  0.6326711203749021|
|   S|  8| 0.43131405222650354|
|   J| 10|  0.5438762312917514|
|   L|  1|  0.7336759510839976|
|   X|  9| 0.43158267183230636|
|   B|  9|0.030009665330478774|
|   G|  1| 0.42449551735018554|
|   T|  9|  0.7102168420016252|
|   G|  6|  0.2548856167145863|
+----+---+--------------------+



本来想要实验一下怎么shift的。结果发现在spark里面，这类操作必须通过window

In [41]:
w = Window().partitionBy("name").orderBy(col("name"))
df.select("*", lag("age").over(w).alias("new_col")).show()

+----+---+--------------------+-------+
|name|age|              salary|new_col|
+----+---+--------------------+-------+
|   K|  3| 0.02161606520352055|   null|
|   Q|  3|  0.6326711203749021|   null|
|   T|  9|  0.7102168420016252|   null|
|   B|  9|0.030009665330478774|   null|
|   L|  1|  0.7336759510839976|   null|
|   J| 10|  0.5438762312917514|   null|
|   X|  9| 0.43158267183230636|   null|
|   S|  8| 0.43131405222650354|   null|
|   G|  1| 0.42449551735018554|   null|
|   G|  6|  0.2548856167145863|      1|
+----+---+--------------------+-------+



类似于禁用orderby

In [42]:
w = Window().partitionBy().orderBy(lit(1))
df.select("*", lag("age",2).over(w).alias("new_col")).show()

+----+---+--------------------+-------+
|name|age|              salary|new_col|
+----+---+--------------------+-------+
|   K|  3| 0.02161606520352055|   null|
|   Q|  3|  0.6326711203749021|   null|
|   S|  8| 0.43131405222650354|      3|
|   J| 10|  0.5438762312917514|      3|
|   L|  1|  0.7336759510839976|      8|
|   X|  9| 0.43158267183230636|     10|
|   B|  9|0.030009665330478774|      1|
|   G|  1| 0.42449551735018554|      9|
|   T|  9|  0.7102168420016252|      9|
|   G|  6|  0.2548856167145863|      1|
+----+---+--------------------+-------+



In [43]:
w = Window().partitionBy().orderBy(lit(1))
df.select("*", lead("age",2).over(w).alias("new_col")).show()

+----+---+--------------------+-------+
|name|age|              salary|new_col|
+----+---+--------------------+-------+
|   K|  3| 0.02161606520352055|      8|
|   Q|  3|  0.6326711203749021|     10|
|   S|  8| 0.43131405222650354|      1|
|   J| 10|  0.5438762312917514|      9|
|   L|  1|  0.7336759510839976|      9|
|   X|  9| 0.43158267183230636|      1|
|   B|  9|0.030009665330478774|      9|
|   G|  1| 0.42449551735018554|      6|
|   T|  9|  0.7102168420016252|   null|
|   G|  6|  0.2548856167145863|   null|
+----+---+--------------------+-------+

