In [7]:
import pandas as pd
import random
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
import numpy as np

Load dataset

In [3]:
df = pd.read_csv("movies_with_short_plot.csv")

df = df.dropna(subset=["short_plot"]).reset_index(drop=True)

texts = df["short_plot"].astype(str).tolist()

Build training examples

In [4]:
train_examples = []

for text in texts:
    train_examples.append(InputExample(texts=[text, text]))

print("Training samples:", len(train_examples))

Training samples: 508


Load Base Model

In [6]:
model = SentenceTransformer("all-MiniLM-L6-v2")

# dataloader
train_dataloader = DataLoader(
    train_examples,
    shuffle=True,
    batch_size=16
)

# loss function
train_loss = losses.MultipleNegativesRankingLoss(model)

#train
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=2,
    warmup_steps=100,
    output_path="short_plot_similarity_model"
)

print("Model Training Completed")

Loading weights: 100%|██████████| 103/103 [00:00<00:00, 337.42it/s, Materializing param=pooler.dense.weight]                             
[1mBertModel LOAD REPORT[0m from: all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.[0m
  super().__init__(loader)


Step,Training Loss


Writing model shards: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


Model Training Completed


In [9]:

# Encode all short plots
embeddings = model.encode(
    df["short_plot"].tolist(),
    show_progress_bar=True,
    batch_size=32
)

# Save embeddings
np.save("short_plot_embeddings.npy", embeddings)

print("✅ short_plot_embeddings.npy created successfully!")
df.to_csv("short_plot_dataset_clean.csv", index=False)

Batches: 100%|██████████| 16/16 [00:03<00:00,  5.10it/s]


✅ short_plot_embeddings.npy created successfully!


In [1]:
import numpy as np
import pandas as pd

emb = np.load("short_plot_embeddings.npy")
df = pd.read_csv("short_plot_dataset_clean.csv")

print("Embeddings:", emb.shape)
print("Rows:", len(df))

Embeddings: (508, 384)
Rows: 508
