In [25]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_recommenders as tfrs
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from src.utils.model.retrieval_model import RetrievalModel

In [3]:
val_rate = 0.2
test_rate = 0.1
batch_size = 100
embedding_dimension = 100
learning_rate = 0.1
early_stopping_flg = True
tensorboard_flg = False
max_epoch_num = 20

In [6]:
behaviors_df = pd.read_csv(
    "data/RentalProperties/user_activity.csv", names=("item_id", "user_id", "event_type", "create_timestamp")
)

In [7]:
behaviors_df

Unnamed: 0,item_id,user_id,event_type,create_timestamp
0,item_id,user_id,event_type,create_timestamp
1,00062bc5-2535-4b1e-bbcb-228526c990b8,182aa519-83a8-848f-84a1-8697046d84c2,seen,2020-02-03 15:47:25.273977
2,00062bc5-2535-4b1e-bbcb-228526c990b8,189a081a-ae0f-499d-9092-01758d93fa7f,seen,2020-02-04 20:19:31.040304
3,00062bc5-2535-4b1e-bbcb-228526c990b8,189a081a-ae0f-499d-9092-01758d93fa7f,sent_catalog_link,2020-02-04 20:19:00.110416
4,00062bc5-2535-4b1e-bbcb-228526c990b8,189a081a-ae0f-499d-9092-01758d93fa7f,visit_request-canceled,2020-02-04 20:54:31.595305
...,...,...,...,...
323889,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,f8df45aa-77ae-45aa-8235-c4d4806d2ad3,seen_in_list,2020-02-13 11:03:34.598830
323890,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,fa5f0121-8a84-87aa-871f-81d3e2e16a2a,seen_in_list,2020-02-13 13:18:08.231600
323891,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,faca30a5-8faa-8661-810f-8cf36f8e1d54,seen_in_list,2020-02-12 18:59:36.575983
323892,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,fc3a6175-af38-47aa-a9fa-15fc7fd5fae1,seen_in_list,2020-02-13 13:58:44.110670


In [30]:
# behaviors_df.query('event_type == "seen"')["user_id"].value_counts()
seen_df = behaviors_df.query('event_type == "seen"')
count_df = pd.DataFrame(seen_df["user_id"].value_counts()).reset_index().rename(columns={"index": "user_id", "user_id": "count"})

unique_user_ids = list(count_df.query("count >= 10")["user_id"])

In [55]:
len(unique_user_ids)

2435

In [43]:
count_df

Unnamed: 0,user_id,count
0,69c257cd-8033-82aa-8950-8804a03c2ed1,1141
1,e4261ef5-8f23-848d-87c4-875798d7e077,498
2,62495dc1-8639-8cdc-842a-89f5cf08f5d5,439
3,cef54613-859c-8dc0-8e69-83f032adfc65,229
4,19e066a4-275d-4a46-a5e6-ac66993954ea,203
...,...,...
16682,a392f478-8d44-817a-82aa-8c6efead7039,1
16683,9a6a1869-8794-8606-8ea6-8c39c3359a19,1
16684,d82380d2-84c9-8cec-8600-81c075fa28e8,1
16685,6489102d-0c93-45a5-901c-cd6c9041917f,1


In [32]:
seen_df

Unnamed: 0,item_id,user_id,event_type,create_timestamp
1,00062bc5-2535-4b1e-bbcb-228526c990b8,182aa519-83a8-848f-84a1-8697046d84c2,seen,2020-02-03 15:47:25.273977
2,00062bc5-2535-4b1e-bbcb-228526c990b8,189a081a-ae0f-499d-9092-01758d93fa7f,seen,2020-02-04 20:19:31.040304
10,00062bc5-2535-4b1e-bbcb-228526c990b8,69c257cd-8033-82aa-8950-8804a03c2ed1,seen,2020-02-05 10:18:54.229749
21,0019bb07-bb6d-44dd-b6b7-d1b5405338d2,054f4d81-ee93-407c-af31-ff84c3a103a8,seen,2020-02-10 22:36:14.563238
23,0019bb07-bb6d-44dd-b6b7-d1b5405338d2,08a44ae5-8419-88dd-8c61-8a9cd391677c,seen,2020-02-12 16:03:27.785961
...,...,...,...,...
323845,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,a75fe72d-a92d-4c01-aa4e-5adaf925d7c3,seen,2020-02-12 16:51:17.061706
323846,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,aaffad4c-61fa-4d13-ad92-0767cd3a3347,seen,2020-02-11 20:38:47.089192
323867,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,cef54613-859c-8dc0-8e69-83f032adfc65,seen,2020-02-12 12:02:43.374588
323878,fffbf497-b7e8-434b-b8e9-d74dbbb492bb,e4261ef5-8f23-848d-87c4-875798d7e077,seen,2020-02-12 12:22:58.122055


In [34]:
seen_df = seen_df[seen_df["user_id"].isin(unique_user_ids)]

In [44]:
train_val_df,test_df = train_test_split(seen_df,test_size=0.1,stratify=seen_df['user_id'])

In [47]:
train_df,val_df = train_test_split(train_val_df,test_size=0.2,stratify=train_val_df['user_id'])

In [51]:
print(len(train_df['user_id'].unique()))
print(len(val_df['user_id'].unique()))
print(len(test_df['user_id'].unique()))

2435
2435
2435


In [52]:
train_df

Unnamed: 0,item_id,user_id,event_type,create_timestamp
19022,0f8f6029-d7e7-4fd6-b269-242cd83d3694,94a53c1d-a06e-49d0-8daa-a4ae13839a68,seen,2020-02-03 23:20:06.953841
39377,1f7d7e95-1404-4549-bb29-f5b63bf11983,765fdfdf-8754-8818-81da-89a6c0742a20,seen,2020-02-11 12:16:26.289445
279416,dfe2cd84-4bdd-4e1b-b68c-ec99cfc25291,48f1a7f0-c584-4730-ae0f-9d197225846e,seen,2020-02-12 23:00:39.691387
223636,b8c5c55f-b0ef-4771-bdb6-b69e237966d5,f6d63567-a745-4393-8737-73797c1cd658,seen,2020-02-08 11:20:25.238062
133620,6f4667c1-8fdc-452e-8507-2bd5c0bdb37b,4e5aa771-66a9-4781-9648-0ca904731d77,seen,2020-02-09 16:30:26.213081
...,...,...,...,...
10830,092bf58c-113d-4c3f-b6f6-d0b82bb63d80,f63adaa8-0aa5-4209-9d6a-df7974a16a26,seen,2020-02-11 20:10:55.791859
246649,c685b023-4cd0-4f3f-9b0f-2b5f858e9d00,46826d84-81ce-878c-8352-83e954f55f7a,seen,2020-02-09 10:40:43.396456
9399,082e2c8f-54f0-4b3d-b882-d0c00b6721bd,59ae1e57-d15a-4952-a78e-aaf652f3a991,seen,2020-02-07 13:48:19.818211
259635,d01cb895-9688-4935-92e0-fbc589d2b959,19ecce1c-604c-4cf8-9519-a4128088195a,seen,2020-02-11 09:47:11.017766


In [53]:
val_df

Unnamed: 0,item_id,user_id,event_type,create_timestamp
128327,6b516020-f480-44d3-b70b-108d749eb257,7dea581e-8a34-843f-8ecc-824a960f86d6,seen,2020-02-10 10:04:05.094784
256336,ced54018-f31b-4945-b62c-50dfc47b439d,62495dc1-8639-8cdc-842a-89f5cf08f5d5,seen,2020-02-11 11:49:07.772855
206319,b3035c18-e9c6-4e19-bbbc-edc55e31b46f,a673224f-82c3-836f-8890-845765790ded,seen,2020-02-08 01:37:00.007594
171486,8b24983b-ecfb-4691-be55-db237c73945d,7a88a9e9-2ac2-41a0-9f27-2edd12286dea,seen,2020-02-05 14:29:27.168170
34510,1b551809-0bbd-4db9-bb30-dfb7d377f203,ac39f982-600a-4aa3-92cf-97ff7d8090c1,seen,2020-02-04 20:14:57.085780
...,...,...,...,...
227258,bb07e2de-9dcf-4c42-97ce-98f68b0d58d8,ac165fa2-8e73-823a-8989-8aa1cdacd605,seen,2020-02-08 13:44:08.073464
234414,bd1d6eb9-0710-4f23-bb20-0bf5b5121c01,2c6d8005-a8e7-4ead-a82f-f804e90cea1c,seen,2020-02-04 19:06:25.995066
78944,4022f756-df82-402b-b557-50be34338ce6,af8a675a-aed6-4981-8c16-262fa31f2e64,seen,2020-02-13 12:22:53.751543
125760,692b86ed-7bde-405d-b791-6f51cc57ff02,06988f66-66f3-4ca1-a259-7558f95a594c,seen,2020-02-09 21:39:11.101324


In [54]:
test_df

Unnamed: 0,item_id,user_id,event_type,create_timestamp
166020,85e28b4c-ed5d-4b9d-b4d0-d95351d2b966,8ef662cc-a967-4ae7-a367-422e15618ea2,seen,2020-02-04 17:26:39.541734
15505,0c48c1d7-dc18-4585-97cc-86d8343eebbb,91c177e7-8cc3-8833-80c3-875c0c99f66c,seen,2020-02-08 14:01:18.719157
175762,8e65bb64-4bde-4696-b5df-8fef0bef3696,5ad79f59-8fe1-8d34-81c6-88dd2afc1c77,seen,2020-02-06 07:34:47.000717
277848,de56c3cb-3d22-4398-8b94-3fdbbe62b9f5,548fa9fa-82d5-8c65-8096-8667dae10128,seen,2020-02-05 14:21:13.400765
300055,ee798534-6b75-438e-be17-5db1dbd089b1,09a84873-1a79-45a2-9d99-2d663fa15563,seen,2020-02-11 02:38:54.791735
...,...,...,...,...
226435,b9c90696-421d-4485-91dd-43193efb9318,f7a96dd6-2419-4f5c-9c2c-045135a889a6,seen,2020-02-09 12:51:00.713844
285228,e4720389-75b5-4d65-809b-1c9dc93e8470,54e3a601-faf4-4c7a-a7a4-8595cf3514e4,seen,2020-02-08 13:58:39.573425
16764,0d2d1b0b-4459-4bb3-b7b6-eb931c1561b1,6d8eaede-8ed5-8727-8780-8c049d137670,seen,2020-02-03 14:53:14.335257
316234,fb18042b-679b-47cd-8e7d-74f17261ee06,95f1a35d-8480-8efc-8dc1-8e7cc93f7e56,seen,2020-02-06 21:22:43.656481


In [21]:
count_df.describe()

Unnamed: 0,count
count,16687.0
mean,5.718164
std,14.357755
min,1.0
25%,1.0
50%,2.0
75%,6.0
max,1141.0


In [7]:
user2clicks = {}
# for index,data in behaviors_df[:1000].iterrows():
for index, data in tqdm(behaviors_df.iterrows()):
    user = data["User_ID"]
    impressions = data["Impressions"].split(" ")
    clicks = []
    for impression in impressions:
        # print(impression)
        if impression[-1] == "1":
            clicks.append(impression[:-2])
    if user not in user2clicks:
        user2clicks[user] = clicks
    else:
        user2clicks[user] = user2clicks[user] + clicks

156965it [00:07, 20434.33it/s]


In [8]:
user_list = []
click_list = []
for user, v in tqdm(user2clicks.items()):
    for click in v:
        user_list.append(user)
        click_list.append(click)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 393846.89it/s]


In [9]:
print(len(user_list))
print(len(click_list))

236344
236344


In [10]:
click_df = pd.DataFrame(list(zip(user_list, click_list)), columns=["user_id", "item_id"])

In [11]:
click_df

Unnamed: 0,user_id,item_id
0,U13740,N55689
1,U13740,N28910
2,U13740,N58133
3,U91836,N17059
4,U91836,N26365
...,...,...
236339,U43157,N64152
236340,U43157,N41533
236341,U66493,N51048
236342,U66493,N11817


In [22]:
ratings = tf.data.Dataset.from_tensor_slices({"user_id": click_df["user_id"], "item_id": click_df["item_id"]})
# train_size = int(len(click_df) * (1 - test_rate))
# val_size = int(train_size * (1 - val_rate))
val_size = int(len(click_df) * val_rate)
test_size = int(len(click_df) * test_rate)
train_size = len(click_df) - val_size - test_size
train = ratings.take(train_size).batch(batch_size)
val = ratings.skip(train_size).take(val_size).batch(batch_size)
test = ratings.skip(train_size + val_size).take(test_size).batch(batch_size)

In [23]:
unique_user_ids = np.array(list(set(user_list)))
unique_item_ids = np.array(list(set(click_list)))
unique_item_dataset = tf.data.Dataset.from_tensor_slices(unique_item_ids)

In [24]:
# unique_user_ids.size
unique_item_ids.size

7713

In [25]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = RetrievalModel(
        unique_user_ids=unique_user_ids,
        unique_item_ids=unique_item_ids,
        user_dict_key="user_id",
        item_dict_key="item_id",
        embedding_dimension=embedding_dimension,
        metrics_candidate_dataset=unique_item_dataset,
    )
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


  return bool(asarray(a1 == a2).all())


In [26]:
callbacks = []
if early_stopping_flg:
    callbacks.append(
        tf.keras.callbacks.EarlyStopping(
            monitor="total_loss",
            min_delta=0,
            patience=3,
            verbose=0,
            mode="auto",
            baseline=None,
            restore_best_weights=False,
        )
    )
if tensorboard_flg:
    tfb_log_path = log_path + datetime.now().strftime("%Y%m%d-%H%M%S")
    callbacks.append(
        tf.keras.callbacks.TensorBoard(
            log_dir=tfb_log_path,
            histogram_freq=1,
        )
    )

In [27]:
model.fit(x=train, validation_data=val, epochs=max_epoch_num, callbacks=callbacks)

Epoch 1/20


2022-09-10 15:41:34.965486: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\027TensorSliceDataset:4355"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
  



2022-09-10 15:42:34.311239: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\027TensorSliceDataset:4355"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
  

Epoch 2/20

2022-09-10 15:44:02.502180: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 3/20
   5/1655 [..............................] - ETA: 55s - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 32328.3785 - regularization_loss: 0.0000e+00 - total_loss: 32328.3785

2022-09-10 15:44:34.341094: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-09-10 15:45:34.116523: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 4/20

2022-09-10 15:47:01.822633: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




<keras.callbacks.History at 0x1370d97f0>

In [28]:
model.evaluate(test, return_dict=True)

2022-09-10 15:48:18.714571: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\027TensorSliceDataset:4355"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
  



{'factorized_top_k/top_1_categorical_accuracy': 4.231192360748537e-05,
 'factorized_top_k/top_5_categorical_accuracy': 0.00021155961439944804,
 'factorized_top_k/top_10_categorical_accuracy': 0.00038080732338130474,
 'factorized_top_k/top_50_categorical_accuracy': 0.002496403409168124,
 'factorized_top_k/top_100_categorical_accuracy': 0.0066006602719426155,
 'loss': 136.3968963623047,
 'regularization_loss': 0,
 'total_loss': 136.3968963623047}