In [None]:
! pip install requests sqlite-vec openai pandas python-fasthtml git+https://github.com/callmephilip/fasthtml-nb-ext.git

In [2]:
DB_URL = "https://huggingface.co/datasets/callmephilip/movies/resolve/main/movies.db"
DB_PATH = "./data/movies.db"
OPENAI_API_KEY = ""

In [3]:
from typing import List
import struct
from openai import OpenAI

client = OpenAI(api_key=OPENAI_API_KEY)

def serialize_f32(vector: List[float]) -> bytes:
    """serializes a list of floats into a compact "raw bytes" format"""
    return struct.pack("%sf" % len(vector), *vector)

def embed(text, model="text-embedding-3-small", pack=True) -> bytes | List[float]: 
   text = text.replace("\n", " ")
   d = client.embeddings.create(input = [text], model=model).data[0].embedding
   return serialize_f32(d) if pack else d

def embed_batch(texts: List[str], model="text-embedding-3-small", pack=True) -> List[bytes] | List[float]: 
   texts = [t.replace("\n", " ") for t in texts]
   d = client.embeddings.create(input = texts, model=model).data
   return [serialize_f32(v.embedding) for v in d] if pack else [v.embedding for v in d]

In [None]:
# Set up database

import sqlite_vec, os, requests
from fasthtml.common import *

OPENAI_TEXT_EMBEDDING_3_SMALL_SIZE = 1536


# check if movies.db exists, not then download it from DB_URL
if not os.path.exists(DB_PATH):
    r = requests.get(DB_URL, allow_redirects=True)
    open(DB_PATH, 'wb').write(r.content)


db = database(DB_PATH)
db.conn.enable_load_extension(True)
sqlite_vec.load(db.conn)
db.conn.enable_load_extension(False)

class Movie: movie_id:int; title: str; year: int; poster_link: str; genres: str; actors: str; director: str; description: str; plot: str; keywords: str
movies = db.create(Movie, pk='movie_id')

vec_version, = db.execute("select vec_version()").fetchone()
if not db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='embeddings'").fetchone():
    db.execute(f"CREATE VIRTUAL TABLE embeddings USING vec0(embedding float[{OPENAI_TEXT_EMBEDDING_3_SMALL_SIZE}])")


def find_movies(query: str, limit=10) -> List[Movie]:
    rows = db.execute(
        f"SELECT rowid, distance FROM embeddings WHERE embedding MATCH ? ORDER BY distance LIMIT {limit}",
        [embed(query)],
    ).fetchall()
    return movies(where=f"movie_id in ({','.join(map(lambda row: str(row[0]), rows))})")

print(f"vec_version={vec_version}")

In [None]:
## ================ YOU probably do not need to run this ================
## prepare data
# import pandas as pd
# import os


# # Load and prepare dataset
# df=pd.read_csv("./data/movie_data.csv", 
#                usecols = ['id', 'Name', 'PosterLink', 'Genres', 'Actors', 
#                           'Director','Description', 'DatePublished', 'Keywords'], 
#                parse_dates = ["DatePublished"])
# df["year"] = df["DatePublished"].dt.year.fillna(0).astype(int)
# df.drop(["DatePublished"], axis=1, inplace=True)
# df = df[df.year > 1970]

# # Plot dataset
# plots = pd.read_csv('./data/wiki_movie_plots_deduped.csv')
# plots = plots[plots['Release Year'] > 1970]
# plots = plots[plots.duplicated(subset=['Title', 'Release Year', 'Plot']) == False]
# plots = plots[plots.duplicated(subset=['Title', 'Release Year']) == False]
# plots = plots[['Title', 'Plot', 'Release Year']]

# plots.columns = ['Name', 'Plot', 'year']

# # Merge
# df = df.merge(plots, on=['Name', 'year'], how='left').fillna('')
# df.reset_index(drop=True, inplace=True)

# recs = []

# for i in range(len(df)):
#     it = df.iloc[i]
#     movie_id=int(it['id'])
#     recs.append(Movie(
#         movie_id=movie_id,
#         title=str(it['Name']).lower(),
#         year=int(it['year']),
#         poster_link=str(it['PosterLink']),
#         genres=str(it['Genres']),
#         actors=str(it['Actors']).lower(),
#         director=str(it['Director']).lower(),
#         description=str(it['Description']),
#         plot=str(it['Plot']),
#         keywords=str(it['Keywords']),
#     ))


# batch_size = 100

# for i in range(0, len(recs), batch_size):
#     print(f"{i} / {len(recs)}")
#     batch = recs[i:i+batch_size]
#     embeds = embed_batch(["\n".join([it.title, it.description, it.plot]) for it in batch])
#     # iterate over batch
#     for j, movie in enumerate(batch):
#         movies.insert(movie)
#         db.execute("INSERT INTO embeddings(rowid, embedding) VALUES (?, ?)", [movie.movie_id, embeds[j]])

In [6]:
from fasthtml.common import *
from fasthtml_nb_ext import Playground

@patch
def __ft__(self: Movie):
    return Div(cls='card')(
        Div(cls='movie-item')(
            Form(
                Button(type='submit')(
                    Img(src=self.poster_link, alt=''),
                    Div(cls='movie-item-content')(
                        Div(style='text-transform:capitalize;', cls='movie-item-title')(self.title),
                        Div(cls='movie-infos')(
                            Div(cls='movie-info')(
                                Span(self.year)
                            ),
                            Div(cls='movie-info')(
                                Span(self.genres)
                            )
                        )
                    )
                )
            )
        )
    )
  
Playground.config(hdrs=[Link(rel="stylesheet", href="/styles.pure.css"), Link(rel="stylesheet", href="/navbar.css")])

In [None]:
with Playground(path="/") as p:
    @p.rt("/{fname:path}.{ext:static}")
    def get(fname: str, ext: str): return FileResponse(f"static/{fname}.{ext}")

    @p.rt("/search")
    def get(request):
        q = request.query_params.get("search")
        results = find_movies(q) if q else [] 
        return Div(cls='section-header')("No matches found!") if len(results) == 0 else Section(cls='popular-tours')(
            Div(cls='cards-wrapper')(*results)
        )

    @p.rt("/")
    def get():
        return  Div(cls='nav-wrapper')(
            Img(src='holoskull.gif', style='width:70px; position:absolute; top:10px; left:10px;'),
            Div(cls='nav')(
                Ul(id='nav-menu', cls='nav-menu')(
                    Li(style='text-align: center;')(
                        "Inspired by",
                        A("@karpathy", href='https://twitter.com/karpathy/status/1647372603907280896', target='”_blank”')
                    ),
                    Li(style='text-align: center;')(
                        "Powered by",
                        A(href='https://github.com/weaviate')(
                            Img(src='weaviate-icon.png', alt='Weaviate', style='height: 30px;width: 30px; vertical-align:middle;')
                        )
                    ),
                    Li(style='text-align: center;')(
                        "Built with 💚 by",
                        A("Leonie", href='https://www.linkedin.com/in/804250ab/', target='”_blank”')
                    ),
                    Li(style='text-align: center;')(
                        A(href='https://github.com/weaviate-tutorials/awesome-moviate')(
                            Img(src='github-icon.png', alt='Github', style='height: 30px;width: 30px; vertical-align:middle;')
                        )
                    )
                )
            ),
            Img(src='holoskull.gif', style='width:70px; position:absolute; top:10px; right:10px;')
        ), Div(style="text-align:center;")(
            Div(cls='logo')(
                Img(src='flame.gif'),
                H1(
                    "!!! awesome-",
                    Img(src='moviate-logo.png', alt='Moviate', style='height: 25px; vertical-align:center;'),
                    "!!!"
                ),
                Img(src='flame.gif')
            ),
            Div(cls='bar')(
                Input(type='text', name='search', title='Search', autocomplete='off', cls='searchbar', hx_get="/search", hx_trigger="input changed delay:500ms, search" , hx_target="#search-results")
            )
        ), Div(id="search-results")(Strong("search results")), Script(src='jquery-3.4.1.min.js') 