# 載入套件

In [1]:
import gzip
import json
import gensim
import numpy as np
from tqdm import tqdm
import pickle

# 載入 Item2vec 模型

In [2]:
# 加载模型
embed_dim = 64
dataset = "Dunnhumby" # "TaFeng" or "Dunnhumby"
model = gensim.models.Word2Vec.load(f'data/item2vec_models/item2vec_{dataset}.{embed_dim}d-testttttt.model')

In [3]:
model.wv.key_to_index

# 創建 item2vector_dict
item2vector_dict = {arg: model.wv[arg] for arg in model.wv.key_to_index}

# 驗證集

In [4]:
# 檔案路徑
file_path = f'data/{dataset}/{dataset}_validation_users_transactions.json.gz'

# 使用 gzip 打開壓縮檔案
with gzip.open(file_path, 'rt', encoding='utf-8') as file:
    # 讀取並解析 JSON 數據
    validation_data = json.load(file) # data 是字典

for key, value in list(validation_data.items())[:5]:
        print(f"{key}: {value}")

726826: [[[290, 677, 274], 20060415], [[290, 274], 20060418], [[909], 20060425], [[534, 231, 117, 2333], 20060429], [[1285], 20060507], [[236, 677, 1226, 117], 20060514], [[534, 280, 628], 20060517], [[1292], 20060519], [[1845, 505], 20060524], [[1212, 1821, 71, 1754], 20060526], [[117], 20060603], [[2468, 1814], 20060619], [[236], 20060620], [[1273, 1481, 2494, 1212, 1915], 20060630]]
215512: [[[2542], 20060411], [[1302, 1937, 1990], 20060413], [[1988, 2169, 2898], 20060414], [[162], 20060416], [[767], 20060417], [[2505, 446, 2169], 20060418], [[2543, 106, 657, 1520, 708], 20060419], [[1857, 2023], 20060422], [[1123, 418], 20060423], [[1599, 494], 20060426], [[486], 20060429], [[2212, 899, 869, 1908], 20060430], [[280, 707], 20060501], [[1689], 20060502], [[1098, 767], 20060504], [[455], 20060507], [[208, 2169, 1276, 139, 1671], 20060510], [[1969, 121], 20060512], [[2497], 20060513], [[2515, 2840, 1375], 20060514], [[121, 519, 765, 2141], 20060517], [[2278], 20060518], [[2169, 996, 76

In [5]:
# 計算 data 中有多少個鍵（用戶）
number_of_keys = len(validation_data)

print(f"validation_data 中共有 {number_of_keys} 個用戶。")

validation_data 中共有 5662 個用戶。


In [6]:
for key in tqdm(validation_data.keys(), desc="處理用戶"):
    baskets = validation_data[key]
    for i, basket in enumerate(baskets):
        items, date = basket[0], basket[1]

        basket_vector = np.zeros(model.vector_size)
        valid_items = 0

        for item in items:
            if item in item2vector_dict:
                basket_vector += item2vector_dict[item]
                valid_items += 1

        if valid_items > 0:
            basket_vector /= valid_items

        # 更新 data 中的商品列表為 basket_vector
        validation_data[key][i][0] = basket_vector
        #print(f"User {key}, Date {date}, Basket Vector: {basket_vector}")

處理用戶: 100%|██████████| 5662/5662 [00:00<00:00, 9333.34it/s]


In [7]:
for key, value in list(validation_data.items())[:5]:
        print(f"{key}: {value}")

726826: [[array([ 0.11570897,  0.35562299,  0.18679793, -0.00291243, -0.12404273,
       -0.25115974,  0.01409973,  0.15234672, -0.44351171,  0.10672879,
       -0.0716867 , -0.07000306, -0.07444386, -0.02488467,  0.12649005,
        0.13271777, -0.28592011,  0.08067834,  0.2203838 ,  0.11038253,
        0.08540111,  0.1243097 ,  0.19031221, -0.07464011, -0.16142947,
        0.0113234 , -0.0475686 ,  0.05539724, -0.29216947,  0.03150849,
        0.00284046, -0.0306042 , -0.01595666, -0.02402539, -0.07324991,
       -0.17608824,  0.1428161 , -0.12531867, -0.00262701,  0.02655361,
       -0.01828408,  0.16463482,  0.06214355,  0.05629523,  0.33729303,
       -0.26041424, -0.08009868, -0.08217811, -0.14256224, -0.07664439,
        0.13538381, -0.10073714, -0.28350307,  0.14240684,  0.24422638,
        0.19515768, -0.02469419, -0.31280575, -0.07842512,  0.09821545,
       -0.05878391, -0.05630253, -0.02850737, -0.04499354]), 20060415], [array([ 0.09734304,  0.26996374,  0.17953847,  0.1017

In [8]:
# 取得最後五個用戶的鍵
last_five_keys = list(validation_data.keys())[-5:]

# 印出這些用戶的數據
for key in last_five_keys:
    print(f"User {key}: {validation_data[key]}")

User 673855: [[array([ 0.02343979,  0.089033  ,  0.21647452, -0.03053533,  0.02448245,
       -0.14099798, -0.08057041,  0.00284137, -0.36726348,  0.05504642,
        0.13905954, -0.00568086, -0.20944212, -0.05047025,  0.1810787 ,
        0.14251   , -0.27277323,  0.04308807, -0.00833897,  0.15743341,
        0.21755079,  0.28730449,  0.19810878, -0.15150554, -0.14352398,
        0.09641397,  0.00604784,  0.10256571, -0.32234014,  0.2370261 ,
        0.00797756,  0.02666744, -0.16356801, -0.14443193, -0.08804946,
       -0.05840606,  0.18989928, -0.14502324, -0.0010011 ,  0.02465862,
       -0.11687137,  0.16279786, -0.01779317,  0.08253833,  0.17322156,
       -0.28478047,  0.07404995, -0.3512427 ,  0.06328432, -0.02149858,
        0.07760195, -0.1685454 , -0.10300412,  0.12337424,  0.1289049 ,
        0.24848049,  0.02108581, -0.10825182, -0.07847923,  0.05960697,
       -0.11536851, -0.11864179,  0.0368417 ,  0.06400411]), 20060412], [array([-0.17689262,  0.00206959,  0.09609048, -0

In [9]:
output_file = f'data/{dataset}/basketembedding/validation_basketembedding_{embed_dim}.pkl.gz'

# 使用 gzip 壓縮並儲存數據
with gzip.open(output_file, 'wb') as file:
    pickle.dump(validation_data, file)

print("Save !")

Save !


# 測試集

In [10]:
# 檔案路徑
file_path2 = f'data/{dataset}/{dataset}_test_users_transactions.json.gz'

# 使用 gzip 打開壓縮檔案
with gzip.open(file_path2, 'rt', encoding='utf-8') as file:
    # 讀取並解析 JSON 數據
    test_data = json.load(file) # data 是字典

for key, value in list(test_data.items())[:5]:
        print(f"{key}: {value}")

804580: [[[436, 36, 2284, 186, 346, 642, 250], 20060415], [[141, 2109, 1344, 734, 1986], 20060416], [[36, 522, 429], 20060421], [[36, 653, 1303, 52, 94, 186, 642, 99, 1518, 1986, 2108], 20060503], [[36, 2687, 186, 597, 1683, 1379, 1053, 1518, 1986, 1069], 20060505], [[36, 1266, 1085, 549, 186, 680, 1954, 2031, 1216, 943], 20060511], [[36, 52, 2616], 20060515], [[36, 2011, 2216, 1508, 1272, 250, 1068], 20060518], [[77], 20060522], [[2901, 582, 486], 20060523], [[36, 20, 379, 1128], 20060525], [[1023, 1347, 1500], 20060530], [[1829, 94, 1739, 1272, 1327], 20060602], [[2109, 1805, 702, 1085, 296, 1193, 1511, 274, 418, 77, 1442, 434], 20060604], [[1266, 2319, 117, 2978, 830], 20060610]]
107077: [[[77, 62], 20060410], [[608], 20060411], [[2035, 1909, 1353, 932, 264, 26], 20060412], [[852], 20060415], [[1193, 19], 20060420], [[1114, 1597, 264, 413], 20060422], [[2158, 26, 2046], 20060423], [[2963, 1908], 20060424], [[1635, 1567], 20060428], [[25, 992], 20060505], [[2464, 139, 2898], 20060508

In [11]:
# 計算 data 中有多少個鍵（用戶）
number_of_keys = len(test_data)

print(f"test_data 中共有 {number_of_keys} 個用戶。")

test_data 中共有 8100 個用戶。


In [12]:
for key in tqdm(test_data.keys(), desc="處理用戶"):
    baskets = test_data[key]
    for i, basket in enumerate(baskets):
        items, date = basket[0], basket[1]

        basket_vector = np.zeros(model.vector_size)
        valid_items = 0

        for item in items:
            if item in item2vector_dict:
                basket_vector += item2vector_dict[item]
                valid_items += 1

        if valid_items > 0:
            basket_vector /= valid_items

        # 更新 data 中的商品列表為 basket_vector
        test_data[key][i][0] = basket_vector
        #print(f"User {key}, Date {date}, Basket Vector: {basket_vector}")

處理用戶: 100%|██████████| 8100/8100 [00:00<00:00, 9151.91it/s]


In [13]:
for key, value in list(test_data.items())[:5]:
        print(f"{key}: {value}")

804580: [[array([ 0.09842637,  0.10050048,  0.17682004,  0.14659084,  0.00107484,
       -0.16106878, -0.06388833,  0.07959905, -0.42698528,  0.04746684,
       -0.06469189, -0.10808727, -0.13423673, -0.15523093,  0.18439122,
        0.08388541, -0.37113295,  0.06069645, -0.01169967, -0.02041237,
        0.15581904,  0.18318415,  0.28000718, -0.04063343, -0.11379407,
       -0.03850038,  0.02652616, -0.01791437, -0.28018115,  0.29649149,
       -0.03163617,  0.10575764, -0.05312626, -0.17802444, -0.08077464,
       -0.09164395,  0.21006748, -0.09682735,  0.02487142, -0.05757419,
       -0.01369406,  0.25358997, -0.02417743, -0.06672286,  0.12464702,
       -0.26214565,  0.04023578, -0.12543005, -0.12643201, -0.06608088,
        0.15887193, -0.08498412, -0.15237555,  0.12241038,  0.16265915,
        0.15246562,  0.04957027, -0.19630749, -0.1235533 ,  0.03740637,
       -0.13480779,  0.02565187,  0.00203559, -0.03830073]), 20060415], [array([-0.19909386,  0.15322317,  0.15918106,  0.0420

In [14]:
# 取得最後五個用戶的鍵
last_five_keys = list(test_data.keys())[-5:]

# 印出這些用戶的數據
for key in last_five_keys:
    print(f"User {key}: {test_data[key]}")

User 999698: [[array([-0.05468503, -0.07476687,  0.07333573,  0.04469131,  0.09172967,
       -0.11852491, -0.11581124,  0.03549588, -0.40760792, -0.17547816,
        0.06392509, -0.08657783, -0.06473359, -0.27895441,  0.2379324 ,
        0.25728918, -0.46560574, -0.15926692, -0.03511316,  0.0283878 ,
        0.02823632,  0.20484126,  0.29560382, -0.22869377, -0.08734651,
        0.06308707,  0.12082296,  0.16184719, -0.27010022,  0.29664915,
        0.14376365, -0.01125115,  0.00763831, -0.15192798, -0.08690886,
       -0.2862703 ,  0.11004401, -0.03600266, -0.02362515, -0.22548434,
       -0.09305341,  0.06297785, -0.05248974,  0.12981599,  0.10129718,
       -0.18924849,  0.00117399, -0.32008204,  0.00964794,  0.10026959,
        0.20625437, -0.07159929, -0.23173945,  0.11231091,  0.14752642,
        0.28797658, -0.0107626 , -0.34387979, -0.13758729,  0.00818177,
       -0.02196804, -0.03919753, -0.0152014 ,  0.09235983]), 20060419], [array([ 0.01779619,  0.03188376,  0.18249825,  0

In [15]:
output_file = f'data/{dataset}/basketembedding/test_basketembedding_{embed_dim}.pkl.gz'

# 使用 gzip 壓縮並儲存數據
with gzip.open(output_file, 'wb') as file:
    pickle.dump(test_data, file)

print("Save !")

Save !


# 訓練集

In [16]:
# 檔案路徑
file_path3 = f'data/{dataset}/{dataset}_training_users_transactions.json.gz'

# 使用 gzip 打開壓縮檔案
with gzip.open(file_path3, 'rt', encoding='utf-8') as file:
    # 讀取並解析 JSON 數據
    training_data = json.load(file) # data 是字典

for key, value in list(training_data.items())[:5]:
        print(f"{key}: {value}")

31: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], 20060416], [[11, 12, 13], 20060420], [[14, 15, 16, 17, 4, 18, 19], 20060425], [[20, 21, 22], 20060428], [[1, 23, 24, 6, 25, 26, 27], 20060503], [[28], 20060507], [[29, 2, 11, 30, 6, 9, 31, 19], 20060508], [[32, 22], 20060515], [[33, 34, 35], 20060520], [[36, 23, 37, 38, 39, 18, 40, 27], 20060526], [[29, 41], 20060527], [[42, 43, 44, 45, 35, 46], 20060605], [[47, 48, 49], 20060620]]
211052: [[[1063], 20060412], [[2718, 68], 20060414], [[1033, 1063], 20060421], [[885, 46, 436, 144, 121, 117], 20060424], [[2244, 1474, 1115, 462, 484, 1812, 62, 1687], 20060501], [[46], 20060506], [[1986], 20060511], [[1063], 20060512], [[46], 20060516], [[1189, 409, 655, 147, 239, 231, 280, 2404, 110, 649], 20060517], [[1695], 20060520], [[2011, 39, 71, 1812, 49], 20060522], [[2184], 20060523], [[2443, 2133], 20060525], [[46], 20060527], [[46, 1063], 20060528], [[46], 20060530], [[577], 20060612], [[2968, 2443], 20060619], [[47], 20060621], [[36, 1033], 20060630]]


In [17]:
# 計算 data 中有多少個鍵（用戶）
number_of_keys = len(training_data)

print(f"training_data 中共有 {number_of_keys} 個用戶。")

training_data 中共有 9234 個用戶。


In [18]:
for key in tqdm(training_data.keys(), desc="處理用戶"):
    baskets = training_data[key]
    for i, basket in enumerate(baskets):
        items, date = basket[0], basket[1]

        basket_vector = np.zeros(model.vector_size)
        valid_items = 0

        for item in items:
            if item in item2vector_dict:
                basket_vector += item2vector_dict[item]
                valid_items += 1

        if valid_items > 0:
            basket_vector /= valid_items

        # 更新 data 中的商品列表為 basket_vector
        training_data[key][i][0] = basket_vector
        #print(f"User {key}, Date {date}, Basket Vector: {basket_vector}")

處理用戶: 100%|██████████| 9234/9234 [00:01<00:00, 9002.32it/s]


In [19]:
for key, value in list(training_data.items())[:5]:
        print(f"{key}: {value}")

31: [[array([-0.00049729,  0.06168059,  0.10455597,  0.20515477,  0.09664868,
        0.00283992, -0.02029442, -0.09829571, -0.41760387, -0.08572544,
       -0.04872751, -0.19857934, -0.19121668, -0.14027289,  0.20996786,
        0.24561161, -0.41659566,  0.06678989, -0.06043612, -0.11300815,
        0.19849182,  0.06182647,  0.20212506, -0.21120315, -0.06437187,
        0.2021687 ,  0.07955647,  0.0030656 , -0.41425942,  0.22721912,
       -0.02865861,  0.01258214, -0.07507645, -0.09472915, -0.13448263,
       -0.05111551,  0.13380793,  0.00594578, -0.02777612, -0.10857102,
       -0.1770484 ,  0.1198671 ,  0.03028231,  0.13965301,  0.12989231,
       -0.29405129,  0.00288554, -0.18278277, -0.09983561,  0.05383913,
       -0.00355893, -0.09367005, -0.08966408,  0.26519052,  0.23960696,
        0.16322382,  0.08506418, -0.27139351, -0.2904999 ,  0.03326201,
       -0.16879427, -0.05730915,  0.00283261, -0.01841557]), 20060416], [array([ 0.04828406,  0.28730401, -0.00539981,  0.03941102

In [20]:
# 取得最後五個用戶的鍵
last_five_keys = list(training_data.keys())[-5:]

# 印出這些用戶的數據
for key in last_five_keys:
    print(f"User {key}: {training_data[key]}")

User 725702: [[array([ 0.00402877,  0.17952782,  0.27095037,  0.08918624, -0.09558082,
       -0.39744542,  0.06208987, -0.06749427, -0.32285371,  0.05688124,
       -0.12992124,  0.10494686, -0.142791  ,  0.22907392,  0.15597826,
        0.20544504, -0.33310507,  0.04087798,  0.0725698 ,  0.24369455,
        0.16178715,  0.23210019,  0.22156736, -0.09631198, -0.23847288,
       -0.26153891, -0.12240981,  0.22730457, -0.30986432,  0.35644614,
       -0.05387999,  0.01516913, -0.14550211, -0.06128286,  0.03920695,
       -0.05132222, -0.0019265 , -0.16886352,  0.00490754,  0.04429654,
        0.12875704,  0.33936689,  0.09257266,  0.06738484,  0.04854267,
       -0.23312604, -0.03301904, -0.27569971, -0.04307091,  0.04307332,
        0.15929629, -0.24025647, -0.14284433,  0.25747758,  0.24776973,
        0.26367985,  0.22909073,  0.12204288, -0.15180545,  0.11965657,
       -0.06851939, -0.19413263,  0.02336315, -0.25207454]), 20060411], [array([ 0.01776298,  0.21348465,  0.16453588,  0

In [21]:
output_file = f'data/{dataset}/basketembedding/training_basketembedding_{embed_dim}.pkl.gz'

# 使用 gzip 壓縮並儲存數據
with gzip.open(output_file, 'wb') as file:
    pickle.dump(training_data, file)

print("Save !")

Save !
