In [1]:
import pandas as pd
import gzip
import json
from tqdm import tqdm

dataset = "Dunnhumby" # "TaFeng" or "Dunnhumby"

In [2]:
# 讀取CSV文件
def read_csv(file_name):
    df = pd.read_csv(file_name)
    df['CUSTOMER_ID'] = df['CUSTOMER_ID'].astype(str)  # 將 CUSTOMER_ID 轉換為字符串型
    return df

# 讀取JSON GZ文件
def read_json_gz(file_name):
    with gzip.open(file_name, 'rt', encoding='utf-8') as f:
        data = json.load(f)
    # 將用戶ID轉換為字符串型
    for pair in data:
        pair[0] = str(pair[0])
        pair[1] = [str(x) for x in pair[1]]
    return data

# 使用這些函數讀取您的文件
history = read_csv(f'data/{dataset}_history.csv')
test_neighbors = read_json_gz(f'data/{dataset}/test_neighbors_for_dlim.json.gz')
training_neighbors = read_json_gz(f'data/{dataset}/training_neighbors_for_dlim.json.gz')
validation_neighbors = read_json_gz(f'data/{dataset}/validation_neighbors_for_dlim.json.gz')

# 顯示讀取的數據（示例）
print(history.head())
print(test_neighbors[:1])

  CUSTOMER_ID  TRANSACTION_DT  ORDER_NUMBER  MATERIAL_NUMBER
0          31        20060416             0                0
1          31        20060416             0                1
2          31        20060416             0                2
3          31        20060416             0                3
4          31        20060416             0                4
[['804580', ['107077', '428025', '402253', '630499', '323812', '52783', '41465', '226027', '537635', '332270', '346608', '143474', '154266', '487479', '90354', '191377', '495345', '629767', '600406', '293802', '52863', '570839', '579253', '162857', '366454', '486759', '198783', '491839', '510368', '344042', '651259', '622257', '691843', '288706', '519963', '207064', '510937', '670018', '172346', '130456', '261575', '523641', '10665', '140492', '444051', '25882', '250864', '215512', '47760', '489192', '266989', '695739', '82732', '334736', '41099', '673586', '441867', '508807', '237637', '297486', '585322', '60968', '521610', '

In [3]:
# 使用集合來計算 validation_neighbors 中不重複的 UserID 數量
unique_user_ids = set()
for pair in test_neighbors:
    unique_user_ids.add(pair[0])
    unique_user_ids.update(pair[1])

# 計算不重複的 UserID 數量
unique_user_ids_count = len(unique_user_ids)
print("不重複的 UserID 數量:", unique_user_ids_count)

不重複的 UserID 數量: 8100


In [4]:
# 初始化一個空的字典來儲存結果
result_data = {}

# 遍歷每個用戶及其鄰居
for pair in tqdm(test_neighbors, desc="Processing Users"):
    customer_id = pair[0]  # 用戶ID
    neighbors = pair[1]    # 鄰居列表
    users = [customer_id] + neighbors

    # 對於每個用戶及其鄰居，從 history 中提取交易資訊
    for user in users:
        user_transactions = history[history['CUSTOMER_ID'] == user]

        # 對於相同的 ORDER_NUMBER，將 MATERIAL_NUMBER 組合在一起
        grouped_transactions = user_transactions.groupby('ORDER_NUMBER').agg({'MATERIAL_NUMBER': lambda x: tuple(x), 'TRANSACTION_DT': 'first'}).reset_index()

        # 將每筆交易的資料以 [(BasketItems, Date)] 的格式添加到字典
        transactions = [(row['MATERIAL_NUMBER'], row['TRANSACTION_DT']) for index, row in grouped_transactions.iterrows()]
        result_data[user] = transactions

Processing Users: 100%|██████████| 2565/2565 [10:27:12<00:00, 14.67s/it] 


In [5]:
print(result_data)  # 印出前幾行以檢查
print(len(result_data))

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [6]:
# 將字典轉換為 DataFrame
result_df = pd.DataFrame(list(result_data.items()), columns=['UserID', 'Transactions'])

# 顯示 DataFrame
print(result_df)

      UserID                                       Transactions
0     804580  [((436, 36, 2284, 186, 346, 642, 250), 2006041...
1     107077  [((77, 62), 20060410), ((608,), 20060411), ((2...
2     428025  [((466, 608, 290, 257, 2634, 1127, 2303, 658, ...
3     402253  [((1060,), 20060410), ((1559,), 20060412), ((7...
4     630499  [((36,), 20060410), ((2580, 2029, 982, 62, 303...
...      ...                                                ...
8095  999698  [((162, 1829, 1502, 1123, 1516, 778, 708), 200...
8096  999718  [((2083, 568, 290, 1489, 16, 274, 1081, 107, 2...
8097  624927  [((308,), 20060410), ((1113, 138, 1796), 20060...
8098  103139  [((885, 2426, 1477, 1590, 1110, 1652, 858, 172...
8099  999934  [((0, 1393, 1607, 540, 2153, 67, 2362, 2520, 9...

[8100 rows x 2 columns]


In [7]:
with gzip.open(f'data/{dataset}/{dataset}_test_users_transactions.json.gz', 'wt', encoding='UTF-8') as zipfile:
    json.dump(result_data, zipfile)

print("test_users_transactions.json.gz 文件已被成功創建並保存。")

test_users_transactions.json.gz 文件已被成功創建並保存。
