In [None]:
import tensorflow as tf
import tensorflow_recommenders as tfrs
import tensorflow_datasets as tfds
import pandas as pd
import numpy as np
from pprint import pprint

In [None]:
ratings = tfds.load("movielens/100k-ratings", split="train")
movies = tfds.load("movielens/100k-movies", split="train")

In [None]:
for x in ratings.take(1).as_numpy_iterator():
  pprint(x)

In [None]:
for x in movies.take(1).as_numpy_iterator():
  pprint(x)

In [None]:
ratings = ratings.map(lambda x : {
    "user_id" : x["user_id"],
    "timestamp" : x["timestamp"],
    "movie_title" : x["movie_title"],
    "raw_user_age" : x["raw_user_age"]
})

movies = movies.map(lambda x : x["movie_title"])

In [None]:
timestamps = np.concatenate(list(ratings.map(lambda x: x["timestamp"]).batch(100)))

max_timestep = timestamps.max()
min_timesteps = timestamps.min()

timetamps_buckets = np.linspace(min_timesteps, max_timestep, 1000)

unique_movie_titles = np.unique(np.concatenate(list(movies.batch(1000))))
unique_user_ids = np.unique(np.concatenate(list(ratings.batch(1_000).map(
    lambda x: x["user_id"]))))

In [None]:
user_ages = np.concatenate(list(ratings.map(lambda x: x["raw_user_age"]).batch(100)))

max_user_age = user_ages.max()
min_user_age = user_ages.min()

user_age_buckets = np.linspace(min_user_age, max_user_age, 1000)

In [None]:
class UserModel(tf.keras.Model):
  def __init__(self, use_timestamp, use_age):
    self.use_timestamp_ = use_timestamp
    self.use_age_ = use_age

    self.user_embedding = tf.keras.Sequential([
        tf.keras.layers.StringLookup(vocabulary = unique_user_ids, mask_tokens = None),
        tf.keras.layers.Embedding(len(unique_user_ids)+1, 32)
    ])

    if use_timestamp:
      self.timestamp_embeddings = tf.keras.Sequential([
          tf.keras.layers.Discretization(timetamps_buckets.toList()),
          tf.keras.layers.Embedding(len(timetamps_buckets)+1, 32)
      ])

      self.normalized_timestamps = tf.keras.layers.Normalization(axis = None)

      self.normalized_timestamps.adapt(timestamps)

    if use_age:
      self.age_embeddings = tf.keras.Sequential([
          tf.keras.layers.Discretization(user_age_buckets),
          tf.keras.layers.Embedding(len(user_age_buckets)+1, 64),
          tf.keras.layers.Embedding(64,32)
      ])

      self.normalized_ages = tf.keras.layers.Normalization(axis = None)

      self.normalized_ages.adapt(user_ages)

    def call(self, inputs):
      if not use_timestamp: 
        return tf.concat([self.user_embedding(inputs["user_id"]), self.age_embeddings(inputs["raw_user_age"]), 
                          tf.reshape(self.normalized_ages(inputs["raw_user_age"]), (-1,1))], axis = 1)
        
      if not use_age:
        return tf.concat([self.user_embedding(inputs["user_id"]), self.use_timestamp(inputs["timetamp"]), 
                          tf.reshape(self.normalized_ages(inputs["timestamp"]), (-1,1))], axis = 1)
        
      return tf.concat([self.user_embedding(inputs["user_id"]), self.age_embeddings(inputs["raw_user_age"]), 
                        self.use_timestamp(inputs["timetamp"]), tf.reshape(self.normalized_ages(inputs["raw_user_age"]), (-1,1)),
                        tf.reshape(self.normalized_ages(inputs["timestamp"]), (-1,1))])

In [None]:
class MovieModel(tf.keras.Model):
  def __init__(self):
    super().__init__()

    max_tokens = 10000

    self.title_embedding = tf.keras.Sequential([
        tf.keras.layers.StringLookup(vocabulary = unique_movie_titles, mask_tokens = None),
        tf.keras.layers.Embedding(len(unique_movie_titles)  +1 , 32)
    ])

    self.title_vectorizer = tf.keras.layers.Vectorizer(max_tokens = max_tokens)

    self.title_text_embedding = tf.keras.Sequential([
        self.title_vectorizer,
        tf.keras.layers.Embedding(max_tokens, 32, mask_zero = True),
        tf.keras.layers.GlobalAveragePooling1D()
    ])

    self.title_vectorizer.adapt(movies)

    def call(self, titles):
      return tf.concat([
          self.title_embedding(titles),
          self.title_text_embedding(titles)
      ], axis = 1)

In [None]:
class MovieLensModel(tfrs.models.Model):
  def __init__(self, use_timestamp, use_age):
    self.query_model = tf.keras.Sequential([
        UserModel(use_timestamp, use_age),
        tf.keras.layers.Dense(128, activation = "relu"),
        tf.keras.layers.Dense(32)
    ])

    self.candidate_model = tf.keras.Sequential([
        MovieModel(),
        tf.keras.layers.Dense(128, activation = "relu"),
        tf.keras.layers.Dense(32)
    ]) 

    self.task = tfrs.tasks.Retrieval(
        metrics = tfrs.metrics.FactorziedTopK(candidates = movies.batch())
    )

In [None]:
tf.random.set_seed(42)
shuffled = ratings.shuffle(100_000, seed=42, reshuffle_each_iteration=False)

train = shuffled.take(80_000)
test = shuffled.skip(80_000).take(20_000)

cached_train = train.shuffle(100_000).batch(2048)
cached_test = test.batch(4096).cache()