In [1]:
BUCKET='hohukelkazan'

# Prepare

以下の流れはベイズ分類の時と変わらない

1. トレーニング日に関するデータのcsvの読み込み(スキーマ: header)
2. 読み込んだトレーニング日に関するデータをSpark SQLビューに変換
3. 飛行データのスキーマ定義
4. 飛行データのcsvの読み込み(スキーマ: 3で読み込んだやつ)
5. 読み込んだ飛行データをSpark SQLビューに変換
6. 2,5の結果をjoinさせてトレーニングデータ作成


In [2]:
# read
traindays = spark.read \
    .option("header", "true") \
    .csv('gs://{}/flights/trainday.csv'.format(BUCKET))

In [3]:
# Spark SQLビューに変換
traindays.createOrReplaceTempView('traindays')

In [4]:
# スキーマ定義(TAXI_OUTも数値型にする)
from pyspark.sql.types import StringType, FloatType, StructType, StructField

header = 'FL_DATE,UNIQUE_CARRIER,AIRLINE_ID,CARRIER,FL_NUM,ORIGIN_AIRPORT_ID,ORIGIN_AIRPORT_SEQ_ID,ORIGIN_CITY_MARKET_ID,ORIGIN,DEST_AIRPORT_ID,DEST_AIRPORT_SEQ_ID,DEST_CITY_MARKET_ID,DEST,CRS_DEP_TIME,DEP_TIME,DEP_DELAY,TAXI_OUT,WHEELS_OFF,WHEELS_ON,TAXI_IN,CRS_ARR_TIME,ARR_TIME,ARR_DELAY,CANCELLED,CANCELLATION_CODE,DIVERTED,DISTANCE,DEP_AIRPORT_LAT,DEP_AIRPORT_LON,DEP_AIRPORT_TZOFFSET,ARR_AIRPORT_LAT,ARR_AIRPORT_LON,ARR_AIRPORT_TZOFFSET,EVENT,NOTIFY_TIME'

def get_structfield(colname):
  if colname in ['ARR_DELAY', 'DEP_DELAY', 'DISTANCE', 'TAXI_OUT']:
    return StructField(colname, FloatType(), True)
  else:
    return StructField(colname, StringType(), True)
  
schema = StructType([get_structfield(colname) for colname in header.split(',')])
print(schema)

StructType(List(StructField(FL_DATE,StringType,true),StructField(UNIQUE_CARRIER,StringType,true),StructField(AIRLINE_ID,StringType,true),StructField(CARRIER,StringType,true),StructField(FL_NUM,StringType,true),StructField(ORIGIN_AIRPORT_ID,StringType,true),StructField(ORIGIN_AIRPORT_SEQ_ID,StringType,true),StructField(ORIGIN_CITY_MARKET_ID,StringType,true),StructField(ORIGIN,StringType,true),StructField(DEST_AIRPORT_ID,StringType,true),StructField(DEST_AIRPORT_SEQ_ID,StringType,true),StructField(DEST_CITY_MARKET_ID,StringType,true),StructField(DEST,StringType,true),StructField(CRS_DEP_TIME,StringType,true),StructField(DEP_TIME,StringType,true),StructField(DEP_DELAY,FloatType,true),StructField(TAXI_OUT,FloatType,true),StructField(WHEELS_OFF,StringType,true),StructField(WHEELS_ON,StringType,true),StructField(TAXI_IN,StringType,true),StructField(CRS_ARR_TIME,StringType,true),StructField(ARR_TIME,StringType,true),StructField(ARR_DELAY,FloatType,true),StructField(CANCELLED,StringType,true),

In [5]:
inputs = 'gs://{}/flights/tzcorr/all_flights-00000-*'.format(BUCKET) # 1/30th
#inputs = 'gs://{}/flights/tzcorr/all_flights-*'.format(BUCKET)
flights_csv = spark.read\
            .schema(schema)\
            .csv(inputs)
# tmpビュー作成
flights_csv.createOrReplaceTempView('flights')

In [6]:
# create train data
traindayquery = """
select
  f.*
from flights f
join traindays t
on f.FL_DATE == t.FL_DATE
where
  t.is_train_day == 'True'
"""
traindata = spark.sql(traindayquery)

In [7]:
# describe()メソッドは列ごとの統計情報を計算して、show()はその結果を表示するよ
# NULLはカウントしないからcountに差がでるよ(ex. スケジュールされたけど出発しなかった場合など)
traindata.describe().show()

+-------+----------+--------------+------------------+-------+------------------+------------------+---------------------+---------------------+------+------------------+-------------------+-------------------+-----+-------------------+-------------------+------------------+------------------+-------------------+-------------------+------------------+-------------------+-------------------+-------------------+--------------------+-----------------+--------------------+------------------+------------------+------------------+--------------------+-----------------+------------------+--------------------+-----+-----------+
|summary|   FL_DATE|UNIQUE_CARRIER|        AIRLINE_ID|CARRIER|            FL_NUM| ORIGIN_AIRPORT_ID|ORIGIN_AIRPORT_SEQ_ID|ORIGIN_CITY_MARKET_ID|ORIGIN|   DEST_AIRPORT_ID|DEST_AIRPORT_SEQ_ID|DEST_CITY_MARKET_ID| DEST|       CRS_DEP_TIME|           DEP_TIME|         DEP_DELAY|          TAXI_OUT|         WHEELS_OFF|          WHEELS_ON|           TAXI_IN|       CRS_ARR_TIME

 -> すべてのcount数が同じになった  
 ただ根本解決にはなってない
 -> キャンセル、迂回が発生したかを表す絡むを見てレコードを絞った方が確実

In [8]:
query = """
select
  cancelled, diverted
from flights
limit 10
"""
spark.sql(query).show()

query = """
select
  cancelled, diverted
from flights
where
  cancelled != '0.00'
limit 10
"""
spark.sql(query).show()

query = """
select
  cancelled, diverted
from flights
where
  diverted != '0.00'
limit 10
"""
spark.sql(query).show()

+---------+--------+
|cancelled|diverted|
+---------+--------+
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
|     0.00|    0.00|
+---------+--------+

+---------+--------+
|cancelled|diverted|
+---------+--------+
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
|     1.00|    0.00|
+---------+--------+

+---------+--------+
|cancelled|diverted|
+---------+--------+
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
|     0.00|    1.00|
+---------+--------+



どうやらキャンセルされたか、迂回されたかは0,1で格納されているみたい

In [9]:
# もっとしっかり取り除こう
traindayquery = """
select
  f.*
from flights f
join traindays t
on f.FL_DATE == t.FL_DATE
where
  t.is_train_day == 'True' and
  f.cancelled == '0.00' and
  f.diverted == '0.00'
"""
traindata = spark.sql(traindayquery)
traindata.describe().show()

+-------+----------+--------------+------------------+-------+------------------+------------------+---------------------+---------------------+------+------------------+-------------------+-------------------+-----+-------------------+-------------------+------------------+------------------+-------------------+-------------------+-----------------+-------------------+-------------------+-------------------+---------+-----------------+--------+------------------+------------------+------------------+--------------------+------------------+------------------+--------------------+-----+-----------+
|summary|   FL_DATE|UNIQUE_CARRIER|        AIRLINE_ID|CARRIER|            FL_NUM| ORIGIN_AIRPORT_ID|ORIGIN_AIRPORT_SEQ_ID|ORIGIN_CITY_MARKET_ID|ORIGIN|   DEST_AIRPORT_ID|DEST_AIRPORT_SEQ_ID|DEST_CITY_MARKET_ID| DEST|       CRS_DEP_TIME|           DEP_TIME|         DEP_DELAY|          TAXI_OUT|         WHEELS_OFF|          WHEELS_ON|          TAXI_IN|       CRS_ARR_TIME|           ARR_TIME|   

NULLを除外した時と同じ結果が得られた

# トレーニング

トレーニングデータの各レコードは、LabeldedPointクラスに変換する必要がある
- https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html

In [10]:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
from pyspark.mllib.regression import LabeledPoint

In [11]:
def to_example(fields):
  return LabeledPoint(\
           float(fields['ARR_DALAY'] < 15), # on-time? \
           [\
             fields['DEP_DELAY'], \
             fields['TAXI_OUT'], \
             fields['DISTANCE'], \
           ])

In [12]:
# トレーニング用のラベルと入力変数だけを取り出したデータ
examples = traindata.rdd.map(to_example)

### トレーニング実施

In [13]:
# intercept: 切片
lrmodel = LogisticRegressionWithLBFGS.train(examples, intercept=True)
print(lrmodel.weights, lrmodel.intercept)

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 11.0 failed 4 times, most recent failure: Lost task 0.3 in stage 11.0 (TID 14, ch6cluster-w-1.us-central1-a.c.hohukelkazan.internal, executor 1): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 181, in main
    ("%d.%d" % sys.version_info[:2], version))
Exception: Python in worker has different version 2.7 than that in driver 3.5, PySpark cannot run with different minor versions.Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:336)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:475)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:458)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:290)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.api.python.PythonRDD$$anonfun$3.apply(PythonRDD.scala:152)
	at org.apache.spark.api.python.PythonRDD$$anonfun$3.apply(PythonRDD.scala:152)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1661)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1649)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1648)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1648)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1882)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1820)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074)
	at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:152)
	at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/worker.py", line 181, in main
    ("%d.%d" % sys.version_info[:2], version))
Exception: Python in worker has different version 2.7 than that in driver 3.5, PySpark cannot run with different minor versions.Please check environment variables PYSPARK_PYTHON and PYSPARK_DRIVER_PYTHON are correctly set.

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:336)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:475)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:458)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:290)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$class.foreach(Iterator.scala:893)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
	at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.api.python.PythonRDD$$anonfun$3.apply(PythonRDD.scala:152)
	at org.apache.spark.api.python.PythonRDD$$anonfun$3.apply(PythonRDD.scala:152)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
	at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
