In [1]:
import time
import csv
import pickle
import operator
import datetime
import os
from tqdm import tqdm

In [2]:
def example(data):
    Y=[]
    with open(data, 'r') as fn:
        for ix,line in enumerate(fn):
            if ix<=5:
                Y.append(line.strip("\n"))
    
    return Y

In [3]:
dataset_path = '/home/ec2-user/SageMaker/sequence-based-recommendation/YOOCHOOSE_data'
# yoochoose-clicks
with open(os.path.join(dataset_path,'yoochoose-clicks.dat'), 'r') as f, open(os.path.join(dataset_path,'yoochoose-clicks-withHeader.dat'), 'w') as fn:
    fn.write('sessionId,timestamp,itemId,category'+'\n')
    for line in f:
        fn.write(line) 
        
example('yoochoose-clicks-withHeader.dat')

['sessionId,timestamp,itemId,category',
 '1,2014-04-07T10:51:09.277Z,214536502,0',
 '1,2014-04-07T10:54:09.868Z,214536500,0',
 '1,2014-04-07T10:54:46.998Z,214536506,0',
 '1,2014-04-07T10:57:00.306Z,214577561,0',
 '2,2014-04-07T13:56:37.614Z,214662742,0']

In [4]:
## yoochoose-clicks
with open(os.path.join(dataset_path,'yoochoose-buys.dat'), 'r') as f, open(os.path.join(dataset_path,'yoochoose-buys-withHeader.dat'), 'w') as fn:
    fn.write('sessionId,timestamp,itemId,price,quantity'+'\n')
    for line in f:
        fn.write(line) 
        
example('yoochoose-buys-withHeader.dat')

['sessionId,timestamp,itemId,price,quantity',
 '420374,2014-04-06T18:44:58.314Z,214537888,12462,1',
 '420374,2014-04-06T18:44:58.325Z,214537850,10471,1',
 '281626,2014-04-06T09:40:13.032Z,214535653,1883,1',
 '420368,2014-04-04T06:13:28.848Z,214530572,6073,1',
 '420368,2014-04-04T06:13:28.858Z,214835025,2617,1']

In [5]:
## yoochoose-clicks
with open(os.path.join(dataset_path,'yoochoose-test.dat'), 'r') as f, open(os.path.join(dataset_path,'yoochoose-test-withHeader.dat'), 'w') as fn:
    fn.write('sessionId,timestamp,itemId,category'+'\n')
    for line in f:
        fn.write(line) 
        
example('yoochoose-test-withHeader.dat')

['sessionId,timestamp,itemId,category',
 '5,2014-04-07T17:13:46.713Z,214530776,0',
 '5,2014-04-07T17:20:56.973Z,214530776,0',
 '5,2014-04-07T17:21:19.602Z,214530776,0',
 '10,2014-04-04T07:44:14.590Z,214820942,0',
 '10,2014-04-04T07:45:20.245Z,214826810,0']

In [6]:
dataset = os.path.join(dataset_path,'yoochoose-clicks-withHeader.dat')

with open(dataset, "r") as f:
    total_rows=sum(1 for line in f)
    
with open(dataset, "r") as f:
    reader = csv.DictReader(f, delimiter=',')
    sess_clicks = {}
    sess_date = {}
    ctr = 0
    curid = -1
    curdate = None
    
    for data in tqdm(reader,total=total_rows,leave=True,position=0):
        sessid = data['sessionId']
        if curdate and not curid == sessid:
            date = ''
            date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
            sess_date[curid] = date
        curid = sessid
        item = data['itemId']
        curdate = ''
        curdate = data['timestamp']

        if sessid in sess_clicks:
            sess_clicks[sessid] += [item]
        else:
            sess_clicks[sessid] = [item]
        ctr += 1
    date = ''
    date = time.mktime(time.strptime(curdate[:19], '%Y-%m-%dT%H:%M:%S'))
    sess_date[curid] = date

100%|█████████▉| 33003944/33003945 [03:50<00:00, 142923.21it/s]


In [7]:
# Filter out length 1 sessions
print("{:<30}{:<20,}".format("Before filtering out", len(sess_clicks)))
for s in list(sess_clicks):
    if len(sess_clicks[s]) == 1:
        del sess_clicks[s]
        del sess_date[s]

print("{:<30}{:<20,}".format("After filtering out", len(sess_clicks)))

Before filtering out          9,249,729           
After filtering out           7,990,018           


In [8]:
# Count number of times each item appears
iid_counts = {}
for s in sess_clicks:
    seq = sess_clicks[s]
    for iid in seq:
        if iid in iid_counts:
            iid_counts[iid] += 1
        else:
            iid_counts[iid] = 1

sorted_counts = sorted(iid_counts.items(), key=operator.itemgetter(1))

length = len(sess_clicks)
print("{:<30}{:<20,}".format("Before filtering out", len(sess_clicks)))

for s in list(sess_clicks):
    curseq = sess_clicks[s]
    filseq = list(filter(lambda i: iid_counts[i] >= 5, curseq))
    if len(filseq) < 2:
        del sess_clicks[s]
        del sess_date[s]
    else:
        sess_clicks[s] = filseq
        
print("{:<30}{:<20,}".format("After filtering out", len(sess_clicks)))

Before filtering out          7,990,018           
After filtering out           7,981,581           


In [9]:
# Split out test set based on dates
dates = list(sess_date.items())
maxdate = dates[0][1]

for _, date in dates:
    if maxdate < date:
        maxdate = date

# 7 days for test
splitdate = 0
splitdate = maxdate - 86400 * 1  # the number of seconds for a day：86400

print('Splitting date', splitdate)      # Yoochoose: ('Split date', 1411930799.0)
tra_sess = filter(lambda x: x[1] < splitdate, dates)
tes_sess = filter(lambda x: x[1] > splitdate, dates)

# Sort sessions by date
tra_sess = sorted(tra_sess, key=operator.itemgetter(1))     # [(sessionId, timestamp), (), ]
tes_sess = sorted(tes_sess, key=operator.itemgetter(1))     # [(sessionId, timestamp), (), ]
print(len(tra_sess))    # 186670    # 7966257
print(len(tes_sess))    # 15979     # 15324
print(tra_sess[:3])
print(tes_sess[:3])

Splitting date 1411959599.0
7966257
15324
[('171168', 1396321232.0), ('345618', 1396321275.0), ('263073', 1396321302.0)]
[('11532683', 1411959653.0), ('11464959', 1411959671.0), ('11296119', 1411959695.0)]


In [10]:
# Choosing item count >=5 gives approximately the same number of items as reported in paper
item_dict = {}
# Convert training sessions to sequences and renumber items to start from 1
def obtian_tra():
    train_ids = []
    train_seqs = []
    train_dates = []
    item_ctr = 1
    for s, date in tqdm(tra_sess,total=len(tra_sess),leave=True,position=0):
        seq = sess_clicks[s]
        outseq = []
        for i in seq:
            if i in item_dict:
                outseq += [item_dict[i]]
            else:
                outseq += [item_ctr]
                item_dict[i] = item_ctr
                item_ctr += 1
        if len(outseq) < 2:  # Doesn't occur
            continue
        train_ids += [s]
        train_dates += [date]
        train_seqs += [outseq]
    print(item_ctr)     # 43098, 37484
    return train_ids, train_dates, train_seqs

# Convert test sessions to sequences, ignoring items that do not appear in training set
def obtian_tes():
    test_ids = []
    test_seqs = []
    test_dates = []
    for s, date in tqdm(tes_sess,total=len(tes_sess),leave=True,position=0):
        seq = sess_clicks[s]
        outseq = []
        for i in seq:
            if i in item_dict:
                outseq += [item_dict[i]]
        if len(outseq) < 2:
            continue
        test_ids += [s]
        test_dates += [date]
        test_seqs += [outseq]
    return test_ids, test_dates, test_seqs


tra_ids, tra_dates, tra_seqs = obtian_tra()
tes_ids, tes_dates, tes_seqs = obtian_tes()


100%|██████████| 7966257/7966257 [00:38<00:00, 209074.23it/s]


37484


100%|██████████| 15324/15324 [00:00<00:00, 332985.43it/s]


In [11]:
def process_seqs(iseqs, idates):
    out_seqs = []
    out_dates = []
    labs = []
    ids = []
    for id, seq, date in zip(range(len(iseqs)), iseqs, idates):
        for i in range(1, len(seq)):
            tar = seq[-i]
            labs += [tar]
            out_seqs += [seq[:-i]]
            out_dates += [date]
            ids += [id]
    return out_seqs, out_dates, labs, ids


tr_seqs, tr_dates, tr_labs, tr_ids = process_seqs(tra_seqs, tra_dates)
te_seqs, te_dates, te_labs, te_ids = process_seqs(tes_seqs, tes_dates)
tra = (tr_seqs, tr_labs)
tes = (te_seqs, te_labs)
print(len(tr_seqs))
print(len(te_seqs))
print(tr_seqs[:3], tr_dates[:3], tr_labs[:3])
print(te_seqs[:3], te_dates[:3], te_labs[:3])

23670982
55898
[[1], [3], [5, 5]] [1396321232.0, 1396321275.0, 1396321302.0] [2, 4, 5]
[[33611, 37169, 6409], [33611, 37169], [33611]] [1411959653.0, 1411959653.0, 1411959653.0] [33128, 6409, 37169]


In [12]:
all = 0

for seq in tra_seqs:
    all += len(seq)
for seq in tes_seqs:
    all += len(seq)
print('avg length: ', all/(len(tra_seqs) + len(tes_seqs) * 1.0))

avg length:  3.9727042800167034


In [13]:
if not os.path.exists('yoochoose1_4'):
    os.makedirs('yoochoose1_4')
if not os.path.exists('yoochoose1_64'):
    os.makedirs('yoochoose1_64')
pickle.dump(tes, open('yoochoose1_4/test.txt', 'wb'))
pickle.dump(tes, open('yoochoose1_64/test.txt', 'wb'))

pickle.dump(tes_seqs, open('yoochoose1_4/all_test_seq.txt', 'wb'))
pickle.dump(tes_seqs, open('yoochoose1_64/all_test_seq.txt', 'wb'))

In [14]:
split4, split64 = int(len(tr_seqs) / 4), int(len(tr_seqs) / 64)
print(len(tr_seqs[-split4:]))
print(len(tr_seqs[-split64:]))

tra4, tra64 = (tr_seqs[-split4:], tr_labs[-split4:]), (tr_seqs[-split64:], tr_labs[-split64:])
seq4, seq64 = tra_seqs[tr_ids[-split4]:], tra_seqs[tr_ids[-split64]:]

pickle.dump(tra4, open('yoochoose1_4/train.txt', 'wb'))
pickle.dump(seq4, open('yoochoose1_4/all_train_seq.txt', 'wb'))

pickle.dump(tra64, open('yoochoose1_64/train.txt', 'wb'))
pickle.dump(seq64, open('yoochoose1_64/all_train_seq.txt', 'wb'))

5917745
369859
