In [1]:
from __future__ import absolute_import, division, print_function

import argparse
import random
import wandb
import random
import json

from utils import *
from data_utils import Dataset
# from knowledge_graph_m import KnowledgeGraph
from knowledge_graph_m import KnowledgeGraph
from easydict import EasyDict as edict


def generate_labels(data_dir, filename):
    interaction_file = f"{data_dir}/{filename}"
    user_items = {}  # {uid: [cid,...], ...}
    with open(interaction_file, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            arr = line.split(" ")
            user_idx = int(arr[0])
            item_idx = int(arr[1])
            if user_idx not in user_items:
                user_items[user_idx] = []
            user_items[user_idx].append(item_idx)
    return user_items


def get_cold_users(user_items, cold_users_prop):
    """Make the cold start users sets."""
    #cold_users_prop=0.1
    cold_start_users = dict()
    users = list(user_items.keys())
    nb_users = len(users)
    random.shuffle(users)
    
    A = int((1 - 2 * cold_users_prop) * nb_users)
    B = int((1 - cold_users_prop) * nb_users)
    cold_start_users["train"]       = set(users[:A])
    cold_start_users["validation"]  = set(users[A:B])
    cold_start_users["test"]        = set(users[B :])
    return cold_start_users


def get_cold_items(user_items, cold_items_prop):
    #cold_items_prop=0.1
    cold_start_items = dict()
    items = set()
    for user in user_items:
        items.update(user_items[user])
    items = list(items)
    nb_items = len(items)
    random.shuffle(items)

    A = int((1 - 2 * cold_items_prop) * nb_items) 
    B = int((1 - cold_items_prop) * nb_items)
    cold_start_items["train"] = set(items[: A])
    cold_start_items["validation"] = set(items[A : B])
    cold_start_items["test"] = set(items[B :])
    return cold_start_items



In [2]:

# def split_train_test_data_by_user(
#     data_dir,
#     data_file,
#     validation_prop,
#     test_prop,
#     cold_users_prop,
#     cold_items_prop,
# ):
#     user_items = generate_labels(data_dir, data_file)

#     train_data = []
#     test_data = []
#     validation_data = []

#     # Make the cold start users sets
#     cold_start_users = get_cold_users(user_items, cold_users_prop)
#     # Make the cold start items sets
#     cold_start_items = get_cold_items(user_items, cold_items_prop)

#     # Split the data into train, validation and test sets
#     for user in user_items:
#         items = user_items[user]
#         nb_users_validation = max(1, int(len(items) * validation_prop))
#         nb_users_test = max(1, int(len(items) * test_prop))
#         # get the list of items for the user for each set
#         l_train_data = items[: -(nb_users_validation + nb_users_test)]
#         l_validation_data = items[
#             -(nb_users_validation + nb_users_test) : -(nb_users_test)
#         ]
#         l_test_data = items[-(nb_users_test):]

#         # remove cold start items from each set
#         l_train_data = [
#             item for item in l_train_data if item in cold_start_items["train"]
#         ]

#         l_validation_data = [
#             item for item in l_validation_data if item not in cold_start_items["test"]
#         ]

#         l_test_data = [
#             item for item in l_test_data if item not in cold_start_items["validation"]
#         ]

#         # remove cold start users from each set
#         if user in cold_start_users["train"]:
#             for c in l_train_data:
#                 train_data.append(f"{user} {c}\n")
#             for c in l_validation_data:
#                 validation_data.append(f"{user} {c}\n")
#             for c in l_test_data:
#                 test_data.append(f"{user} {c}\n")
#         elif user in cold_start_users["validation"]:
#             for c in l_validation_data:
#                 validation_data.append(f"{user} {c}\n")
#         else:
#             for c in l_test_data:
#                 test_data.append(f"{user} {c}\n")

#     random.shuffle(train_data)
#     random.shuffle(test_data)
#     random.shuffle(validation_data)

#     create_data_file(data_dir, train_data, "train.txt")
#     create_data_file(data_dir, validation_data, "validation.txt")
#     create_data_file(data_dir, test_data, "test.txt")

#     for key in cold_start_users:
#         cold_start_users[key] = list(cold_start_users[key])

#     for key in cold_start_items:
#         cold_start_items[key] = list(cold_start_items[key])

#     with open(data_dir + "/cold_start_users.json", "w") as f:
#         json.dump(cold_start_users, f, indent=4)

#     with open(data_dir + "/cold_start_items.json", "w") as f:
#         json.dump(cold_start_items, f, indent=4)

In [42]:
class Agrument:
    config = "../../config/beauty/graph_reasoning/UPGPR.json"
    
args = Agrument()

with open(args.config, "r") as f:
    config = edict(json.load(f))
    
# config
config['original_data_dir'] = "../../data/beauty/Amazon_Beauty"
config['processed_data_dir'] = "../../data/beauty/Amazon_Beauty_01_01"

In [43]:
set_random_seed(config.data_split_seed)
    
data_dir = config.processed_data_dir
data_file=config.data_file
validation_prop=config.validation_prop
test_prop=config.test_prop
cold_users_prop=config.cold_users_prop
cold_items_prop=config.cold_items_prop
    
# split_train_test_data_by_user(
#     config.processed_data_dir,
#     data_file=config.data_file,
#     validation_prop=config.validation_prop,
#     test_prop=config.test_prop,
#     cold_users_prop=config.cold_users_prop,
#     cold_items_prop=config.cold_items_prop,
# )

In [44]:
user_items = generate_labels(data_dir, data_file) #purchase.txt

# len(user_items) ##22363

train_data = []
test_data = []
validation_data = []

# Make the cold start users sets
cold_start_users = get_cold_users(user_items, cold_users_prop)
# Make the cold start items sets
cold_start_items = get_cold_items(user_items, cold_items_prop)


In [6]:
user_items

{319: [8099, 10825, 2167, 1636, 5980, 10569, 6222],
 9816: [8099, 7804, 8153, 7247, 587, 3572, 1205, 2086, 8365],
 10232: [8099, 5515, 7173, 12024, 5745, 8820],
 3222: [8099, 5578, 7029, 2777, 10787, 3135, 5625, 9028, 1648],
 1449: [8099, 1368, 5795, 6572, 2797, 9042, 6097],
 14999: [8099, 10266, 7288, 7127, 4631],
 2238: [4443, 8653, 6013, 4631, 4032, 8249, 10482, 8238, 1009, 10970],
 21974: [4443, 1103, 11644, 11193, 980, 144, 9472],
 11785: [4443, 8566, 6324, 10516, 7014, 2777, 9314, 8153, 4330],
 15019: [4443, 8322, 9723, 4488, 6176, 3532],
 19222: [4443, 1117, 5728, 2111, 8876, 5230, 5392, 12099, 11119],
 4771: [4443, 5872, 1800, 9555, 9495, 9598, 1751],
 22257: [4443, 715, 422, 507, 5200],
 22073: [4443, 3479, 5517, 6221, 5848, 8587],
 17614: [4443, 2268, 5693, 5607, 11630, 8007, 6320, 6369, 1901],
 13144: [8978, 2167, 4255, 11524, 568, 9546, 929, 4122],
 15301: [8978,
  8299,
  7949,
  1557,
  7753,
  4593,
  9232,
  11045,
  6320,
  1071,
  9793,
  5015,
  288],
 8552: [8978, 4

In [33]:
user_items[12]

[11932, 5205, 11229, 9610, 8718]

In [34]:
len(user_items.values())

22363

In [35]:
# user_items.values()

# Combine all numbers into a single list
all_numbers = []
for numbers_list in user_items.values():
    all_numbers.extend(numbers_list)

# Convert to a set to remove duplicates
unique_numbers = set(all_numbers)

# Count the number of unique numbers
num_unique_numbers_item = len(unique_numbers)
num_unique_numbers_item

12101

In [36]:
cold_start_users.keys(), cold_start_items.keys()

(dict_keys(['train', 'validation', 'test']),
 dict_keys(['train', 'validation', 'test']))

In [37]:
# cold_start_users #{'train' :{0,1,2,3,4,5,6,7,8,9,10,11,13,15,16,... }, 'test' : {}, 'validation' : {}}
# cold_start_items #{'train' :{0,1,2,3,4,5,6,7,8,10,11,12, 13,15,16,... }, 'test' : {}, 'validation' : {}}

len(cold_start_users['train']) + len(cold_start_users['test']) + len(cold_start_users['validation']) #22363
len(cold_start_items['train']) + len(cold_start_items['test']) + len(cold_start_items['validation']) #12101

12101

In [12]:
items

[2136, 11383, 3161, 3331, 8179, 7836, 1895]

In [13]:
l_train_data, l_validation_data, l_test_data

([2136, 11383, 3161, 3331, 8179], [], [1895])

In [15]:
# l_train_data, l_validation_data, l_test_data

In [49]:
# Split the data into train, validation and test sets
from tqdm.auto import tqdm
for user in tqdm(user_items):
    items = user_items[user] #list of items
    nb_users_validation = max(1, int(len(items) * validation_prop))
    nb_users_test       = max(1, int(len(items) * test_prop))
    # get the list of items for the user for each set
    l_train_data        = items[: -(nb_users_validation + nb_users_test)]
    l_validation_data   = items[-(nb_users_validation + nb_users_test) : -(nb_users_test)]
    l_test_data         = items[-(nb_users_test):]

    # remove cold start items from each set
    l_train_data = [
        item for item in l_train_data if item in cold_start_items["train"]
    ]

    l_validation_data = [
        item for item in l_validation_data if item not in cold_start_items["test"]
    ]

    l_test_data = [
        item for item in l_test_data if item not in cold_start_items["validation"]
    ]

    # remove cold start users from each set
    if user in cold_start_users["train"]:
        for c in l_train_data:
            train_data.append(f"{user} {c}\n")
        for c in l_validation_data:
            validation_data.append(f"{user} {c}\n")
        for c in l_test_data:
            test_data.append(f"{user} {c}\n")
    elif user in cold_start_users["validation"]:
        for c in l_validation_data:
            validation_data.append(f"{user} {c}\n")
    else:
        for c in l_test_data:
            test_data.append(f"{user} {c}\n")

random.shuffle(train_data)
random.shuffle(test_data)
random.shuffle(validation_data)

# create_data_file(data_dir, train_data, "train.txt")
# create_data_file(data_dir, validation_data, "validation.txt")
# create_data_file(data_dir, test_data, "test.txt")

for key in cold_start_users:
    cold_start_users[key] = list(cold_start_users[key])

for key in cold_start_items:
    cold_start_items[key] = list(cold_start_items[key])

# with open(data_dir + "/cold_start_users.json", "w") as f:
#     json.dump(cold_start_users, f, indent=4)

# with open(data_dir + "/cold_start_items.json", "w") as f:
#     json.dump(cold_start_items, f, indent=4)

  0%|          | 0/22363 [00:00<?, ?it/s]

In [50]:
len(cold_start_users['train']) + len(cold_start_users['test']) + len(cold_start_users['validation']) 

22363

In [48]:
old_cold_start_users = cold_start_users
len(cold_start_users['train'])

17890

In [51]:
cold_start_users['train']
len(cold_start_users['train'])

17890

In [58]:
set(old_cold_start_users['train'])

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 27,
 28,
 30,
 31,
 32,
 34,
 35,
 36,
 37,
 38,
 39,
 41,
 42,
 43,
 45,
 46,
 48,
 49,
 50,
 51,
 52,
 53,
 56,
 57,
 58,
 60,
 62,
 64,
 66,
 67,
 68,
 72,
 73,
 75,
 76,
 77,
 78,
 79,
 81,
 83,
 84,
 85,
 86,
 89,
 91,
 93,
 94,
 96,
 98,
 99,
 101,
 102,
 104,
 105,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 116,
 117,
 118,
 119,
 120,
 123,
 124,
 125,
 126,
 127,
 130,
 131,
 132,
 133,
 134,
 135,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 153,
 154,
 156,
 157,
 159,
 160,
 161,
 162,
 163,
 164,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 174,
 175,
 177,
 178,
 179,
 181,
 182,
 183,
 184,
 186,
 187,
 188,
 190,
 191,
 192,
 193,
 194,
 195,
 196,
 198,
 200,
 201,
 203,
 204,
 205,
 206,
 207,
 208,
 209,
 210,
 211,
 213,
 215,
 216,
 218,
 219,
 220,
 222,
 227,
 231,
 232,
 234,
 237,
 238,
 239,
 240,
 241,
 242,

In [59]:
set(cold_start_users['train'])

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 27,
 28,
 30,
 31,
 32,
 34,
 35,
 36,
 37,
 38,
 39,
 41,
 42,
 43,
 45,
 46,
 48,
 49,
 50,
 51,
 52,
 53,
 56,
 57,
 58,
 60,
 62,
 64,
 66,
 67,
 68,
 72,
 73,
 75,
 76,
 77,
 78,
 79,
 81,
 83,
 84,
 85,
 86,
 89,
 91,
 93,
 94,
 96,
 98,
 99,
 101,
 102,
 104,
 105,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 116,
 117,
 118,
 119,
 120,
 123,
 124,
 125,
 126,
 127,
 130,
 131,
 132,
 133,
 134,
 135,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 153,
 154,
 156,
 157,
 159,
 160,
 161,
 162,
 163,
 164,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 174,
 175,
 177,
 178,
 179,
 181,
 182,
 183,
 184,
 186,
 187,
 188,
 190,
 191,
 192,
 193,
 194,
 195,
 196,
 198,
 200,
 201,
 203,
 204,
 205,
 206,
 207,
 208,
 209,
 210,
 211,
 213,
 215,
 216,
 218,
 219,
 220,
 222,
 227,
 231,
 232,
 234,
 237,
 238,
 239,
 240,
 241,
 242,

In [17]:
# split_train_test_data_by_user

In [16]:
# Create MoocDataset instance for dataset.
# ========== BEGIN ========== #

for set_name in ["train", "test", "validation"]:

    print(f"Loading dataset from folder: {config.processed_data_dir}")
    dataset = Dataset(config.processed_data_dir, config.KG_ARGS, set_name)
    save_dataset(config.processed_data_dir, dataset, config.use_wandb)

    kg = KnowledgeGraph(
        dataset,
        config.KG_ARGS,
        set_name=set_name,
        use_user_relations=config.use_user_relations,
        use_entity_relations=config.use_entity_relations,
    )
    kg.compute_degrees()
    save_kg(config.processed_data_dir, kg, config.use_wandb)
# =========== END =========== #

# Genereate train/test labels.
# ========== BEGIN ========== #
print("Generate train/test labels.")
train_labels = generate_labels(config.processed_data_dir, "train.txt")
test_labels = generate_labels(config.processed_data_dir, "test.txt")
validation_labels = generate_labels(config.processed_data_dir, "validation.txt")

# save_labels(
#     config.processed_data_dir,
#     train_labels,
#     mode="train",
#     use_wandb=config.use_wandb,
# )
# save_labels(
#     config.processed_data_dir, test_labels, mode="test", use_wandb=config.use_wandb
# )
# save_labels(
#     config.processed_data_dir,
#     validation_labels,
#     mode="validation",
#     use_wandb=config.use_wandb,
# )

# # =========== END =========== #
# if config.use_wandb:
#     wandb.finish()