In [143]:
import sys
sys.path.append('..')
import pandas as pd
import numpy as np
import sys
import json
import jsonlines
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from advsber.utils.data import write_jsonlines

In [144]:
n_weeks = 10
DATASET_NAME = 'gender'
DATASET_READ_PATH = "../data"
DATASET_SAVE_PATH = 'datasets'
index= n_weeks
NUM_WEEKS = 24//n_weeks
MIN_LEN = 3
MAX_LEN = 50*n_weeks
TEST_RATIO = 0.1
SUBST_RATIO = 0.3
VALID_RATIO = 0.2
LM_RATIO = 0.1
NUM_DAYS = 7*n_weeks


In [145]:
if DATASET_NAME == 'age':
    data = pd.read_csv(f'{DATASET_READ_PATH}/{DATASET_NAME}/transactions_train.csv')
    target_data = pd.read_csv(f'{DATASET_READ_PATH}/{DATASET_NAME}/train_target.csv')
    data['week'] = data['trans_date'] // NUM_DAYS
else:
    transactions = pd.read_csv(f'{DATASET_READ_PATH}/{DATASET_NAME}/transactions.csv')
    target_data = pd.read_csv(f'{DATASET_READ_PATH}/{DATASET_NAME}/gender_train.csv')
    data = transactions.rename(columns={'customer_id': 'client_id', 'amount':'amount_rur'})
    target_data = target_data.rename(columns={'customer_id':'client_id', 'gender':'bins'})
    data['week'] = data['tr_datetime'].str.split(' ').apply(lambda x: int(x[0]) // 7)
    data['small_group'] = data['mcc_code'].tolist()
target_data_dict = dict(target_data.values)

In [None]:
transactions = data.groupby(['client_id', 'week']).agg(list)

In [None]:
my_lovely_data_raw = []

for idx, (_, row) in tqdm(enumerate(transactions.iterrows())):
    client_id, week = row.name

    my_lovely_data_raw.append(
        {
            'transactions': row['small_group'],
            'amounts': row.amount_rur,
            'client_id': client_id, 
            'week': week
        }
    )

In [None]:
my_lovely_data = pd.DataFrame(my_lovely_data_raw)

In [None]:
my_lovely_data = my_lovely_data[(my_lovely_data['week'] < NUM_WEEKS)]

In [None]:
my_lovely_data['label'] = my_lovely_data['client_id'].apply(lambda x: target_data_dict.get(x))
my_lovely_data = my_lovely_data[~my_lovely_data['label'].isna()]
my_lovely_data['label'] = my_lovely_data['label'].astype(int)
my_lovely_data = my_lovely_data[['transactions', 'amounts', 'client_id', 'label']]

In [None]:
lens = my_lovely_data.transactions.apply(lambda x: len(x))

my_lovely_data = my_lovely_data[(lens >= MIN_LEN) & (lens <= MAX_LEN)]

In [None]:
lm_train, lm_valid = train_test_split(
    my_lovely_data, 
    stratify=my_lovely_data['label'], 
    random_state=126663,
    test_size=LM_RATIO
)

other_data, test_data = train_test_split(
    my_lovely_data, 
    stratify=my_lovely_data['label'], 
    random_state=123,
    test_size=TEST_RATIO
)

target_data, subst_data = train_test_split(
    other_data, 
    stratify=other_data['label'], 
    random_state=123,
    test_size=SUBST_RATIO
)

target_data_tr, target_data_val = train_test_split(
    target_data, 
    stratify=target_data['label'], 
    random_state=123,
    test_size=VALID_RATIO
)

subst_data_tr, subst_data_val = train_test_split(
    subst_data, 
    stratify=subst_data['label'], 
    random_state=123,
    test_size=VALID_RATIO
)

In [None]:
NAME =DATASET_SAVE_PATH+'/' + DATASET_NAME + '_' + str(index)
!mkdir "$DATASET_SAVE_PATH"
!mkdir "$NAME"
!mkdir "$NAME/target_clf"
!mkdir "$NAME/substitute_clf"
!mkdir "$NAME/lm"
write_jsonlines(test_data.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/test.jsonl')

write_jsonlines(target_data_tr.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/target_clf/train.jsonl')
write_jsonlines(target_data_val.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/target_clf/valid.jsonl')

write_jsonlines(subst_data_tr.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/substitute_clf/train.jsonl')
write_jsonlines(subst_data_val.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/substitute_clf/valid.jsonl')

write_jsonlines(lm_train.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/lm/train.jsonl')
write_jsonlines(lm_valid.to_dict('records'), f'{DATASET_SAVE_PATH}/{DATASET_NAME}_{index}/lm/valid.jsonl')