In [1]:
import tensorflow as tf
import pandas as pd
from pprint import pprint
import numpy as np
import tensorflow_recommenders as tfrs
import tensorflow_datasets as tfds


In [6]:
ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

In [16]:
ratings = ratings.map(lambda x: {
    "movie_title": x["movie_title"],
    "user_id": x["user_id"],
    "timestamp": x["timestamp"],
})
movies = movies.map(lambda x: x["movie_title"])

In [17]:
timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))

max_timestamp = timestamps.max()
min_timestamp = timestamps.min()

timestamp_buckets = np.linspace(
    min_timestamp, max_timestamp, num=1000,
)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

2022-10-05 15:30:29.311405: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [None]:
class UserModel(tf.keras.Model):
    def __init__(self, use_timestamp):
        self._use_timestamp = use_timestamp

        self.user_emedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(vocabulary = unique_user_ids, mask_token = None),
            tf.keras.layers.Embedding(len(unique_user_ids) +1, 32)
        ])

        if use_timestamp:
            self.timestamp_embedding = tf.keras.Sequential([
                tf.keras.layers.Discretization(timestamp_buckets.tolist()),
                tf.keras.layers.Embedding(len(timestamp_buckets) + 1, 32)
            ])

            self.normalized_timestamps = tf.keras.layers.Normalization(axis = None)

            self.normalized_timestamps.adapt(timestamps)

    def call(self, inputs):
        if not self._use_timestamps:
            return self.user_emedding(inputs["user_id"])

        return tf.concat([self.user_emedding(inputs["user_id"]), self.timestamp_embedding(inputs["timestamp"]),
        tf.reshape(self.normalized_timestamp(inputs["timestamp"]), (-1,1))], axis = 1)

In [None]:
class MovieModel(tf.keras.Model):
    def __init__(self):
        super().__init()

        max_tokens = 10000

        self.title_embedding = tf.keras.Sequential([
            tf.keras.layers.StringLookup(vocabulary = unique_movie_titles, mask_token=None),
            tf.keras.layers.Embedding(len(unique_movie_titles)+1, 32)
        ])

        self.title_vectorizer = tf.keras.layers.TextVectorization(max_tokens = max_tokens)

        self.title_text_embedding = tf.keras.Sequential([
            self.title_vectorizer,
            tf.keras.layers.Embedding(max_tokens, 32, mask_zero = True),
            tf.keras.layers.GlobalAveragePooling1D()
        ])

        self.title_vectorizer.adapt(movies)

    def call(self, titles):
        return tf.concat([
            self.title_embedding(titles),
            self.title_text_embedding(titles),
        ], axis = 1)
