In [1]:
import pandas as pd
import gzip
import json
from tqdm import tqdm

dataset = "Dunnhumby" # "TaFeng" or "Dunnhumby"

In [2]:
# 讀取CSV文件
def read_csv(file_name):
    df = pd.read_csv(file_name)
    df['CUSTOMER_ID'] = df['CUSTOMER_ID'].astype(str)  # 將 CUSTOMER_ID 轉換為字符串型
    return df

# 讀取JSON GZ文件
def read_json_gz(file_name):
    with gzip.open(file_name, 'rt', encoding='utf-8') as f:
        data = json.load(f)
    # 將用戶ID轉換為字符串型
    for pair in data:
        pair[0] = str(pair[0])
        pair[1] = [str(x) for x in pair[1]]
    return data

# 使用這些函數讀取您的文件
history = read_csv(f'data/{dataset}_history.csv')
test_neighbors = read_json_gz(f'data/{dataset}/test_neighbors_for_dlim.json.gz')
training_neighbors = read_json_gz(f'data/{dataset}/training_neighbors_for_dlim.json.gz')
validation_neighbors = read_json_gz(f'data/{dataset}/validation_neighbors_for_dlim.json.gz')

# 顯示讀取的數據（示例）
print(history.head())
print(training_neighbors[:1])

  CUSTOMER_ID  TRANSACTION_DT  ORDER_NUMBER  MATERIAL_NUMBER
0          31        20060416             0                0
1          31        20060416             0                1
2          31        20060416             0                2
3          31        20060416             0                3
4          31        20060416             0                4
[['31', ['211052', '428025', '266989', '293802', '107077', '491839', '605802', '215512', '486759', '651259', '495345', '244650', '238158', '487479', '52783', '154266', '172346', '425722', '226027', '441867', '683525', '551172', '365721', '691843', '510937', '140492', '510368', '261575', '403281', '323812', '489192', '198783', '537635', '380801', '579253', '191377', '444051', '412206', '508807', '62328', '47760', '600406', '659359', '453093', '237637', '130456', '51022', '722116', '47742', '332270', '669592', '500478', '155225', '136130', '317474', '82736', '622732', '52749', '162857', '346608', '288706', '34878', '630499', '29

In [3]:
# 使用集合來計算 validation_neighbors 中不重複的 UserID 數量
unique_user_ids = set()
for pair in training_neighbors:
    unique_user_ids.add(pair[0])
    unique_user_ids.update(pair[1])

# 計算不重複的 UserID 數量
unique_user_ids_count = len(unique_user_ids)
print("不重複的 UserID 數量:", unique_user_ids_count)

不重複的 UserID 數量: 9234


In [4]:
# 初始化一個空的字典來儲存結果
result_data = {}

# 遍歷每個用戶及其鄰居
for pair in tqdm(training_neighbors, desc="Processing Users"):
    customer_id = pair[0]  # 用戶ID
    neighbors = pair[1]    # 鄰居列表
    users = [customer_id] + neighbors

    # 對於每個用戶及其鄰居，從 history 中提取交易資訊
    for user in users:
        user_transactions = history[history['CUSTOMER_ID'] == user]

        # 對於相同的 ORDER_NUMBER，將 MATERIAL_NUMBER 組合在一起
        grouped_transactions = user_transactions.groupby('ORDER_NUMBER').agg({'MATERIAL_NUMBER': lambda x: tuple(x), 'TRANSACTION_DT': 'first'}).reset_index()

        # 將每筆交易的資料以 [(BasketItems, Date)] 的格式添加到字典
        transactions = [(row['MATERIAL_NUMBER'], row['TRANSACTION_DT']) for index, row in grouped_transactions.iterrows()]
        result_data[user] = transactions

Processing Users: 100%|██████████| 9234/9234 [37:12:17<00:00, 14.50s/it]   


In [5]:
print(result_data)  # 印出前幾行以檢查
print(len(result_data))

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [6]:
# 將字典轉換為 DataFrame
result_df = pd.DataFrame(list(result_data.items()), columns=['UserID', 'Transactions'])

# 顯示 DataFrame
print(result_df)

      UserID                                       Transactions
0         31  [((0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 20060416...
1     211052  [((1063,), 20060412), ((2718, 68), 20060414), ...
2     428025  [((466, 608, 290, 257, 2634, 1127, 2303, 658, ...
3     266989  [((36, 1430, 1685, 248, 1692), 20060410), ((56...
4     293802  [((141, 1881, 62), 20060410), ((1463, 1463), 2...
...      ...                                                ...
9229  725702  [((2140, 885, 1207, 139, 2138), 20060411), ((8...
9230  725978  [((596, 2344, 1955, 479, 432, 19), 20060420), ...
9231  725996  [((1673,), 20060413), ((1070,), 20060414), ((3...
9232  726073  [((553,), 20060417), ((799, 319, 1051, 1175, 1...
9233  726134  [((64, 568, 181, 1591, 2014, 2022, 11, 2169, 1...

[9234 rows x 2 columns]


In [7]:
with gzip.open(f'data/{dataset}/{dataset}_training_users_transactions.json.gz', 'wt', encoding='UTF-8') as zipfile:
    json.dump(result_data, zipfile)

print("training_users_transactions.json.gz 文件已被成功創建並保存。")

training_users_transactions.json.gz 文件已被成功創建並保存。
