# Imports

In [None]:
import math

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from skimage import io
from weaviate import Client

# Data

In [None]:
# read the movies postprocess parquet file
movies = pd.read_parquet("../data/movies_postprocessed.parquet")
movies.head()


# Weaviate

In [None]:
# connect to the weaviate instance at weaviate:8080
client = Client("http://weaviate:8080")

client.cluster.get_nodes_status()


Get the movie classes:

In [None]:
movie_classes = client.schema.get()["classes"]

movie_classes = [movie_class["class"]
                 for movie_class in movie_classes if movie_class["class"].startswith("Movie")]

movie_classes


# Schema

Define a schema to represent the a user's viewing history:

In [None]:
# make a copy of movie_classes
# replace Movie with View in each element
view_classes = [movie_class.replace("Movie", "View")
                for movie_class in movie_classes]
view_classes


In [None]:
# delete the View schemas if it exists
for view_class in view_classes:
    try:
        client.schema.delete_class(view_class)
    except Exception as e:
        print(e)


In [None]:
for view_class, movie_class in zip(view_classes, movie_classes):

    view_class_schema = {
        "class": view_class,
        "description": f"The movies a user has watched",
        "moduleConfig": {
            "ref2vec-centroid": {
                "referenceProperties": ["movies"],
                "method": "mean"
            }
        },
        "properties": [
            {
                "dataType": ["string"],
                "name": "user_id",
                "description": "The user id",

            },
            {
                "dataType": [movie_class],
                "name": "movies",
                "description": "The movies the user has watched",
            }
        ],
        "vectorizer": "ref2vec-centroid"
    }

    # create the View schema
    client.schema.create_class(view_class_schema)


In [None]:
client.schema.get("ViewCos")


# Demo

Define a dummy user with a viewing history:

In [None]:
user_id = "test_user"

# sleeping beauty, aladdin and the little mermaid
# movie_ids = ["2096", "588", "2081"]

# mortal kombat, mortal kombat annihilation, street fighter
# movie_ids = ["44", "1681", "393"]


# the running man, "demolition man", "assassins"
# movie_ids = ["3698", "442", "23"]

# cinderella, peter pan, Lady and the Tramp, sleeping beauty, the jungle book
movie_ids = ["1022", "2087", "2080", "2096", "362"]


In [None]:
# define a function to get the movie's uuid from the movie id
def get_movie_uuid(movie_id, movie_class="MovieCos"):
    where_filter = {
        "path": ["movie_id"],
        "operator": "Equal",
        "valueString": movie_id}


    result = client.query.get(movie_class).with_additional(
        "id").with_where(where_filter).do()

    return result.get('data').get('Get').get(movie_class)[0].get('_additional').get('id')


In [None]:
# define a function to build the user's view history
def build_user_view_history(user_id, movie_ids, view_class="ViewCos", movie_class="MovieCos"):
    with client.batch() as batch:
        movie_uuids = [get_movie_uuid(movie_id, movie_class=movie_class)
                       for movie_id in movie_ids]

        user_uuid = client.data_object.create(
            {"user_id": user_id}, class_name=view_class)

        for movie_uuid in movie_uuids:
            client.data_object.reference.add(
                from_uuid=user_uuid,
                from_property_name="movies",
                to_uuid=movie_uuid,
                from_class_name=view_class,
                to_class_name=movie_class)


In [None]:
for view_class, movie_class in zip(view_classes, movie_classes):
    build_user_view_history(user_id, movie_ids, view_class, movie_class)


In [None]:
# TODO: checkout if this is a bug in the client and/or ref2vec module
# with client.batch() as batch:
#     # user_uuid = batch.add_data_object({
#     #     "user_id": user_id,
#     # }, class_name="View"
#     # )

#     user_uuid = client.data_object.create(
#         {"user_id": user_id}, class_name="View")

#     for movie_uuid in movie_uuids:
#         # batch.add_reference(
#         #     from_object_uuid=user_uuid,
#         #     from_object_class_name="View",
#         #     from_property_name="movies",
#         #     to_object_uuid=movie_uuid,
#         #     to_object_class_name="Movie"
#         # )
#         client.data_object.reference.add(
#             from_uuid=user_uuid,
#             from_property_name="movies",
#             to_uuid=movie_uuid,
#             from_class_name="View",
#             to_class_name="Movie"
#         )


Let's sanity check the embeddings that ref2vec has generated for the user:

In [None]:
# define a function to get a user's vector given a view class and user id
def get_user_vector(user_id, view_class="ViewCos"):
    where_filter = {
        "path": ["user_id"],
        "operator": "Equal",
        "valueString": user_id}

    result = client.query.get(view_class).with_additional(
        "vector").with_where(where_filter).do()

    return result.get('data').get('Get').get(view_class)[0].get('_additional').get('vector')


In [None]:
# define a function to get a movie's vector given a movie class and movie id
def get_movie_vector(movie_id, movie_class="MovieCos"):
    where_filter = {
        "path": ["movie_id"],
        "operator": "Equal",
        "valueString": movie_id}

    result = client.query.get(movie_class).with_additional(
        "vector").with_where(where_filter).do()

    return result.get('data').get('Get').get(movie_class)[0].get('_additional').get('vector')


In [None]:
user_embeddings = []

# check 1: in each view,movie class, the user's vector should be the average of the movie vectors

# zip the view classes and movie classes together and iterate over them
for view_class, movie_class in zip(view_classes, movie_classes):
    user_embedding = get_user_vector(user_id, view_class)
    user_embeddings.append(user_embedding)

    movie_embeddings = [get_movie_vector(
        movie_id, movie_class) for movie_id in movie_ids]

    # compute the mean of movie embeddings
    mean_movie_embedding = np.mean(movie_embeddings, axis=0)

    # is mean_movie_embedding close to user_embedding?
    assert np.allclose(mean_movie_embedding, user_embedding)

# check 2: the user embeddings should be the same across view classes

# loop from the second element to the end of user_embeddings
for i in range(1, len(user_embeddings)):
    # compare the previous user embedding to the current user embedding
    assert np.allclose(user_embeddings[i-1], user_embeddings[i])


Now we find the nearest movies to the `user_embedding`:

In [None]:
def build_movie_id_exclude_filter(movie_ids, genre=None):
    operands = [{"path": ["movie_id"], "operator": "NotEqual",
                 "valueString": movie_id} for movie_id in movie_ids]

    if genre:
        operands = operands + \
            [{"path": ["genres", "Genre", "name"],
                "operator": "Equal", "valueString": genre}]

    return {
        "operator": "And",
        "operands": operands
    }


In [None]:
# define a function to generate recommendations given a user embedding
def get_recommendations(user_embedding, watched_movie_ids, genre_constraint=None, movie_class="MovieCos", top_k=10):
    nearVector = {
        "vector": user_embedding,
    }

    movies = (
        client.query
        .get(movie_class, properties=["movie_id", "title", "plot", "genres {... on Genre {name}}", "poster_url"])
        .with_near_vector(nearVector)
        .with_where(build_movie_id_exclude_filter(watched_movie_ids, genre=genre_constraint))
        .with_additional("distance")
        .with_limit(top_k)
        .do()
    )

    movies = movies.get('data').get('Get').get(movie_class)
    return movies


In [None]:
# define a function to get poster images from a list of movie ids
def get_poster_images(movie_ids):
    poster_urls = movies.loc[movie_ids]["poster_url"].tolist()
    images = [io.imread(url) for url in poster_urls]
    return images


In [None]:

def visualize_results(history_movie_ids, rec_movie_ids):

    # based on https://stackoverflow.com/questions/70083434/combine-two-matplotlib-figures-side-by-side-high-quality
    fig = plt.figure(constrained_layout=True, figsize=(6, 2), dpi=500)
    titles_size = 12
    labels_size = 18
    subfigs = fig.subfigures(1, 2, wspace=0.1, hspace=0)
    subfigs[0].suptitle('If you liked these ...', fontsize=titles_size)
    subfigs[1].suptitle('Then, you might like these', fontsize=titles_size)

    axsLeft = subfigs[0].subplots(2, 5)
    axsRight = subfigs[1].subplots(2, 5)

    rec_images = get_poster_images(rec_movie_ids)
    history_images = get_poster_images(history_movie_ids)

    for ax_idx, ax in enumerate(axsLeft.reshape(-1)):
        ax.grid(False)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        ax.axis('off')

    for ax_idx, ax in enumerate(axsRight.reshape(-1)):
        ax.grid(False)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.axes.xaxis.set_visible(False)
        ax.axes.yaxis.set_visible(False)
        ax.axis('off')

    # flatten axsleft and loop through it to plot the history images
    for ax_idx, ax in enumerate(axsLeft.reshape(-1)):
        if ax_idx < len(history_images):
            ax.imshow(history_images[ax_idx])

    # flatten axsright and loop through it to plot the rec images
    for ax_idx, ax in enumerate(axsRight.reshape(-1)):
        ax.imshow(rec_images[ax_idx])
        ax.axis('off')

    return fig


In [None]:
recs = get_recommendations(user_embeddings[0], watched_movie_ids=movie_ids,
                           genre_constraint=None, top_k=10, movie_class=movie_classes[0])

# extract the movie ids from the recommendations
rec_movie_ids = [rec["movie_id"] for rec in recs]

fig = visualize_results(movie_ids, rec_movie_ids)
fig;


Same recommendation but with a genre filter:

In [None]:
recs = get_recommendations(user_embeddings[0], watched_movie_ids=movie_ids,
                           genre_constraint="Children's", top_k=10, movie_class=movie_classes[0])

rec_movie_ids = [rec["movie_id"] for rec in recs]

fig = visualize_results(movie_ids, rec_movie_ids)
fig;


Let's visualize how the recommendations change as we change the distance metric:

In [None]:
def extract_distance_function(movie_class):
    return client.schema.get(movie_class)["vectorIndexConfig"]["distance"]


def visualize_recs_across_distance_metrics(user_embedding, watched_movie_ids, genre_constraint=None, top_k=10):

    # create a figure with constrained layout, figsize and dpi
    fig, ax = plt.subplots(constrained_layout=True, figsize=(10, 20), dpi=1000)

    # create a 5 x 1 grid of subfigures
    subfigs = fig.subfigures(5, 1)

    # enumerate over the movie classes and subfigures
    for movie_class, subfig in zip(movie_classes, subfigs):
        # extract the distance function from the movie class
        distance_function = extract_distance_function(movie_class)
        # set the title of the subfigure and make it bold
        subfig.suptitle(f"{distance_function}", fontsize=18, fontweight="bold")

        # create a 2 x 5 subplot grid
        axs = subfig.subplots(2, 5, gridspec_kw={'height_ratios': [
                              1, 1], 'wspace': 0, 'hspace': 0})

        # get recommendations for the user embedding
        recs = get_recommendations(
            user_embedding, watched_movie_ids=watched_movie_ids, genre_constraint=genre_constraint, top_k=top_k, movie_class=movie_class)

        # zip the axes and recs and enumerate over it plot the poster image, title and distance
        for ax_idx, (ax, rec) in enumerate(zip(axs.reshape(-1), recs)):
            ax.grid(False)
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.axes.xaxis.set_visible(False)
            ax.axes.yaxis.set_visible(False)
            ax.axis('off')

            # get the rec title, distance and poster image
            rec_title = rec["title"]
            rec_distance = rec["_additional"]["distance"]

            rec_movie_id = rec["movie_id"]
            rec_image = get_poster_images([rec_movie_id])[0]

            # plot the poster image
            ax.imshow(rec_image)
            # set title to the movie title (distance)
            # distance should be centered
            ax.set_title(f"{rec_title}\n({rec_distance:.2f})",
                         fontsize=4, ha="center", wrap=True)

    return fig


In [None]:
visualize_recs_across_distance_metrics(
    user_embeddings[0], movie_ids, genre_constraint=None, top_k=10);


In [None]:
# do the same as the previous cell but with Children's constraint
visualize_recs_across_distance_metrics(
    user_embeddings[0], movie_ids, genre_constraint="Children's", top_k=10);
