In [9]:
import pandas as pd
import random

In [89]:
def refine_log_data(data_frame: pd.DataFrame,
                    random_seed: int,
                    upper_bound: int) -> tuple:
    """Refine the log data to construct cross-sectional data.

    Parameter
    ========
    `data_frame`: The content spend history log data.
    `random_seed`: Control the random seed to cut the log data.
    `upper_bound`: The function will cut at least `upper_bound`th log data for each user.
    """
    random.seed(random_seed)

    # 다음을 함수화해야함
    log_count_data = derive_random_cutoff_data(data_frame, upper_bound)
    
    log_data = pd.merge(data_frame, log_count_data, on = 'user_id')
    log_data['row_id_by_user'] = log_data.groupby(['user_id'])['row_id'].rank()

    train_log_data  = log_data[log_data['row_id_by_user'] < log_data['cutoff_position']]
    test_log_data = log_data[log_data['row_id_by_user'] == log_data['cutoff_position']]

    return train_log_data, test_log_data

def derive_random_cutoff_data(log_data: pd.DataFrame,
                              upper_bound: int) -> pd.DataFrame:
    """Return the user data with cut off information.

    If the total log count for the user is less than the upper bound, return the cut off value as the log count.

    For Example
    ===========
    For the `upper_bound = 10`, 

    | user_id | total_log_count | cutoff_position |
    |---------|-----------------|-----------------|
    |    1    |      400        |       314       |
    |    2    |       2         |        2        |
    """
    def _get_cutoff_position(row_count):
        if upper_bound <= row_count:
            return random.randrange(upper_bound, row_count + 1)
        return row_count

    log_count_data = log_data.groupby(['user_id'])\
                             .agg({'row_id': 'count'})\
                             .rename(columns = {'row_id': 'total_log_count'})

    log_count_data['cutoff_position'] = log_count_data['total_log_count'].apply(lambda x: _get_cutoff_position(x))
    log_count_data = log_count_data.drop(columns = ['total_log_count'])

    return log_count_data



In [4]:
INPUT_DIR = '../resources/input'

In [87]:
log_data = pd.read_csv(f"{INPUT_DIR}/sample_train.csv")

In [90]:
refine = refine_log_data(log_data, 1000, 10)

In [92]:
refine[0]

Unnamed: 0.1,Unnamed: 0,row_id,timestamp,user_id,content_id,content_type_id,task_container_id,user_answer,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,cutoff_position,row_id_by_user
0,624,624,0,13134,3926,False,0,3,1,,,888,1.0
1,625,625,23840,13134,564,False,1,1,0,22000.0,False,888,2.0
2,626,626,46834,13134,3865,False,2,1,0,18000.0,False,888,3.0
3,627,627,64749,13134,4231,False,3,1,1,19000.0,False,888,4.0
4,628,628,113000,13134,3684,False,4,1,1,13000.0,False,888,5.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12904702,101223416,101223416,26776674687,2147331033,1113,False,72,3,1,16000.0,True,77,72.0
12904703,101223417,101223417,26776698199,2147331033,1315,False,70,1,0,15000.0,True,77,73.0
12904704,101223418,101223418,26776926698,2147331033,6282,False,73,1,0,18000.0,False,77,74.0
12904705,101223419,101223419,26776992718,2147331033,13480,False,74,2,1,10000.0,True,77,75.0


In [93]:
refine[1]

Unnamed: 0.1,Unnamed: 0,row_id,timestamp,user_id,content_id,content_type_id,task_container_id,user_answer,answered_correctly,prior_question_elapsed_time,prior_question_had_explanation,cutoff_position,row_id_by_user
887,1511,1511,16582811868,13134,6502,False,733,0,1,31000.0,True,888,888.0
1309,8570,8570,3778888369,44331,6916,False,44,3,0,54750.0,True,60,60.0
1578,13041,13041,1042673,137455,894,False,20,3,1,18000.0,True,35,35.0
3048,15656,15656,3694573675,141455,10287,False,836,0,1,23750.0,True,1455,1455.0
3916,16524,16524,1125743,142896,4568,False,9,0,0,17000.0,True,10,10.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12904518,101216958,101216958,727081,2147128056,6370,False,13,0,0,20000.0,False,22,22.0
12904545,101219045,101219045,615312,2147175651,3854,False,26,2,1,25000.0,True,27,27.0
12904560,101223175,101223175,13117599,2147307933,4221,False,13,1,0,1000.0,True,14,14.0
12904595,101223210,101223210,714348,2147308237,4420,False,20,1,1,10000.0,True,35,35.0
