In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_recommenders as tfrs
from tqdm import tqdm

from src.utils.model.retrieval_model import RetrievalModel

In [2]:
def make_click_df(behaviors_df):
    user2clicks = {}
    for index, data in tqdm(behaviors_df.iterrows()):
        user = data["User_ID"]
        impressions = data["Impressions"].split(" ")
        clicks = []
        for impression in impressions:
            if impression[-1] == "1":
                clicks.append(impression[:-2])
        if user not in user2clicks:
            user2clicks[user] = clicks
        else:
            user2clicks[user] = user2clicks[user] + clicks

    user_list = []
    click_list = []
    for user, v in tqdm(user2clicks.items()):
        for click in v:
            user_list.append(user)
            click_list.append(click)

    print("user_list", len(user_list))
    print("click_list", len(click_list))

    click_df = pd.DataFrame(list(zip(user_list, click_list)), columns=["user_id", "item_id"])

    return click_df

In [3]:
val_rate = 0.2
test_rate = 0.1
batch_size = 200
embedding_dimension = 512
learning_rate = 0.1
early_stopping_flg = True
tensorboard_flg = False
max_epoch_num = 20

In [4]:
train_behaviors_df = pd.read_table(
    "data/MIND/MINDsmall_train/behaviors.tsv", names=("Impression_ID", "User_ID", "Time", "History", "Impressions")
)
val_behaviors_df = pd.read_table(
    "data/MIND/MINDsmall_dev/behaviors.tsv", names=("Impression_ID", "User_ID", "Time", "History", "Impressions")
)
test_behaviors_df = pd.read_table(
    "data/MIND/MINDsmall_dev/behaviors.tsv", names=("Impression_ID", "User_ID", "Time", "History", "Impressions")
)
# news_df = pd.read_table(
#     "data/MIND/MINDsmall_train/news.tsv",
#     names=("News_ID", "Category", "SubCategory", "Title", "Abstract", "URL", "Title_Entities", "Abstract_Entities"),
# )

In [5]:
# train_behaviors_df

In [6]:
print("unique user number of train", len(train_behaviors_df["User_ID"].unique()))
print("unique user number of val", len(val_behaviors_df["User_ID"].unique()))
print("unique user number of test", len(test_behaviors_df["User_ID"].unique()))

# print(train_behaviors_df["User_ID"].value_counts())
# print(val_behaviors_df["User_ID"].value_counts())
# print(test_behaviors_df["User_ID"].value_counts())

unique user number of train 50000
unique user number of val 50000
unique user number of test 50000


In [7]:
train_click_df = make_click_df(train_behaviors_df)
train_ratings = tf.data.Dataset.from_tensor_slices({"user_id": train_click_df["user_id"], "item_id": train_click_df["item_id"]})
val_click_df = make_click_df(val_behaviors_df)
val_ratings = tf.data.Dataset.from_tensor_slices({"user_id": val_click_df["user_id"], "item_id": val_click_df["item_id"]})
test_click_df = make_click_df(test_behaviors_df)
test_ratings = tf.data.Dataset.from_tensor_slices({"user_id": test_click_df["user_id"], "item_id": test_click_df["item_id"]})

156965it [00:07, 21185.50it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 665272.56it/s]
2022-09-11 12:59:49.246701: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


user_list 236344
click_list 236344


73152it [00:03, 22144.60it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 1202964.46it/s]


user_list 111383
click_list 111383


73152it [00:03, 21419.42it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 1151307.14it/s]


user_list 111383
click_list 111383


In [8]:
train = train_ratings.batch(batch_size)
val = val_ratings.batch(batch_size)
test = test_ratings.batch(batch_size)

In [16]:
unique_user_ids = np.array(
    list(
        (set(train_click_df["user_id"].unique()) | set(val_click_df["user_id"].unique()) | set(test_click_df["user_id"].unique()))
    )
)
unique_item_ids = np.array(
    list(set(train_click_df["item_id"].unique()) | set(val_click_df["item_id"].unique()) | set(test_click_df["item_id"].unique()))
)
unique_item_dataset = tf.data.Dataset.from_tensor_slices(unique_item_ids)

In [17]:
print(len(unique_user_ids))
print(len(unique_item_ids))
print(len(set(train_click_df["item_id"].unique())))

94057
9100
7713


In [18]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = RetrievalModel(
        unique_user_ids=unique_user_ids,
        unique_item_ids=unique_item_ids,
        user_dict_key="user_id",
        item_dict_key="item_id",
        embedding_dimension=embedding_dimension,
        metrics_candidate_dataset=unique_item_dataset,
    )
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


  return bool(asarray(a1 == a2).all())


In [19]:
callbacks = []
if early_stopping_flg:
    callbacks.append(
        tf.keras.callbacks.EarlyStopping(
            monitor="total_loss",
            min_delta=0,
            patience=3,
            verbose=0,
            mode="auto",
            baseline=None,
            restore_best_weights=False,
        )
    )
if tensorboard_flg:
    tfb_log_path = log_path + datetime.now().strftime("%Y%m%d-%H%M%S")
    callbacks.append(
        tf.keras.callbacks.TensorBoard(
            log_dir=tfb_log_path,
            histogram_freq=1,
        )
    )

In [20]:
model.fit(x=train, validation_data=val, epochs=max_epoch_num, callbacks=callbacks)

Epoch 1/20


2022-09-11 13:30:28.123393: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:0"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
     



2022-09-11 13:37:04.521531: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 111383
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:1"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
     

Epoch 2/20


2022-09-11 13:38:01.473510: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


 154/1182 [==>...........................] - ETA: 5:42 - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 38660.6893 - regularization_loss: 0.0000e+00 - total_loss: 38660.6893

KeyboardInterrupt: 

In [21]:
model.evaluate(test, return_dict=True)

2022-09-11 13:38:55.137533: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 111383
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024TensorSliceDataset:2"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
     



{'factorized_top_k/top_1_categorical_accuracy': 0.00014364848902914673,
 'factorized_top_k/top_5_categorical_accuracy': 0.0005566378822550178,
 'factorized_top_k/top_10_categorical_accuracy': 0.001274880371056497,
 'factorized_top_k/top_50_categorical_accuracy': 0.00631155539304018,
 'factorized_top_k/top_100_categorical_accuracy': 0.012380704283714294,
 'loss': 6055.6240234375,
 'regularization_loss': 0,
 'total_loss': 6055.6240234375}