In [1]:
import pandas as pd
import gzip
import json
from tqdm import tqdm

dataset = "Dunnhumby" # "TaFeng" or "Dunnhumby"

In [2]:
# 讀取CSV文件
def read_csv(file_name):
    df = pd.read_csv(file_name)
    df['CUSTOMER_ID'] = df['CUSTOMER_ID'].astype(str)  # 將 CUSTOMER_ID 轉換為字符串型
    return df

# 讀取JSON GZ文件
def read_json_gz(file_name):
    with gzip.open(file_name, 'rt', encoding='utf-8') as f:
        data = json.load(f)
    # 將用戶ID轉換為字符串型
    for pair in data:
        pair[0] = str(pair[0])
        pair[1] = [str(x) for x in pair[1]]
    return data

# 使用這些函數讀取您的文件
history = read_csv(f'data/{dataset}_history.csv')
test_neighbors = read_json_gz(f'data/{dataset}/test_neighbors_for_dlim.json.gz')
training_neighbors = read_json_gz(f'data/{dataset}/training_neighbors_for_dlim.json.gz')
validation_neighbors = read_json_gz(f'data/{dataset}/validation_neighbors_for_dlim.json.gz')

# 顯示讀取的數據（示例）
print(history.head())
print(validation_neighbors[:1])

  CUSTOMER_ID  TRANSACTION_DT  ORDER_NUMBER  MATERIAL_NUMBER
0          31        20060416             0                0
1          31        20060416             0                1
2          31        20060416             0                2
3          31        20060416             0                3
4          31        20060416             0                4
[['726826', ['215512', '428025', '362775', '293802', '651259', '154266', '486759', '491839', '489192', '261575', '508807', '487479', '673586', '107077', '29945', '691843', '441867', '52783', '722116', '47760', '51022', '403281', '510937', '62328', '172346', '669592', '659359', '60320', '365721', '584197', '82736', '140492', '34878', '155225', '339747', '226027', '47742', '495345', '168928', '622732', '244650', '162857', '288706', '380801', '11216', '540205', '564240', '168687', '623887', '398826', '444051', '551172', '94689', '413620', '537635', '305179', '361919', '59994', '136130', '510368', '266989', '671602', '191377', '51

In [3]:
# 使用集合來計算 validation_neighbors 中不重複的 UserID 數量
unique_user_ids = set()
for pair in validation_neighbors:
    unique_user_ids.add(pair[0])
    unique_user_ids.update(pair[1])

# 計算不重複的 UserID 數量
unique_user_ids_count = len(unique_user_ids)
print("不重複的 UserID 數量:", unique_user_ids_count)

不重複的 UserID 數量: 5662


In [4]:
# 初始化一個空的字典來儲存結果
result_data = {}

# 遍歷每個用戶及其鄰居
for pair in tqdm(validation_neighbors, desc="Processing Users"):
    customer_id = pair[0]  # 用戶ID
    neighbors = pair[1]    # 鄰居列表
    users = [customer_id] + neighbors

    # 對於每個用戶及其鄰居，從 history 中提取交易資訊
    for user in users:
        user_transactions = history[history['CUSTOMER_ID'] == user]

        # 對於相同的 ORDER_NUMBER，將 MATERIAL_NUMBER 組合在一起
        grouped_transactions = user_transactions.groupby('ORDER_NUMBER').agg({'MATERIAL_NUMBER': lambda x: tuple(x), 'TRANSACTION_DT': 'first'}).reset_index()

        # 將每筆交易的資料以 [(BasketItems, Date)] 的格式添加到字典
        transactions = [(row['MATERIAL_NUMBER'], row['TRANSACTION_DT']) for index, row in grouped_transactions.iterrows()]
        result_data[user] = transactions

Processing Users: 100%|██████████| 1026/1026 [4:16:28<00:00, 15.00s/it] 


In [5]:
print(result_data)  # 印出前幾行以檢查
print(len(result_data))

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [6]:
# 將字典轉換為 DataFrame
result_df = pd.DataFrame(list(result_data.items()), columns=['UserID', 'Transactions'])

# 顯示 DataFrame
print(result_df)

      UserID                                       Transactions
0     726826  [((290, 677, 274), 20060415), ((290, 274), 200...
1     215512  [((2542,), 20060411), ((1302, 1937, 1990), 200...
2     428025  [((466, 608, 290, 257, 2634, 1127, 2303, 658, ...
3     362775  [((956, 2971), 20060410), ((349, 1082), 200604...
4     293802  [((141, 1881, 62), 20060410), ((1463, 1463), 2...
...      ...                                                ...
5657  673855  [((2347, 84, 37, 671, 1459, 122, 1898, 1567, 2...
5658  330227  [((1536,), 20060511), ((1536, 578), 20060513),...
5659  804366  [((1837, 534, 2291, 2736, 2989, 1654, 734, 125...
5660  804385  [((36, 144, 305, 2147, 909, 224, 108, 329, 121...
5661  804471  [((736, 1141, 11), 20060412), ((0,), 20060413)...

[5662 rows x 2 columns]


In [7]:
with gzip.open(f'data/{dataset}/{dataset}_validation_users_transactions.json.gz', 'wt', encoding='UTF-8') as zipfile:
    json.dump(result_data, zipfile)

print("validation_users_transactions.json.gz 文件已被成功創建並保存。")

validation_users_transactions.json.gz 文件已被成功創建並保存。
