In [0]:
%pip install tensorflow

In [0]:
import json, re
import tensorflow as tf
from tensorflow.keras import layers
from pyspark.sql import functions as F
import mlflow
import mlflow.tensorflow

RUN_ID = "fe9ecda0ef3d49d0ae5b96847f975de9"

# 1) Load model from MLflow
model_uri = f"runs:/{RUN_ID}/model"
lm_model = mlflow.tensorflow.load_model(model_uri)

# 2) Load vocab.json artifact from MLflow
vocab_path = mlflow.artifacts.download_artifacts(
    run_id=RUN_ID,
    artifact_path="vocab.json"
)

with open(vocab_path, "r") as f:
    vocab = json.load(f)

print("Loaded from MLflow. vocab_size =", len(vocab))

vectorizer = layers.TextVectorization(
    max_tokens=len(vocab),
    standardize="lower_and_strip_punctuation",
    split="whitespace",
    output_mode="int",
)
vectorizer.set_vocabulary(vocab)

SEQ_LEN = 128
BATCH_SIZE = 128
print("Loaded model + vocab_size:", len(vocab))

In [0]:
import base64

lines_df = spark.table("workspace.default.movie_lines")

# Your table has one column called "value"
line_col = "value"

# Split: L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
parts = F.split(F.col(line_col), " \\+\\+\\+\\$\\+\\+\\+ ")

line_map_df = (lines_df
    .select(
        parts.getItem(0).alias("line_id"),
        parts.getItem(4).alias("text")
    )
    .where(F.col("line_id").isNotNull())
    .where(F.col("text").isNotNull())
    .where(F.length(F.col("text")) > 0)
)

print("Lines rows:", line_map_df.count())

# ---------------------------
# 2. Load conversations
# ---------------------------
convs_df = spark.table("workspace.default.movie_conversations")
conv_col = "value"

conv_parts = F.split(F.col(conv_col), " \\+\\+\\+\\$\\+\\+\\+ ")
ids_field = conv_parts.getItem(3)  # "['L194','L195',...]"

conv_ids_df = (convs_df
    .select(ids_field.alias("ids_field"))
    .withColumn("ids", F.expr("regexp_extract_all(ids_field, 'L\\\\d+', 0)"))
    .where(F.size("ids") >= 2)
)

print("Conversations:", conv_ids_df.count())

# ---------------------------
# 3. Explode conversations into turns
# ---------------------------
exploded = conv_ids_df.select(
    F.monotonically_increasing_id().alias("conv_id"),
    F.posexplode("ids").alias("pos", "line_id")
)

# Join with actual text
joined = exploded.join(line_map_df, on="line_id", how="left")

# Assign User/Bot roles
with_roles = joined.withColumn(
    "turn",
    F.when(F.col("pos") % 2 == 0,
           F.concat(F.lit("User: "), F.col("text")))
     .otherwise(
           F.concat(F.lit("Bot: "), F.col("text")))
)

# Limit turns per conversation (10)
with_roles_limited = with_roles.where(F.col("pos") < 10)

# ---------------------------
# 4. Build chat documents
# ---------------------------
chat_docs_df = (with_roles_limited
    .groupBy("conv_id")
    .agg(F.concat_ws("\n", F.collect_list("turn")).alias("chat_doc"))
    .where(F.length("chat_doc") > 0)
)

display(chat_docs_df.limit(5))

# ---------------------------
# 5. SAFE COLLECT (no UTF-8 crash)
# ---------------------------
chat_docs_b64_df = chat_docs_df.select(
    F.base64(F.encode(F.col("chat_doc"), "UTF-8")).alias("chat_doc_b64")
)

chat_docs_b64 = [
    r["chat_doc_b64"]
    for r in chat_docs_b64_df.limit(50000).collect()
]

chat_docs = [
    base64.b64decode(s).decode("utf-8", errors="replace")
    for s in chat_docs_b64
]

print("chat_docs:", len(chat_docs))
print(chat_docs[0][:300])

In [0]:
def make_windows(token_ids):
    return tf.data.Dataset.from_tensor_slices(token_ids).window(
        SEQ_LEN + 1, shift=SEQ_LEN, drop_remainder=True
    ).flat_map(lambda w: w.batch(SEQ_LEN + 1))

def split_xy(seq):
    return seq[:-1], seq[1:]

def doc_to_ds(doc):
    ids = vectorizer(tf.expand_dims(doc, 0))[0]
    ids = tf.boolean_mask(ids, ids > 0)
    return make_windows(ids).map(split_xy, num_parallel_calls=tf.data.AUTOTUNE)

text_ds = tf.data.Dataset.from_tensor_slices(chat_docs).shuffle(20000, seed=42)
lm_ds = text_ds.flat_map(doc_to_ds).shuffle(20000, seed=42)

VAL_EXAMPLES = 2000
val_ds = lm_ds.take(VAL_EXAMPLES).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
train_ds = lm_ds.skip(VAL_EXAMPLES).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

train_ds_rep = train_ds.repeat()
val_ds_rep   = val_ds.repeat()


In [0]:
STEPS_PER_EPOCH = 200
VAL_STEPS = 20

history = lm_model.fit(
    train_ds_rep,
    validation_data=val_ds_rep,
    steps_per_epoch=STEPS_PER_EPOCH,
    validation_steps=VAL_STEPS,
    epochs=1
)


In [0]:
import mlflow, mlflow.tensorflow, json

mlflow.set_experiment("/Users/desiborisovab@gmail.com/movie_chatbot_experiment")

with mlflow.start_run(run_name="finetune_cornell_dialogue") as run:
    run_id = run.info.run_id

    mlflow.tensorflow.log_model(lm_model, artifact_path="model")
    mlflow.log_text(json.dumps(vocab), "vocab.json")

    mlflow.log_param("finetune_dataset", "cornell")
    mlflow.log_param("SEQ_LEN", SEQ_LEN)
    mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
    mlflow.log_param("STEPS_PER_EPOCH", STEPS_PER_EPOCH)
    mlflow.log_param("VAL_STEPS", VAL_STEPS)

print("Fine-tuned model logged. run_id =", run_id)
