In [None]:
import os
import random

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from utils.data import extract_embedding, get_interactions_dataframe, mark_evaluation_rows
from utils.hashing import pre_hash, HashesContainer


# Triplet sampling

In [None]:
# Mode
# Use 'MODE_PROFILE = True' for CuratorNet-like training 
# Use 'MODE_PROFILE = False' for VBPR-like training
MODE_PROFILE = False
MODE_PROFILE_VERBOSE = "profile" if MODE_PROFILE else "user"


In [None]:
# Feature extractor
FEATURE_EXTRACTOR = "resnet50"


In [None]:
# Paths (general)
EMBEDDING_PATH = os.path.join("data", f"embedding-{FEATURE_EXTRACTOR}.npy")
INTERACTIONS_PATH = os.path.join("data", "wikimedia.csv")
OUTPUT_TRAIN_PATH = os.path.join("data", f"naive-{MODE_PROFILE_VERBOSE}-train.csv")
OUTPUT_VALID_PATH = os.path.join("data", f"naive-{MODE_PROFILE_VERBOSE}-validation.csv")
OUTPUT_EVAL_PATH = os.path.join("data", f"naive-{MODE_PROFILE_VERBOSE}-evaluation.csv")

# General constants
RNG_SEED = 0

# Sampling constants
GROUP_USER_INTERACTIONS_BY_TIMESTAMP = True
MAX_PROFILE_SIZE = 10
TOTAL_SAMPLES_TRAIN = 5_000_000
TOTAL_SAMPLES_VALID = 500_000


In [None]:
# Freezing RNG seed if needed
if RNG_SEED is not None:
    print(f"\nUsing random seed... ({RNG_SEED})")
    random.seed(RNG_SEED)
    np.random.seed(RNG_SEED)


In [None]:
# Load embedding from file
print(f"\nLoading embedding from file... ({EMBEDDING_PATH})")
embedding = np.load(EMBEDDING_PATH, allow_pickle=True)

# Extract features and "id2index" mapping
print("\nExtracting data into variables...")
features, item_id2index, _ = extract_embedding(embedding, verbose=True)
print(f">> Features shape: {features.shape}")
del embedding  # Release some memory


In [None]:
# Load interactions CSVs
print(f"\nLoading interactions from files...")
interactions_df = get_interactions_dataframe(
    INTERACTIONS_PATH,
    display_stats=True,
)

# Apply 'item_id2index', to work with indexes only
print("\nApply 'item_id2index' mapping for items...")
interactions_df["item_id"] = interactions_df["item_id"].map(str)
n_missing_ids = interactions_df[~interactions_df["item_id"].isin(item_id2index)]["item_id"].count()
interactions_df = interactions_df[interactions_df["item_id"].isin(item_id2index)]
interactions_df["item_id"] = interactions_df["item_id"].map(item_id2index)
print(f">> Mapping applied, ({n_missing_ids} values in 'item_id2index')")

# Store mapping from user_id to index (0-index, no skipping)
print("\nCreate 'user_id2index' mapping for users...")
unique_user_ids = interactions_df["user_id"].unique()
new_user_ids = np.argsort(unique_user_ids)
user_id2index = dict(zip(unique_user_ids, new_user_ids))

# Apply 'user_id2index', to work with indexes only
print("\nApply 'user_id2index' mapping for users...")
n_missing_ids = interactions_df[~interactions_df["user_id"].isin(user_id2index)]["user_id"].count()
interactions_df = interactions_df[interactions_df["user_id"].isin(user_id2index)]
interactions_df["user_id"] = interactions_df["user_id"].map(user_id2index)
print(f">> Mapping applied, ({n_missing_ids} values in 'user_id2index')")

# Mark interactions used for evaluation procedure if needed
if "evaluation" not in interactions_df:
    print("\nApply evaluation split...")
    interactions_df = mark_evaluation_rows(interactions_df)
    # Check if new column exists and has boolean dtype
    assert interactions_df["evaluation"].dtype.name == "bool"
    print(f">> Interactions: {interactions_df.shape}")

# Split interactions data according to evaluation column
evaluation_df = interactions_df[interactions_df["evaluation"]]
interactions_df = interactions_df[~interactions_df["evaluation"]]
assert not interactions_df.empty
assert not evaluation_df.empty
print(f">> Evaluation: {evaluation_df.shape} | Interactions: {interactions_df.shape}")

# Form interactions baskets, grouping by timestamp and user_id
if GROUP_USER_INTERACTIONS_BY_TIMESTAMP:
    print("\nForm interactions groups (baskets), by timestamp and user_id...")
    interactions_df = interactions_df.groupby(["timestamp", "user_id"])["item_id"].apply(list)
    interactions_df = interactions_df.reset_index()
    interactions_df = interactions_df.sort_values("timestamp")
    interactions_df = interactions_df.reset_index(drop=True)
else:
    print("\nInteractions groups (baskets), by timestamp and user_id, skipped")


In [None]:
# Copy interactions dataframe to complete evaluation dataframe
_idf = interactions_df.sort_values("timestamp").groupby(["user_id"])["item_id"].apply(list).reset_index().copy()
if GROUP_USER_INTERACTIONS_BY_TIMESTAMP:
    # Group and flatten interactions to create user profiles
    evaluation_df["profile"] = evaluation_df["user_id"].apply(
        lambda user_id: [
            item_id
            for row in _idf[_idf["user_id"] == user_id]["item_id"]
            for interaction in row
            for item_id in interaction
        ],
    )
if MAX_PROFILE_SIZE:
    # Reduce size of profiles if needed
    evaluation_df["profile"] = evaluation_df["profile"].apply(lambda profile: profile[-MAX_PROFILE_SIZE:])

# Rename predict column and drop evaluation column
evaluation_df.rename(columns={"item_id": "predict"}, inplace=True)
evaluation_df.drop(columns=["evaluation"], inplace=True)


In [None]:
print("\nCreating helpers instances...")
# Creating hashes container for duplicates detection
hashes_container = HashesContainer()

# Sampling constants
print("\nCalculating important values...")
N_USERS = interactions_df["user_id"].nunique()
N_ITEMS = len(features)
print(f">> N_USERS = {N_USERS} | N_ITEMS = {N_ITEMS}")


In [None]:
def random_triplet_sampling(samples_per_user, hashes_container, desc=None, limit_iteration=10000):
    interactions = interactions_df.copy()
    samples = []
    for ui, group in tqdm(interactions.groupby("user_id"), desc=desc):
        # Get profile artworks
        full_profile = np.hstack(group["item_id"].values).tolist()
        full_profile_set = set(full_profile)
        n = samples_per_user
        aux_limit = limit_iteration
        while n > 0:
            if aux_limit == 0:
                break
            # Sample positive and negative items
            pi_index = random.randrange(len(full_profile))
            pi = full_profile[pi_index]
            # Get profile
            if MAX_PROFILE_SIZE:
                profile = random.sample(full_profile, min(len(full_profile), MAX_PROFILE_SIZE))
                if pi not in profile:
                    profile = profile[0: -1]
                    profile.append(pi)
            else:
                profile = list(full_profile)
            # (While loop is in the sampling method)
            while True:
                ni = random.randint(0, N_ITEMS - 1)
                if ni not in full_profile_set:
                    break
            # If conditions are met, hash and enroll triple
            if MODE_PROFILE:
                triple = (profile, pi, ni)
            else:
                triple = (ui, pi, ni)
            if not hashes_container.enroll(pre_hash(triple, contains_iter=MODE_PROFILE)):
                limit_iteration -= 1
                continue
            # If not seen, store sample
            samples.append((profile, pi, ni, ui))
            n -= 1
    return samples


In [None]:
samples_training = random_triplet_sampling(
    np.ceil(TOTAL_SAMPLES_TRAIN / N_USERS),
    hashes_container,
    desc="Random sampling (training)",
)
samples_testing = random_triplet_sampling(
    np.ceil(TOTAL_SAMPLES_VALID / N_USERS),
    hashes_container,
    desc="Random sampling (testing)"
)

# Total collected samples
print(f"Training samples: {len(samples_training)} ({TOTAL_SAMPLES_TRAIN})")
print(f"Testing samples: {len(samples_testing)} ({TOTAL_SAMPLES_VALID})")

# Log out detected collisions
print(f">> Total hash collisions: {hashes_container.collisions}")


In [None]:
# Merge triples into a single list
print("\nMerging strategies samples into a single list")
TRAINING_DATA = samples_training
print(f">> Training samples: {len(TRAINING_DATA)}")
# Merge strategies samples
VALIDATION_DATA = samples_testing
print(f">> Validation samples: {len(VALIDATION_DATA)}")


In [None]:
# Search for duplicated hashes
print(f"\nNaive triples validation and looking for duplicates...")
validation_hash_check = HashesContainer()
all_samples = [
    triple
    for subset in (TRAINING_DATA, VALIDATION_DATA)
    for triple in subset
]
user_ids = interactions_df["user_id"].unique()
user_data = dict()
for triple in tqdm(all_samples, desc="Naive validation"):
    profile, pi, ni, ui = triple
    if MODE_PROFILE:
        assert validation_hash_check.enroll(pre_hash((profile, pi, ni)))
    else:
        assert validation_hash_check.enroll(pre_hash((ui, pi, ni), contains_iter=False))
    assert 0 <= pi < N_ITEMS
    assert 0 <= ni < N_ITEMS
    assert pi != ni
    if ui == -1:
        continue
    assert ui in user_ids
    if not ui in user_data:
        user = interactions_df[interactions_df["user_id"] == ui]
        user_data[ui] = set(np.hstack(user["item_id"].values))
    user_artworks = user_data[ui]
    assert all(i in user_artworks for i in profile)
print(">> No duped hashes found")


In [None]:
print("\nCreating output files (train and valid)...")
# Training dataframe
df_train = pd.DataFrame(TRAINING_DATA, columns=["profile", "pi", "ni", "ui"])
df_train["profile"] = df_train["profile"].map(lambda l: " ".join(map(str, l)))
print(f">> Saving training samples ({OUTPUT_TRAIN_PATH})")
df_train.to_csv(OUTPUT_TRAIN_PATH, index=False)

# Validation dataframe
df_validation = pd.DataFrame(VALIDATION_DATA, columns=["profile", "pi", "ni", "ui"])
df_validation["profile"] = df_validation["profile"].map(lambda l: " ".join(map(str, l)))
print(f">> Saving validation samples ({OUTPUT_VALID_PATH})")
df_validation.to_csv(OUTPUT_VALID_PATH, index=False)

# Evaluation dataframe
df_evaluation = evaluation_df.copy()
# if GROUP_USER_INTERACTIONS_BY_TIMESTAMP:
#     df_evaluation["predict"] = df_evaluation["predict"].map(lambda l: " ".join(map(str, l)))
df_evaluation["profile"] = df_evaluation["profile"].map(lambda l: " ".join(map(str, l)))
print(f">> Saving evaluation data ({OUTPUT_EVAL_PATH})")
df_evaluation.to_csv(OUTPUT_EVAL_PATH, index=False)
