In [0]:
dbutils.library.restartPython()

In [0]:
%pip install tensorflow

In [0]:
import json, re
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from pyspark.sql import functions as F
import mlflow
import mlflow.tensorflow

RUN_ID = "fe9ecda0ef3d49d0ae5b96847f975de9"

# Load model from MLflow
lm_model = mlflow.tensorflow.load_model(f"runs:/{RUN_ID}/model")

# Load vocab from MLflow artifact
vocab = json.loads(mlflow.artifacts.load_text(f"runs:/{RUN_ID}/vocab.json"))

vectorizer = layers.TextVectorization(
    max_tokens=len(vocab),
    standardize="lower_and_strip_punctuation",
    split="whitespace",
    output_mode="int",
)
vectorizer.set_vocabulary(vocab)

VOCAB_SIZE = len(vocab)
client = mlflow.tracking.MlflowClient()
params = client.get_run(RUN_ID).data.params
SEQ_LEN = int(params["SEQ_LEN"])
print("SEQ_LEN:", SEQ_LEN)

id_to_token = np.array(vocab)

print("Loaded model. vocab_size =", VOCAB_SIZE)


In [0]:
def sample_from_logits(logits, temperature=0.8, top_k=50):
    # logits: shape (vocab_size,)
    logits = tf.cast(logits, tf.float32)
    logits = logits / max(float(temperature), 1e-6)

    if top_k is not None and top_k > 0:
        values, _ = tf.math.top_k(logits, k=min(int(top_k), logits.shape[-1]))
        cutoff = values[-1]
        logits = tf.where(logits < cutoff, tf.constant(-1e10, logits.dtype), logits)

    probs = tf.nn.softmax(logits)
    next_id = int(tf.random.categorical(tf.math.log([probs]), 1)[0, 0])
    return next_id

def detokenize(ids):
    toks = id_to_token[ids]
    toks = [t for t in toks if t not in ("", "[UNK]")]
    return " ".join(toks)

def generate(prompt, max_new_tokens=120, temperature=0.7, top_k=50):
    # vectorizer returns padded ids; we drop zeros
    ids = vectorizer(tf.constant([prompt]))[0]
    ids = tf.boolean_mask(ids, ids > 0).numpy().tolist()

    for _ in range(max_new_tokens):
        window = ids[-SEQ_LEN:]
        if len(window) < SEQ_LEN:
            window = [0] * (SEQ_LEN - len(window)) + window

        x = tf.constant([window], dtype=tf.int32)
        logits = lm_model(x)  # (1, SEQ_LEN, vocab)
        next_id = sample_from_logits(logits[0, -1], temperature=temperature, top_k=top_k)
        ids.append(next_id)

    return detokenize(ids)

In [0]:
MOVIE_TABLE = "default.wiki_movie_plots_deduped"

def retrieve_movies(question: str, k: int = 5):
    q = (question or "").lower().strip()
    tokens = [t for t in re.findall(r"[a-z0-9]+", q) if len(t) >= 3]
    tokens = list(dict.fromkeys(tokens))  # unique, preserve order

    df = spark.table(MOVIE_TABLE)

    # Create a "haystack" field for simple keyword contains checks
    hay = F.lower(F.concat_ws(" ",
        F.coalesce(F.col("Title"), F.lit("")),
        F.coalesce(F.col("Plot"), F.lit("")),
        F.coalesce(F.col("Genre"), F.lit("")),
        F.coalesce(F.col("Director"), F.lit("")),
        F.coalesce(F.col("Cast"), F.lit(""))
    ))

    if not tokens:
        # fallback: just return some rows
        return df.limit(k)

    score = None
    for t in tokens:
        hit = F.when(hay.contains(t), F.lit(1)).otherwise(F.lit(0))
        score = hit if score is None else (score + hit)

    return (df
            .withColumn("_score", score)
            .where(F.col("_score") > 0)
            .orderBy(F.col("_score").desc())
            .limit(k))


In [0]:
def format_context(rows, max_plot_chars=500):
    blocks = []
    for r in rows:
        title = (r["Title"] or "").strip()
        year  = r["Release Year"]
        genre = (r["Genre"] or "").strip()
        director = (r["Director"] or "").strip()
        cast = (r["Cast"] or "").strip()
        plot = (r["Plot"] or "").strip()

        if len(plot) > max_plot_chars:
            plot = plot[:max_plot_chars].rsplit(" ", 1)[0] + "..."

        blocks.append(
            f"Title: {title}\n"
            f"Year: {year}\n"
            f"Genre: {genre}\n"
            f"Director: {director}\n"
            f"Cast: {cast}\n"
            f"Plot: {plot}\n"
        )
    return "\n---\n".join(blocks).strip()


In [0]:
def answer(question: str, k: int = 5,
           max_new_tokens: int = 120,
           temperature: float = 0.7,
           top_k: int = 50):

    rows = retrieve_movies(question, k=k).collect()
    context = format_context(rows)

    prompt = (
        "You are a movie assistant.\n"
        "Use only the CONTEXT to answer. If the answer is not in the context, say \"I don't know.\"\n\n"
        f"CONTEXT:\n{context}\n\n"
        f"QUESTION:\n{question}\n\n"
        "ANSWER:\n"
    )

    return generate(prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)

test_rows = retrieve_movies("Kansas Saloon Smashers", k=3).collect()
print("Retrieved rows:", len(test_rows))
print("First title:", test_rows[0]["Title"] if test_rows else "NONE")

q = "Tell me about Kansas Saloon Smashers."
resp = answer(q, k=5, max_new_tokens=120, temperature=0.7, top_k=50)
print(resp)

