In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_recommenders as tfrs
from tqdm import tqdm

from src.utils.model.retrieval_model import RetrievalModel

In [21]:
val_rate = 0.2
test_rate = 0.1
batch_size = 100
embedding_dimension = 100
learning_rate = 0.1
early_stopping_flg = True
tensorboard_flg = False
max_epoch_num = 20

In [3]:
behaviors_df = pd.read_table(
    "data/MIND/MINDsmall_train/behaviors.tsv", names=("Impression_ID", "User_ID", "Time", "History", "Impressions")
)
# news_df = pd.read_table(
#     "data/MIND/MINDsmall_train/news.tsv",
#     names=("News_ID", "Category", "SubCategory", "Title", "Abstract", "URL", "Title_Entities", "Abstract_Entities"),
# )

In [4]:
behaviors_df

Unnamed: 0,Impression_ID,User_ID,Time,History,Impressions
0,1,U13740,11/11/2019 9:05:58 AM,N55189 N42782 N34694 N45794 N18445 N63302 N104...,N55689-1 N35729-0
1,2,U91836,11/12/2019 6:11:30 PM,N31739 N6072 N63045 N23979 N35656 N43353 N8129...,N20678-0 N39317-0 N58114-0 N20495-0 N42977-0 N...
2,3,U73700,11/14/2019 7:01:48 AM,N10732 N25792 N7563 N21087 N41087 N5445 N60384...,N50014-0 N23877-0 N35389-0 N49712-0 N16844-0 N...
3,4,U34670,11/11/2019 5:28:05 AM,N45729 N2203 N871 N53880 N41375 N43142 N33013 ...,N35729-0 N33632-0 N49685-1 N27581-0
4,5,U8125,11/12/2019 4:11:21 PM,N10078 N56514 N14904 N33740,N39985-0 N36050-0 N16096-0 N8400-1 N22407-0 N6...
...,...,...,...,...,...
156960,156961,U21593,11/14/2019 10:24:05 PM,N7432 N58559 N1954 N43353 N14343 N13008 N28833...,N2235-0 N22975-0 N64037-0 N47652-0 N11378-0 N4...
156961,156962,U10123,11/13/2019 6:57:04 AM,N9803 N104 N24462 N57318 N55743 N40526 N31726 ...,N3841-0 N61571-0 N58813-0 N28213-0 N4428-0 N25...
156962,156963,U75630,11/14/2019 10:58:13 AM,N29898 N59704 N4408 N9803 N53644 N26103 N812 N...,N55913-0 N62318-0 N53515-0 N10960-0 N9135-0 N5...
156963,156964,U44625,11/13/2019 2:57:02 PM,N4118 N47297 N3164 N43295 N6056 N38747 N42973 ...,N6219-0 N3663-0 N31147-0 N58363-0 N4107-0 N457...


In [5]:
# news_df

In [6]:
behaviors_df["User_ID"].value_counts()

U32146    62
U15740    44
U20833    41
U51286    40
U44201    40
          ..
U60416     1
U20588     1
U84385     1
U89164     1
U72015     1
Name: User_ID, Length: 50000, dtype: int64

In [7]:
user2clicks = {}
# for index,data in behaviors_df[:1000].iterrows():
for index, data in tqdm(behaviors_df.iterrows()):
    user = data["User_ID"]
    impressions = data["Impressions"].split(" ")
    clicks = []
    for impression in impressions:
        # print(impression)
        if impression[-1] == "1":
            clicks.append(impression[:-2])
    if user not in user2clicks:
        user2clicks[user] = clicks
    else:
        user2clicks[user] = user2clicks[user] + clicks

156965it [00:07, 20434.33it/s]


In [8]:
user_list = []
click_list = []
for user, v in tqdm(user2clicks.items()):
    for click in v:
        user_list.append(user)
        click_list.append(click)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 393846.89it/s]


In [9]:
print(len(user_list))
print(len(click_list))

236344
236344


In [10]:
click_df = pd.DataFrame(list(zip(user_list, click_list)), columns=["user_id", "item_id"])

In [11]:
click_df

Unnamed: 0,user_id,item_id
0,U13740,N55689
1,U13740,N28910
2,U13740,N58133
3,U91836,N17059
4,U91836,N26365
...,...,...
236339,U43157,N64152
236340,U43157,N41533
236341,U66493,N51048
236342,U66493,N11817


In [22]:
ratings = tf.data.Dataset.from_tensor_slices({"user_id": click_df["user_id"], "item_id": click_df["item_id"]})
# train_size = int(len(click_df) * (1 - test_rate))
# val_size = int(train_size * (1 - val_rate))
val_size = int(len(click_df) * val_rate)
test_size = int(len(click_df) * test_rate)
train_size = len(click_df) - val_size - test_size
train = ratings.take(train_size).batch(batch_size)
val = ratings.skip(train_size).take(val_size).batch(batch_size)
test = ratings.skip(train_size + val_size).take(test_size).batch(batch_size)

In [23]:
unique_user_ids = np.array(list(set(user_list)))
unique_item_ids = np.array(list(set(click_list)))
unique_item_dataset = tf.data.Dataset.from_tensor_slices(unique_item_ids)

In [24]:
# unique_user_ids.size
unique_item_ids.size

7713

In [25]:
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
    model = RetrievalModel(
        unique_user_ids=unique_user_ids,
        unique_item_ids=unique_item_ids,
        user_dict_key="user_id",
        item_dict_key="item_id",
        embedding_dimension=embedding_dimension,
        metrics_candidate_dataset=unique_item_dataset,
    )
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


  return bool(asarray(a1 == a2).all())


In [26]:
callbacks = []
if early_stopping_flg:
    callbacks.append(
        tf.keras.callbacks.EarlyStopping(
            monitor="total_loss",
            min_delta=0,
            patience=3,
            verbose=0,
            mode="auto",
            baseline=None,
            restore_best_weights=False,
        )
    )
if tensorboard_flg:
    tfb_log_path = log_path + datetime.now().strftime("%Y%m%d-%H%M%S")
    callbacks.append(
        tf.keras.callbacks.TensorBoard(
            log_dir=tfb_log_path,
            histogram_freq=1,
        )
    )

In [27]:
model.fit(x=train, validation_data=val, epochs=max_epoch_num, callbacks=callbacks)

Epoch 1/20


2022-09-10 15:41:34.965486: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\027TensorSliceDataset:4355"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
  



2022-09-10 15:42:34.311239: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\027TensorSliceDataset:4355"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
  

Epoch 2/20

2022-09-10 15:44:02.502180: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 3/20
   5/1655 [..............................] - ETA: 55s - factorized_top_k/top_1_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_5_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_10_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_50_categorical_accuracy: 0.0000e+00 - factorized_top_k/top_100_categorical_accuracy: 0.0000e+00 - loss: 32328.3785 - regularization_loss: 0.0000e+00 - total_loss: 32328.3785

2022-09-10 15:44:34.341094: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




2022-09-10 15:45:34.116523: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.


Epoch 4/20

2022-09-10 15:47:01.822633: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.




<keras.callbacks.History at 0x1370d97f0>

In [28]:
model.evaluate(test, return_dict=True)

2022-09-10 15:48:18.714571: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2"
op: "TensorSliceDataset"
input: "Placeholder/_0"
input: "Placeholder/_1"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 236344
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\027TensorSliceDataset:4355"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
      args {
  



{'factorized_top_k/top_1_categorical_accuracy': 4.231192360748537e-05,
 'factorized_top_k/top_5_categorical_accuracy': 0.00021155961439944804,
 'factorized_top_k/top_10_categorical_accuracy': 0.00038080732338130474,
 'factorized_top_k/top_50_categorical_accuracy': 0.002496403409168124,
 'factorized_top_k/top_100_categorical_accuracy': 0.0066006602719426155,
 'loss': 136.3968963623047,
 'regularization_loss': 0,
 'total_loss': 136.3968963623047}