In [1]:
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division

In [2]:
import numpy as np
import pandas as pd

In [3]:
import datetime as dt

In [4]:
from pyspark.sql import Row

In [5]:
spark

# Definition

In [6]:
start_ts = dt.datetime.strptime("2017-01-02 01:00:00", '%Y-%m-%d %H:%M:%S')
end_ts = dt.datetime.strptime("2017-01-04 00:30:00", '%Y-%m-%d %H:%M:%S')

In [7]:
def parse_slot(d):
    slot = 3
    week_day = d.weekday()
    
    if d < dt.datetime(d.year, d.month, d.day, 1, 0, 0):
        slot = 3
    elif d < dt.datetime(d.year, d.month, d.day, 9, 0, 0):
        slot = 0
    elif d < dt.datetime(d.year, d.month, d.day, 17, 0, 0):
        slot = 1
    elif d < dt.datetime(d.year, d.month, d.day, 21, 0, 0):
        slot = 2
  
    return slot

In [8]:
parse_slot(end_ts)

3

In [9]:
def get_dime(d):
    return (d - start_ts).days * 4 + parse_slot(d)

In [10]:
get_dime(end_ts)

7

# Load data

In [11]:
def parse_evt(p):
    evt_time = dt.datetime.strptime(p[4][:19], '%Y-%m-%d %H:%M:%S')
    time_diff = evt_time - start_ts 
    dim = get_dime(evt_time)
    
    r = Row(
        user_id=p[0],
        device_id=p[1],
        session_id=p[2],
        title_id=p[3],
        event_time=evt_time,
        played_duration=float(p[5]),
        action_trigger=p[6],
        platform=p[7],
        episode_number=int(p[8]),
        series_total_episodes_count=int(p[9]),
        internet_connection_type=p[10],
        is_trailer=bool(p[11]),
        month=evt_time.month,
        day=evt_time.day,
        hour=evt_time.hour,
        week = (evt_time - start_ts).days // 7,
        slot=parse_slot(evt_time),
        dime=dim
    )
    return r

In [16]:
data_header = sc.textFile('kktv-17-11/train').take(1)[0]

In [17]:
data_header

u'user_id,device_id,session_id,title_id,event_time,played_duration,action_trigger,platform,episode_number,series_total_episodes_count,internet_connection_type,is_trailer'

In [18]:
label_header = sc.textFile('kktv-17-11/label').take(1)[0] 

## Training  

In [24]:
train_label_df = sc.textFile('kktv-17-11/label') \
    .filter(lambda l: l != label_header) \
    .map(lambda l: l.split(',')) \
    .map(lambda p: [p[0]] + map(lambda d: float(d), p[1:])) \
    .toDF(label_header.split(','))

In [26]:
train_label_df.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- time_slot_0: double (nullable = true)
 |-- time_slot_1: double (nullable = true)
 |-- time_slot_2: double (nullable = true)
 |-- time_slot_3: double (nullable = true)
 |-- time_slot_4: double (nullable = true)
 |-- time_slot_5: double (nullable = true)
 |-- time_slot_6: double (nullable = true)
 |-- time_slot_7: double (nullable = true)
 |-- time_slot_8: double (nullable = true)
 |-- time_slot_9: double (nullable = true)
 |-- time_slot_10: double (nullable = true)
 |-- time_slot_11: double (nullable = true)
 |-- time_slot_12: double (nullable = true)
 |-- time_slot_13: double (nullable = true)
 |-- time_slot_14: double (nullable = true)
 |-- time_slot_15: double (nullable = true)
 |-- time_slot_16: double (nullable = true)
 |-- time_slot_17: double (nullable = true)
 |-- time_slot_18: double (nullable = true)
 |-- time_slot_19: double (nullable = true)
 |-- time_slot_20: double (nullable = true)
 |-- time_slot_21: double (nullable = true

In [27]:
train_label_df.select('user_id', 'time_slot_0', 'time_slot_15', 'time_slot_27').show()

+-------+-----------+------------+------------+
|user_id|time_slot_0|time_slot_15|time_slot_27|
+-------+-----------+------------+------------+
|      0|        0.0|         0.0|         0.0|
|      1|        0.0|         0.0|         0.0|
|      2|        0.0|         0.0|         0.0|
|      3|        0.0|         0.0|         0.0|
|      4|        0.0|         0.0|         0.0|
|      5|        0.0|         0.0|         0.0|
|      6|        0.0|         0.0|         0.0|
|      7|        1.0|         0.0|         0.0|
|      8|        0.0|         0.0|         0.0|
|      9|        0.0|         0.0|         0.0|
|     10|        0.0|         0.0|         0.0|
|     11|        1.0|         1.0|         0.0|
|     12|        0.0|         0.0|         0.0|
|     13|        1.0|         0.0|         0.0|
|     14|        0.0|         0.0|         0.0|
|     15|        0.0|         0.0|         0.0|
|     16|        0.0|         0.0|         0.0|
|     17|        0.0|         0.0|      

In [28]:
train_data_df = sc.textFile('kktv-17-11/train') \
    .filter(lambda e: e != header) \
    .map(lambda e: e.split(',')) \
    .map(parse_evt).toDF() \
    .where('event_time >= "2017-01-02 01:00:00" and event_time < "2017-08-14 01:00:00"')

In [30]:
train_data_df.select('user_id', 'event_time', 'month', 'day', 'hour', 'week', 'slot', 'dime').sample(False, 0.1).show()

+-------+-------------------+-----+---+----+----+----+----+
|user_id|         event_time|month|day|hour|week|slot|dime|
+-------+-------------------+-----+---+----+----+----+----+
|      0|2017-06-08 14:54:38|    6|  8|  14|  22|   1| 629|
|      0|2017-06-08 14:54:41|    6|  8|  14|  22|   1| 629|
|      0|2017-06-08 14:57:26|    6|  8|  14|  22|   1| 629|
|      0|2017-06-11 07:11:54|    6| 11|   7|  22|   0| 640|
|      0|2017-06-11 08:17:39|    6| 11|   8|  22|   0| 640|
|      0|2017-06-27 14:54:15|    6| 27|  14|  25|   1| 705|
|      2|2017-07-30 03:04:41|    7| 30|   3|  29|   0| 836|
|      2|2017-07-30 03:05:14|    7| 30|   3|  29|   0| 836|
|      3|2017-05-28 15:18:18|    5| 28|  15|  20|   1| 585|
|      3|2017-05-28 15:34:22|    5| 28|  15|  20|   1| 585|
|      3|2017-05-28 15:55:41|    5| 28|  15|  20|   1| 585|
|      3|2017-05-28 16:05:13|    5| 28|  16|  20|   1| 585|
|      3|2017-05-28 16:07:09|    5| 28|  16|  20|   1| 585|
|      3|2017-05-28 16:15:08|    5| 28| 

In [33]:
import pyspark.sql.functions as sqlf

In [36]:
train_dime_df = train_data_df \
    .groupBy('user_id') \
    .pivot("dime", range(896)) \
    .agg(sqlf.expr('count(*) as double'))

In [38]:
train_dime_df.select('user_id', '1', '2', '895').show()

+-------+----+----+----+
|user_id|   1|   2| 895|
+-------+----+----+----+
|  10096|   2|null|null|
|   1090|null|null|null|
|  12847|null|null|null|
|  13865|null|null|null|
|  14204|null|null|null|
|  14899|null|null|null|
|  15634|   2|null|null|
|  16320|null|null|null|
|  17686|  11|null|null|
|  18634|null|null|null|
|  19095|null|null|null|
|  20569|null|null|null|
|  21248|null|null|null|
|  21249| 109|null|null|
|  21259|null|null|null|
|  21452|null|null|null|
|  22728|null|null|null|
|  23918|null|null|null|
|  24114|null|null|null|
|  27108|   7|null|null|
+-------+----+----+----+
only showing top 20 rows



In [39]:
train_df = train_dime_df \
    .join(train_label_df, train_dime_df['user_id']==train_label_df['user_id']) \
    .drop(train_label_df['user_id'])

In [51]:
from pyspark.sql.functions import col

In [55]:
train_df = train_df.select(
    ['user_id'] + \
    [col(str(n)).cast('double').alias('{}'.format(n)) for n in range(896)] +
    ['time_slot_{}'.format(s) for s in range(28)]) \
    .na.fill(0.0)

In [56]:
train_df.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- 0: double (nullable = false)
 |-- 1: double (nullable = false)
 |-- 2: double (nullable = false)
 |-- 3: double (nullable = false)
 |-- 4: double (nullable = false)
 |-- 5: double (nullable = false)
 |-- 6: double (nullable = false)
 |-- 7: double (nullable = false)
 |-- 8: double (nullable = false)
 |-- 9: double (nullable = false)
 |-- 10: double (nullable = false)
 |-- 11: double (nullable = false)
 |-- 12: double (nullable = false)
 |-- 13: double (nullable = false)
 |-- 14: double (nullable = false)
 |-- 15: double (nullable = false)
 |-- 16: double (nullable = false)
 |-- 17: double (nullable = false)
 |-- 18: double (nullable = false)
 |-- 19: double (nullable = false)
 |-- 20: double (nullable = false)
 |-- 21: double (nullable = false)
 |-- 22: double (nullable = false)
 |-- 23: double (nullable = false)
 |-- 24: double (nullable = false)
 |-- 25: double (nullable = false)
 |-- 26: double (nullable = false)
 |-- 27: double (null

In [57]:
train_df.write.mode('overwrite').format('parquet').save('tmp/train')

In [60]:
spark.read.parquet('tmp/train').count()

57139

In [None]:
spark.read.parquet('tmp/train').show()

## Testing 

In [62]:
test_df = sc.textFile('kktv-17-11/test') \
    .filter(lambda e: e != data_header) \
    .map(lambda e: e.split(',')) \
    .map(parse_evt).toDF()

In [64]:
test_df.count()

KeyboardInterrupt: 