In [6]:
import pandas as pd
from tqdm import tqdm
import json
from collections import defaultdict

In [7]:
class LabelField:
    def __init__(self):
        self.label2id = dict()
        self.label_num = 0

    def get_id(self, label):

        if label in self.label2id:
            return self.label2id[label]

        self.label2id[label] = self.label_num
        self.label_num += 1

        return self.label2id[label]

In [8]:

data = pd.read_csv('../data/online_retail/data.csv', encoding='ISO-8859-1')
data.head()

Unnamed: 0,InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
0,536365,85123A,WHITE HANGING HEART T-LIGHT HOLDER,6,12/1/2010 8:26,2.55,17850.0,United Kingdom
1,536365,71053,WHITE METAL LANTERN,6,12/1/2010 8:26,3.39,17850.0,United Kingdom
2,536365,84406B,CREAM CUPID HEARTS COAT HANGER,8,12/1/2010 8:26,2.75,17850.0,United Kingdom
3,536365,84029G,KNITTED UNION FLAG HOT WATER BOTTLE,6,12/1/2010 8:26,3.39,17850.0,United Kingdom
4,536365,84029E,RED WOOLLY HOTTIE WHITE HEART.,6,12/1/2010 8:26,3.39,17850.0,United Kingdom


In [9]:
data['InvoiceNo'].nunique()

25900

In [10]:
import datetime
# to unix timestamp
data['InvoiceDate'] = data['InvoiceDate'].apply(lambda x: pd.to_datetime(x).timestamp())
data.head()

Unnamed: 0,InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
0,536365,85123A,WHITE HANGING HEART T-LIGHT HOLDER,6,1291192000.0,2.55,17850.0,United Kingdom
1,536365,71053,WHITE METAL LANTERN,6,1291192000.0,3.39,17850.0,United Kingdom
2,536365,84406B,CREAM CUPID HEARTS COAT HANGER,8,1291192000.0,2.75,17850.0,United Kingdom
3,536365,84029G,KNITTED UNION FLAG HOT WATER BOTTLE,6,1291192000.0,3.39,17850.0,United Kingdom
4,536365,84029E,RED WOOLLY HOTTIE WHITE HEART.,6,1291192000.0,3.39,17850.0,United Kingdom


In [11]:
data = data[['Description', 'InvoiceDate', 'InvoiceNo', 'StockCode']]
data.head()

Unnamed: 0,Description,InvoiceDate,InvoiceNo,StockCode
0,WHITE HANGING HEART T-LIGHT HOLDER,1291192000.0,536365,85123A
1,WHITE METAL LANTERN,1291192000.0,536365,71053
2,CREAM CUPID HEARTS COAT HANGER,1291192000.0,536365,84406B
3,KNITTED UNION FLAG HOT WATER BOTTLE,1291192000.0,536365,84029G
4,RED WOOLLY HOTTIE WHITE HEART.,1291192000.0,536365,84029E


In [12]:
data.info()
# 16,520 3,469 519,906

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 541909 entries, 0 to 541908
Data columns (total 4 columns):
 #   Column       Non-Null Count   Dtype  
---  ------       --------------   -----  
 0   Description  540455 non-null  object 
 1   InvoiceDate  541909 non-null  float64
 2   InvoiceNo    541909 non-null  object 
 3   StockCode    541909 non-null  object 
dtypes: float64(1), object(3)
memory usage: 16.5+ MB


In [13]:
data.groupby('InvoiceNo').count().sort_values(by='InvoiceDate', ascending=False)

Unnamed: 0_level_0,Description,InvoiceDate,StockCode
InvoiceNo,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
573585,1114,1114,1114
581219,749,749,749
581492,731,731,731
580729,721,721,721
558475,705,705,705
...,...,...,...
557509,1,1,1
540264,1,1,1
540272,0,1,1
557501,1,1,1


In [14]:
user_len = data['InvoiceNo'].value_counts()
item_len = data['StockCode'].value_counts()
invocie_len = len(data)

print(f"User: {len(user_len)}, Item: {len(item_len)}, Invoice: {data['InvoiceDate'].count()}")

User: 25900, Item: 4070, Invoice: 541909


In [15]:
user_field = LabelField()
s_field = LabelField()
sequences = defaultdict(list)
raw_sequences = defaultdict(list)

In [21]:
# Create interactions
inters = []
for row in tqdm(data.itertuples(), total=len(data)):
    user_id = row.InvoiceNo
    item_id = row.StockCode
    time = row.InvoiceDate
    if row.Description:
        inters.append((user_id, item_id, time))

len(inters)

100%|██████████| 541909/541909 [00:00<00:00, 1502052.31it/s]


541909

In [51]:
# Create Item2Seq
item2seq = {}
for row in tqdm(data.itertuples(), total=len(data)):
    if isinstance(row.Description, str) and len(row.Description) > 0:
        item_id = row.StockCode
        item2seq[item_id] = row.Description

len(item2seq)

100%|██████████| 541909/541909 [00:00<00:00, 1469452.24it/s]


3958

In [23]:
# Filter K core

def get_user2count(inters):
    user2count = defaultdict(int)
    for unit in inters:
        user2count[unit[0]] += 1
    return user2count


def get_item2count(inters):
    item2count = defaultdict(int)
    for unit in inters:
        item2count[unit[1]] += 1
    return item2count


def generate_candidates(unit2count, threshold):
    cans = set()
    for unit, count in unit2count.items():
        if count >= threshold:
            cans.add(unit)
    return cans, len(unit2count) - len(cans)

user_k_core_threshold = 5
item_k_core_threshold = 5

if user_k_core_threshold or item_k_core_threshold:
    new_inters = []
    print('\nFiltering by k-core:')
    idx = 0
    user2count = get_user2count(inters)
    item2count = get_item2count(inters)

    while True:
        new_user2count = defaultdict(int)
        new_item2count = defaultdict(int)
        users, n_filtered_users = generate_candidates(
            user2count, user_k_core_threshold)
        items, n_filtered_items = generate_candidates(
            item2count, item_k_core_threshold)
        if n_filtered_users == 0 and n_filtered_items == 0:
            break
        for unit in inters:
            if unit[0] in users and unit[1] in items:
                new_inters.append(unit)
                new_user2count[unit[0]] += 1
                new_item2count[unit[1]] += 1
        idx += 1
        inters, new_inters = new_inters, []
        user2count, item2count = new_user2count, new_item2count
        print('    Epoch %d The number of inters: %d, users: %d, items: %d'
                % (idx, len(inters), len(user2count), len(item2count)))


Filtering by k-core:
    Epoch 1 The number of inters: 525514, users: 16537, items: 3524
    Epoch 2 The number of inters: 525309, users: 16531, items: 3472
    Epoch 3 The number of inters: 525301, users: 16529, items: 3472
    Epoch 4 The number of inters: 525297, users: 16529, items: 3471


In [30]:
# Check
raw_sequences = defaultdict(list)
usercount = defaultdict(int)
itemcount = defaultdict(int)
titlelen = []

for unit in tqdm(inters):
    user_id = unit[0]
    item_id = unit[1]
    time = unit[2]
    usercount[user_id] += 1
    itemcount[item_id] += 1
    titlelen.append(len(item2seq[item_id]))
    raw_sequences[user_id].append((item_id, time))

100%|██████████| 525297/525297 [00:00<00:00, 1230027.80it/s]


In [31]:
len(raw_sequences)

16529

In [37]:
sequences = defaultdict(list)

In [38]:

for k, v in raw_sequences.items():
    if len(v) >= 2:
        sequences[user_field.get_id(k)] = [(s_field.get_id(ele[0]), ele[1]) for ele in v]

train_dict = dict()
dev_dict = dict()
test_dict = dict()

intersections = 0

for k, v in tqdm(sequences.items()):
    sequences[k] = sorted(v, key=lambda x: x[1])
    sequences[k] = [ele[0] for ele in sequences[k]]

    length = len(sequences[k])
    intersections += length
    if length < 4:
        train_dict[k] = sequences[k]
    else:
        train_dict[k] = sequences[k][: length - 2]
        dev_dict[k] = [sequences[k][length - 2]]
        test_dict[k] = [sequences[k][length - 1]]

print(f"Users: {len(user_field.label2id)}, Items: {len(s_field.label2id)}, Intersects: {intersections}")

100%|██████████| 16529/16529 [00:00<00:00, 301126.06it/s]

Users: 20060, Items: 3927, Intersects: 525297





In [47]:
item2seq['85123A']

'CREAM HANGING HEART T-LIGHT HOLDER'

In [52]:
for item, des in item2seq.items():
    item2seq[item] = {
        'description': des,
    }

item2seq

{'85123A': {'description': 'CREAM HANGING HEART T-LIGHT HOLDER'},
 '71053': {'description': 'WHITE MOROCCAN METAL LANTERN'},
 '84406B': {'description': 'CREAM CUPID HEARTS COAT HANGER'},
 '84029G': {'description': 'KNITTED UNION FLAG HOT WATER BOTTLE'},
 '84029E': {'description': 'RED WOOLLY HOTTIE WHITE HEART.'},
 '22752': {'description': 'SET 7 BABUSHKA NESTING BOXES'},
 '21730': {'description': 'GLASS STAR FROSTED T-LIGHT HOLDER'},
 '22633': {'description': 'HAND WARMER UNION JACK'},
 '22632': {'description': 'HAND WARMER RED RETROSPOT'},
 '84879': {'description': 'ASSORTED COLOUR BIRD ORNAMENT'},
 '22745': {'description': "POPPY'S PLAYHOUSE BEDROOM "},
 '22748': {'description': "POPPY'S PLAYHOUSE KITCHEN"},
 '22749': {'description': 'FELTCRAFT PRINCESS CHARLOTTE DOLL'},
 '22310': {'description': 'IVORY KNITTED MUG COSY '},
 '84969': {'description': 'BOX OF 6 ASSORTED COLOUR TEASPOONS'},
 '22623': {'description': 'BOX OF VINTAGE JIGSAW BLOCKS '},
 '22622': {'description': 'BOX OF VI

In [53]:
import os

output_path = 'online_retail_ours'
train_file = os.path.join(output_path, "train.json")
dev_file = os.path.join(output_path, "val.json")
test_file = os.path.join(output_path, "test.json")
umap_file = os.path.join(output_path, "umap.json")
smap_file = os.path.join(output_path, "smap.json")
meta_file = os.path.join(output_path, "meta_data.json")

print(f"Users: {len(user_field.label2id)}, Items: {len(s_field.label2id)}, Intersects: {intersections}")

f_u = open(umap_file, "w", encoding="utf8")
json.dump(user_field.label2id, f_u, indent=1, ensure_ascii=False)
f_u.close()

f_s = open(smap_file, "w", encoding="utf8")
json.dump(s_field.label2id, f_s, indent=1, ensure_ascii=False)
f_s.close()

train_f = open(train_file, "w", encoding="utf8")
json.dump(train_dict, train_f, indent=1, ensure_ascii=False)
train_f.close()

dev_f = open(dev_file, "w", encoding="utf8")
json.dump(dev_dict, dev_f, indent=1, ensure_ascii=False)
dev_f.close()

test_f = open(test_file, "w", encoding="utf8")
json.dump(test_dict, test_f, indent=1, ensure_ascii=False)
test_f.close()


meta_f = open(meta_file, "w", encoding="utf8")
json.dump(item2seq, meta_f, indent=1, ensure_ascii=False)
meta_f.close()

Users: 20060, Items: 3927, Intersects: 525297
