In [None]:
BUCKET='hohukelkazan'

# Prepare

In [None]:
# read
traindays = spark.read \
    .option("header", "true") \
    .csv('gs://{}/flights/trainday.csv'.format(BUCKET))

In [None]:
# Spark SQLビューに変換
traindays.createOrReplaceTempView('traindays')

In [None]:
# スキーマ定義(TAXI_OUTも数値型にする)
from pyspark.sql.types import StringType, FloatType, StructType, StructField

header = 'FL_DATE,UNIQUE_CARRIER,AIRLINE_ID,CARRIER,FL_NUM,ORIGIN_AIRPORT_ID,ORIGIN_AIRPORT_SEQ_ID,ORIGIN_CITY_MARKET_ID,ORIGIN,DEST_AIRPORT_ID,DEST_AIRPORT_SEQ_ID,DEST_CITY_MARKET_ID,DEST,CRS_DEP_TIME,DEP_TIME,DEP_DELAY,TAXI_OUT,WHEELS_OFF,WHEELS_ON,TAXI_IN,CRS_ARR_TIME,ARR_TIME,ARR_DELAY,CANCELLED,CANCELLATION_CODE,DIVERTED,DISTANCE,DEP_AIRPORT_LAT,DEP_AIRPORT_LON,DEP_AIRPORT_TZOFFSET,ARR_AIRPORT_LAT,ARR_AIRPORT_LON,ARR_AIRPORT_TZOFFSET,EVENT,NOTIFY_TIME'

def get_structfield(colname):
  if colname in ['ARR_DELAY', 'DEP_DELAY', 'DISTANCE', 'TAXI_OUT']:
    return StructField(colname, FloatType(), True)
  else:
    return StructField(colname, StringType(), True)
  
schema = StructType([get_structfield(colname) for colname in header.split(',')])
print(schema)

In [None]:
inputs = 'gs://{}/flights/tzcorr/all_flights-00000-*'.format(BUCKET) # 1/30th
#inputs = 'gs://{}/flights/tzcorr/all_flights-*'.format(BUCKET)
flights_csv = spark.read\
            .schema(schema)\
            .csv(inputs)
# tmpビュー作成
flights_csv.createOrReplaceTempView('flights')

In [None]:
# check
spark.sql("SELECT * from traindays LIMIT 5").show()
spark.sql("SELECT * from flights LIMIT 5").show()

In [None]:
# create train data
traindayquery = """
select
  f.*
from flights f
join traindays t
on f.FL_DATE == t.FL_DATE
where
  t.is_train_day == 'True'
"""
traindata = spark.sql(traindayquery)

In [None]:
# check: 'ARR_DELAY', 'DEP_DELAY', 'DISTANCE', 'TAXI_OUT'がfloatになっている
print(traindata.head(1)) 

In [None]:
# describe()メソッドは列ごとの統計情報を計算して、show()はその結果を表示するよ
# NULLはカウントしないからcountに差がでるよ(ex. スケジュールされたけど出発しなかった場合など)
traindata.describe().show()

In [None]:
# じゃあNULLを除外して特殊なフライトをなくそう
traindayquery = """
select
  f.*
from flights f
join traindays t
on f.FL_DATE == t.FL_DATE
where
  t.is_train_day == 'True' and
  f.dep_delay is not NULL and
  f.arr_delay is not NULL

"""
traindata = spark.sql(traindayquery)
traindata.describe().show()

 -> すべてのcount数が同じになった
 ただ根本解決にはなってない
 -> キャンセル、迂回が発生したかを表す絡むを見てレコードを絞った方が確実

In [None]:
query = """
select
  cancelled, diverted
from flights
limit 10
"""
spark.sql(query).show()

query = """
select
  cancelled, diverted
from flights
where
  cancelled != '0.00'
limit 10
"""
spark.sql(query).show()

query = """
select
  cancelled, diverted
from flights
where
  diverted != '0.00'
limit 10
"""
spark.sql(query).show()

どうやらキャンセルされたか、迂回されたかは0,1で格納されているみたい

In [None]:
# もっとしっかり取り除こう
traindayquery = """
select
  f.*
from flights f
join traindays t
on f.FL_DATE == t.FL_DATE
where
  t.is_train_day == 'True' and
  f.cancelled == '0.00' and
  f.diverted == '0.00'
"""
traindata = spark.sql(traindayquery)
traindata.describe().show()

NULLを除外した時と同じ結果が得られた

# トレーニング

トレーニングデータの各レコードは、LabeldedPointクラスに変換する必要がある
- https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html

In [None]:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
from pyspark.mllib.regression import LabeledPoint

In [None]:
def to_example(fields):
  return LabeledPoint(\
           float(fields['ARR_DALAY'] < 15), # on-time? \
           [\
             fields['DEP_DELAY'], \
             fields['TAXI_OUT'], \
             fields['DISTANCE'], \
           ])

In [None]:
# トレーニング用のラベルと入力変数だけを取り出したデータ
examples = traindata.rdd.map(to_example)

### トレーニング実施

In [None]:
# intercept: 切片
lrmodel = LogisticRegressionWithLBFGS.train(examples, intercept=True)
print(lrmodel.weights, lrmodel.intercept)