# Content Segmentation and Labeling with PCA, K-Means, Generative AI

This notebook demonstrates an approach to video content recommendations using machine learning and Generative AI. By analyzing semantic relationships rather than just metadata tags, this technique offers an effective solution for content recommendation.

## Key Features

- **Data Analysis:** Examines metadata from 150 recent contents, shows, and series to identify content similarities and relationships
- **Semantic Vectorization:** Transforms content's metadata into dense vector embeddings using a text embedding model
- **Dimensional Analysis:** Employs Principal Component Analysis (PCA) for dimensionality reduction and K-Means for cluster identification
- **Visual Representation:** Creates interactive 2D scatter plots to visualize content relationships and clustering patterns
- **AI-Powered Insights:** Leverages Generative AI to produce natural language descriptions of each content cluster and extract meaningful keywords

This approach enables content platforms to automatically group similar videos based on deep semantic connections rather than superficial categorization, significantly enhancing recommendation accuracy and content discovery.

## Create Text Embeddings

Loads the JSON file containing the metadata for content: recent movies, shows, and series. Then, transforms content metadata into dense vector embeddings using a text embedding model

In [None]:
# Constants - Change to your preferred models
EMBEDDING_MODEL = "qwen3-embedding:4b"  # "embeddinggemma:300m"
LARGE_LANGUAGE_MODEL = "gpt-oss:20b"  # "gemma3n:e4b"

from ollama import Client
import json

client = Client(
    host="http://localhost:11434",
)

with open("content.json", "r") as f:
    contents = json.load(f)

In [None]:
def prepare_text_for_embedding(content: dict) -> str:
    """Prepares text for embedding by combining genres, keywords, and description.

    Args:
        content (dict): individual content object

    Returns:
        str: Combined text for embedding.
    """
    content["genres"].sort()
    genres_as_string = ", ".join(str(genre) for genre in content["genres"]).lower()
    content["keywords"].sort()
    keywords_as_string = ", ".join(
        str(keyword) for keyword in content["keywords"]
    ).lower()
    text_to_embed = (
        f"{genres_as_string}, {keywords_as_string}, {content['description']}"
    )

    return text_to_embed

In [None]:
def create_embedding(client: Client, text_to_embed: str) -> list:
    """Creates an embedding for the given text

    Args:
        text_to_embed (str): Text to create an embedding for

    Returns:
        list: Dense vector embedding - list of floats
    """
    response = client.embed(model=EMBEDDING_MODEL, input=text_to_embed, dimensions=1024)
    embedding = response["embeddings"][0]

    return embedding

In [None]:
%%time

for content in contents:
    text_to_embed = prepare_text_for_embedding(content)
    print(
        f"Creating embedding for {content['id'] + 1:>3}: {content['title']:<50}",
        end="\r",
        flush=True,
    )
    embedding = create_embedding(client, text_to_embed)
    content["embedding"] = embedding

## PCA or t-SNE

PCA (Principal Component Analysis) is a linear method that reduces dimensions by finding global patterns and retaining overall data variance, while t-SNE (T-Distributed Stochastic Neighbor Embedding) is a non-linear method designed to preserve local data structure, making it excellent for visualizing clusters but not preserving global relationships or cluster sizes. PCA finds new, uncorrelated variables (principal components), whereas t-SNE creates a low-dimensional representation for visualization, often in 2D, by embedding data points based on their local neighborhood proximity.

** __Run one method or the other, PCA or t-SNE__ **

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import numpy as np

In [None]:
def perform_pca(contents: list) -> np.ndarray:
    """Perform PCA to reduce dimensionality of embeddings to 2D

    Args:
        contents (list): list of embeddings and their associated text

    Returns:
        reduced_dims (np.ndarray): 2D array of reduced dimensions
    """

    # Extract embeddings
    embeddings = list(
        map(lambda extracted_field: extracted_field["embedding"], contents)
    )
    embeddings_array = np.array(embeddings)

    # Reduce dimensions from high-D to 2D
    pca = PCA(n_components=2, random_state=42)
    reduced_dims = pca.fit_transform(embeddings_array)

    print(
        f"Original shape: {embeddings_array.shape}, Reduced shape: {reduced_dims.shape}"
    )
    print(f"Explained variance ratio: {pca.explained_variance_ratio_}")

    return reduced_dims

In [None]:
def perform_tsne(contents: list) -> np.ndarray:
    """Perform t-SNE to reduce dimensionality of embeddings to 2D

    Args:
        contents (list): list of embeddings and their associated text

    Returns:
        reduced_dims (np.ndarray): 2D array of reduced dimensions
    """
    # Extract embeddings
    embeddings = list(
        map(lambda extracted_field: extracted_field["embedding"], contents)
    )
    embeddings_array = np.array(embeddings)

    # Reduce dimensions from high-D to 2D
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    reduced_dims = tsne.fit_transform(embeddings_array)

    print(
        f"Original shape: {embeddings_array.shape}, Reduced shape: {reduced_dims.shape}"
    )
    print(f"Final KL Divergence: {tsne.kl_divergence_}")

    return reduced_dims

In [None]:
reduced_dims = perform_pca(contents)  # perform_tsne(contents)

## K-Means Clustering

K-means clustering is an unsupervised machine learning algorithm that partitions unlabeled data into a user-defined number of clusters (k) by minimizing the distance between data points and their assigned cluster centroids, which are the mean values of the points within each cluster. The algorithm iteratively assigns data points to the nearest centroid and then recalculates the centroids until cluster assignments stabilize, effectively grouping similar data points together. Reference: https://en.wikipedia.org/wiki/K-means_clustering

### Elbow Method: Determining Optimal Number of Clusters

Manual Inertia Calculation: To calculate within-cluster sum of squares (WCSS), also called inertia. for a range of cluster numbers (often in the context of K-means clustering), use the fitted model's .inertia_ attribute after clustering for each value of k (number of clusters). This process is commonly employed when applying the elbow method to determine the optimal cluster count.

In [None]:
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

# Adjust the range as needed
MIN_CLUSTERS = 5
MAX_CLUSTERS = 15

# Calculate inertia for different K values
inertias = []
k_range = range(MIN_CLUSTERS, MAX_CLUSTERS)
for k in k_range:
    kmeans = KMeans(n_clusters=k, random_state=0)
    kmeans.fit(reduced_dims)
    inertias.append(kmeans.inertia_)

# Plot the elbow curve
plt.figure(figsize=(10, 6))
plt.plot(k_range, inertias, "bo-")
plt.title("Elbow Method for Optimal K")
plt.xlabel("Number of Clusters (K)")
plt.ylabel("Inertia (WCSS)")
plt.grid(True)
plt.show()

Using Yellowbrick's `KElbowVisualizer`: The `KElbowVisualizer` implements the “elbow” method to help select the optimal number of clusters by fitting the model with a range of values for k. If the line chart resembles an arm, then the “elbow” (the point of inflection on the curve) is a good indication that the underlying model fits best at that point. In the visualizer “elbow” will be annotated with a dashed line.

In [None]:
# https://www.scikit-yb.org/en/latest/api/cluster/elbow.html

from yellowbrick.cluster import KElbowVisualizer

# Instantiate the clustering model and visualizer
model = KMeans()
visualizer = KElbowVisualizer(model, k=(MIN_CLUSTERS, MAX_CLUSTERS))

visualizer.fit(reduced_dims)  # Fit the data to the visualizer
visualizer.show()  # Finalize and render the figure

In [None]:
def kmeans_clustering(
    reduced_dims: np.ndarray, CLUSTER_COUNT: int = 10
) -> tuple[np.ndarray, np.ndarray]:
    """Perform K-Means clustering on reduced dimensions

    Args:
        reduced_dims (np.ndarray): 2D array of reduced dimensions

    Returns:
        labels (np.ndarray): array of KMeans cluster labels
    """
    kmeans = KMeans(n_clusters=CLUSTER_COUNT, random_state=0)
    kmeans.fit(reduced_dims)
    labels = kmeans.labels_
    centers = kmeans.cluster_centers_

    print(f"Created {CLUSTER_COUNT} clusters with sizes:")
    for cluster in range(CLUSTER_COUNT):
        print(f"Cluster {cluster}: {np.sum(labels == cluster)} items")

    return labels, centers

In [None]:
# Apply K-Means clustering
CLUSTER_COUNT = 12
labels, centers = kmeans_clustering(reduced_dims, CLUSTER_COUNT=CLUSTER_COUNT)

In [None]:
import plotly.graph_objects as go
import pandas as pd

# Count items per cluster and sort by cluster ID
cluster_counts = pd.Series(labels).value_counts().sort_index()

# Create custom labels with format "Cluster: X, Item count: Y"
custom_labels = [
    f"Cluster {cluster} count: {count}" for cluster, count in cluster_counts.items()
]

# Create pie chart with explicit ordering
fig = go.Figure(
    data=[
        go.Pie(
            labels=custom_labels,
            values=cluster_counts.values,
            sort=False,  # Disable automatic sorting
        )
    ]
)

fig.update_layout(
    title="Content Distribution Across Clusters",
    width=800,  # Width in pixels
    height=600,  # Height in pixels
    showlegend=True,  # Hide the legend
)
fig.update_legends(title="Clusters")
fig.show()

### D3 Animated Pie Chart

This cell renders an animated D3.js pie chart using the current `cluster_counts` from the Python kernel. Run the following code cell to display an interactive SVG pie chart that animates the slice growth and shows labels in the format `Cluster X: Y items`.

Customizable parameters:
- width / height: SVG size in pixels
- animation duration: in milliseconds
- colors: change the D3 color scale

If running in a headless environment (no internet), download d3 locally and update the script src to a local path.

In [None]:
# D3 Animated Pie Chart - Save to HTML and open in browser
import json
import webbrowser
import os

# Build a JSON serializable version of cluster counts
cluster_counts = pd.Series(labels).value_counts().sort_index()
data = [
    {"label": int(cluster), "count": int(count)}
    for cluster, count in cluster_counts.items()
]

# Create complete HTML document
html_content = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>D3 Animated Pie Chart - Content Distribution</title>
    <style>
        body {{
            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
            display: flex;
            justify-content: center;
            align-items: center;
            min-height: 100vh;
            margin: 0;
            background: #f5f5f5;
        }}
        #container {{
            background: white;
            padding: 20px;
            border-radius: 8px;
            box-shadow: 0 2px 8px rgba(0,0,0,0.1);
        }}
        h1 {{
            text-align: center;
            color: #333;
            margin-top: 0;
        }}
    </style>
</head>
<body>
    <div id="container">
        <h1>Content Distribution Across Clusters</h1>
        <div id="d3-pie-container"></div>
    </div>
    
    <script src="https://d3js.org/d3.v7.min.js"></script>
    <script>
        const data = {json.dumps(data)};
        const width = 900;
        const height = 600;
        const radius = Math.min(width, height) / 2 - 60;

        const svg = d3.select('#d3-pie-container')
            .append('svg')
            .attr('width', width)
            .attr('height', height)
            .append('g')
            .attr('transform', 'translate(' + width/2 + ',' + height/2 + ')');

        const color = d3.scaleOrdinal()
            .domain(data.map(d => d.label))
            .range(d3.schemeTableau10);

        const pie = d3.pie()
            .value(d => d.count)
            .sort((a,b) => a.label - b.label);

        const arc = d3.arc()
            .innerRadius(0)
            .outerRadius(radius);

        const outerArc = d3.arc()
            .innerRadius(radius * 1.1)
            .outerRadius(radius * 1.1);

        // Create tooltip
        const tooltip = d3.select('body')
            .append('div')
            .style('position', 'absolute')
            .style('background', 'rgba(0, 0, 0, 0.8)')
            .style('color', 'white')
            .style('padding', '8px 12px')
            .style('border-radius', '4px')
            .style('font-size', '14px')
            .style('pointer-events', 'none')
            .style('opacity', 0);

        const arcs = svg.selectAll('.arc')
            .data(pie(data))
            .enter()
            .append('g')
            .attr('class', 'arc');

        // Animate slices
        arcs.append('path')
            .attr('fill', d => color(d.data.label))
            .attr('stroke', 'white')
            .attr('stroke-width', 2)
            .style('cursor', 'pointer')
            .transition()
            .duration(1000)
            .attrTween('d', function(d) {{
                const i = d3.interpolate({{startAngle: 0, endAngle: 0}}, d);
                return function(t) {{ return arc(i(t)); }};
            }});

        // Add hover effects
        arcs.selectAll('path')
            .on('mouseover', function(event, d) {{
                d3.select(this)
                    .transition()
                    .duration(200)
                    .attr('opacity', 0.7)
                    .attr('transform', function() {{
                        const centroid = arc.centroid(d);
                        return 'translate(' + centroid[0] * 0.05 + ',' + centroid[1] * 0.05 + ')';
                    }});
                
                tooltip
                    .style('opacity', 1)
                    .html('<strong>Cluster ' + d.data.label + '</strong><br/>Count: ' + d.data.count)
                    .style('left', (event.pageX + 10) + 'px')
                    .style('top', (event.pageY - 10) + 'px');
            }})
            .on('mouseout', function(event, d) {{
                d3.select(this)
                    .transition()
                    .duration(200)
                    .attr('opacity', 1)
                    .attr('transform', 'translate(0,0)');
                
                tooltip.style('opacity', 0);
            }});

        // Add labels with polylines
        arcs.append('text')
            .attr('transform', function(d) {{
                const pos = outerArc.centroid(d);
                const midAngle = d.startAngle + (d.endAngle - d.startAngle) / 2;
                pos[0] = radius * 1.25 * (midAngle < Math.PI ? 1 : -1);
                return 'translate(' + pos + ')';
            }})
            .attr('text-anchor', function(d) {{
                const midAngle = d.startAngle + (d.endAngle - d.startAngle) / 2;
                return midAngle < Math.PI ? 'start' : 'end';
            }})
            .style('font-size', '12px')
            .style('fill', '#333')
            .style('opacity', 0)
            .text(d => 'Cluster ' + d.data.label + ': ' + d.data.count + ' items')
            .transition()
            .delay(1000)
            .duration(500)
            .style('opacity', 1);

        // Add polylines
        arcs.append('polyline')
            .attr('stroke', '#999')
            .attr('stroke-width', 1)
            .attr('fill', 'none')
            .style('opacity', 0)
            .attr('points', function(d) {{
                const pos = outerArc.centroid(d);
                const midAngle = d.startAngle + (d.endAngle - d.startAngle) / 2;
                pos[0] = radius * 1.25 * (midAngle < Math.PI ? 1 : -1);
                return [arc.centroid(d), outerArc.centroid(d), pos];
            }})
            .transition()
            .delay(1000)
            .duration(500)
            .style('opacity', 0.5);
    </script>
</body>
</html>
"""

# Save to file
output_file = "d3_pie_chart.html"
with open(output_file, "w") as f:
    f.write(html_content)

# Open in browser
file_path = os.path.abspath(output_file)
webbrowser.open("file://" + file_path)

print(f"✓ Chart saved to: {output_file}")
print(f"✓ Opening in your default browser...")

### Show Clustered Content

In [None]:
# Display items from each cluster
for idx, label in enumerate(labels):
    contents[idx]["cluster_id"] = int(label)

contents_sorted = sorted(
    contents, key=lambda content: (content["cluster_id"], content["title"])
)

for cluster in range(CLUSTER_COUNT):
    print(f"\nCluster {cluster}:")
    for content in contents_sorted:
        if content["cluster_id"] == cluster:
            print(f"  - {content["title"]}")

In [None]:
# Save extended content with embeddings and cluster IDs
with open("content_extended.json", "w") as f:
    json.dump(contents, f)

## 2D Scatter Plot

Visual the two principal components as x, y coordinates in a 2D scatter plot. Outline the clusters.

In [None]:
import pandas as pd
import plotly.express as px
import plotly.colors as colors

# Create a mapping from cluster IDs to colors
unique_clusters = np.unique(labels)
colorscale = colors.qualitative.Plotly  # A qualitative color scale with distinct colors
cluster_colors = {}
for idx, cluster in enumerate(unique_clusters):
    cluster_colors[cluster] = colorscale[idx % len(colorscale)]

# Create a DataFrame for plotting
plot_data = pd.DataFrame(
    {
        "x": [point[0] for point in reduced_dims],
        "y": [point[1] for point in reduced_dims],
        "title": [contents[idx]["title"] for idx in range(len(contents_sorted))],
        "cluster": labels,
        "color": [cluster_colors[label] for label in labels],
    }
)
plot_data = plot_data.sort_values(by="cluster")

fig = px.scatter(
    plot_data,
    x="x",
    y="y",
    hover_name="title",  # Show title on hover
    hover_data={"color": False, "cluster": True},  # Show cluster ID but hide color
    color="color",  # Color by cluster
    title="Content Clustering using PCA and K-Means",
    labels={"x": "PCA Dimension 1", "y": "PCA Dimension 2"},
    width=900,
    height=600,
    opacity=0.8,
    # animation_frame="cluster",
)

fig.update_layout(
    xaxis=dict(
        gridcolor="lightgray",
        griddash="dot",
        gridwidth=1,
        zeroline=True,
        zerolinecolor="lightgray",
        zerolinewidth=1,
        scaleanchor="y",
        scaleratio=1,
        showgrid=True,
        dtick=0.25,
    ),
    yaxis=dict(
        gridcolor="lightgray",
        griddash="dot",
        gridwidth=1,
        zeroline=True,
        zerolinecolor="lightgray",
        zerolinewidth=1,
        scaleanchor="x",
        scaleratio=1,
        showgrid=True,
        dtick=0.25,
    ),
    plot_bgcolor="White",
)

# Add circle boundaries for each cluster
for cluster in unique_clusters:
    # Get points in this cluster
    cluster_points = reduced_dims[labels == cluster]

    # Calculate the center of the cluster
    center = cluster_points.mean(axis=0)

    # Calculate the radius (maximum distance from center to any point)
    radius = np.linalg.norm(cluster_points - center, axis=1).max()
    # Add a circle shape
    fig.add_shape(
        type="circle",
        x0=center[0] - radius,
        y0=center[1] - radius,
        x1=center[0] + radius,
        y1=center[1] + radius,
        line=dict(color="#555", width=1),
        fillcolor=cluster_colors[cluster],
        opacity=0.08,
    )

# fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1000

fig.update_traces(showlegend=False)
fig.show()

### D3 Animated Scatter Plot

Interactive D3.js scatter plot with animated points moving from origin (0,0) to their final positions, cluster boundaries, hover tooltips, and zoom/pan capabilities.

In [None]:
# D3 Animated Scatter Plot - Save to HTML and open in browser
import json
import webbrowser
import os

# Prepare data for D3
scatter_data = []
for idx in range(len(reduced_dims)):
    scatter_data.append(
        {
            "x": float(reduced_dims[idx][0]),
            "y": float(reduced_dims[idx][1]),
            "cluster": int(labels[idx]),
            "title": contents[idx]["title"],
        }
    )

# Calculate cluster boundaries
cluster_boundaries = []
for cluster in unique_clusters:
    cluster_points = reduced_dims[labels == cluster]
    center = cluster_points.mean(axis=0)
    radius = np.linalg.norm(cluster_points - center, axis=1).max()
    cluster_boundaries.append(
        {
            "cluster": int(cluster),
            "centerX": float(center[0]),
            "centerY": float(center[1]),
            "radius": float(radius),
            "color": cluster_colors[cluster],
        }
    )

# Create complete HTML document
html_content = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>D3 Animated Scatter Plot - Content Clustering</title>
    <style>
        body {{
            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
            display: flex;
            justify-content: center;
            align-items: center;
            min-height: 100vh;
            margin: 0;
            background: #f5f5f5;
        }}
        #container {{
            background: white;
            padding: 20px;
            border-radius: 8px;
            box-shadow: 0 2px 8px rgba(0,0,0,0.1);
        }}
        h1 {{
            text-align: center;
            color: #333;
            margin-top: 0;
            margin-bottom: 10px;
        }}
        .controls {{
            text-align: center;
            margin-bottom: 15px;
        }}
        button {{
            padding: 8px 16px;
            margin: 0 5px;
            background: #007bff;
            color: white;
            border: none;
            border-radius: 4px;
            cursor: pointer;
            font-size: 14px;
        }}
        button:hover {{
            background: #0056b3;
        }}
        .tooltip {{
            position: absolute;
            background: rgba(0, 0, 0, 0.85);
            color: white;
            padding: 10px 14px;
            border-radius: 4px;
            font-size: 13px;
            pointer-events: none;
            opacity: 0;
            transition: opacity 0.2s;
            max-width: 300px;
        }}
    </style>
</head>
<body>
    <div id="container">
        <h1>Content Clustering using PCA and K-Means</h1>
        <div id="d3-scatter-container"></div>
        </br>
    </div>
    <script src="https://d3js.org/d3.v7.min.js"></script>
    <script>
        const data = {json.dumps(scatter_data)};
        const boundaries = {json.dumps(cluster_boundaries)};
        
        const margin = {{top: 20, right: 20, bottom: 50, left: 60}};
        const width = 900 - margin.left - margin.right;
        const height = 600 - margin.top - margin.bottom;

        // Create SVG
        const svg = d3.select('#d3-scatter-container')
            .append('svg')
            .attr('width', width + margin.left + margin.right)
            .attr('height', height + margin.top + margin.bottom);

        // Add zoom behavior
        const zoom = d3.zoom()
            .scaleExtent([0.5, 10])
            .on('zoom', zoomed);

        svg.call(zoom);

        const g = svg.append('g')
            .attr('transform', 'translate(' + margin.left + ',' + margin.top + ')');

        // Scales
        const xExtent = d3.extent(data, d => d.x);
        const yExtent = d3.extent(data, d => d.y);
        const padding = 0.1;

        const xScale = d3.scaleLinear()
            .domain([xExtent[0] - padding, xExtent[1] + padding])
            .range([0, width]);

        const yScale = d3.scaleLinear()
            .domain([yExtent[0] - padding, yExtent[1] + padding])
            .range([height, 0]);

        // Create clip path
        g.append('defs')
            .append('clipPath')
            .attr('id', 'clip')
            .append('rect')
            .attr('width', width)
            .attr('height', height);

        // Axes
        const xAxis = d3.axisBottom(xScale).ticks(10);
        const yAxis = d3.axisLeft(yScale).ticks(10);

        const xAxisG = g.append('g')
            .attr('class', 'x-axis')
            .attr('transform', 'translate(0,' + height + ')')
            .call(xAxis);

        const yAxisG = g.append('g')
            .attr('class', 'y-axis')
            .call(yAxis);

        // Axis labels
        g.append('text')
            .attr('x', width / 2)
            .attr('y', height + 40)
            .attr('text-anchor', 'middle')
            .style('font-size', '14px')
            .text('PCA Dimension 1');

        g.append('text')
            .attr('transform', 'rotate(-90)')
            .attr('x', -height / 2)
            .attr('y', -45)
            .attr('text-anchor', 'middle')
            .style('font-size', '14px')
            .text('PCA Dimension 2');

        // Grid
        g.append('g')
            .attr('class', 'grid')
            .attr('transform', 'translate(0,' + height + ')')
            .call(d3.axisBottom(xScale).ticks(10).tickSize(-height).tickFormat(''))
            .style('stroke', 'lightgray')
            .style('stroke-dasharray', '2,2')
            .style('opacity', 0.5);

        g.append('g')
            .attr('class', 'grid')
            .call(d3.axisLeft(yScale).ticks(10).tickSize(-width).tickFormat(''))
            .style('stroke', 'lightgray')
            .style('stroke-dasharray', '2,2')
            .style('opacity', 0.5);

        // Clipped content group
        const content = g.append('g')
            .attr('clip-path', 'url(#clip)');

        // Draw cluster boundaries
        const boundaryGroup = content.append('g').attr('class', 'boundaries');
        
        boundaries.forEach(b => {{
            boundaryGroup.append('circle')
                .attr('cx', xScale(b.centerX))
                .attr('cy', yScale(b.centerY))
                .attr('r', 0)
                .attr('fill', b.color)
                .attr('opacity', 0.08)
                .attr('stroke', '#555')
                .attr('stroke-width', 1)
                .transition()
                .duration(1500)
                .delay(500)
                .attr('r', Math.max(
                    Math.abs(xScale(b.centerX + b.radius) - xScale(b.centerX)),
                    Math.abs(yScale(b.centerY + b.radius) - yScale(b.centerY))
                ));
        }});

        // Create tooltip
        const tooltip = d3.select('body')
            .append('div')
            .attr('class', 'tooltip');

        // Create color mapping from boundaries data
        const clusterColorMap = {{}};
        boundaries.forEach(b => {{
            clusterColorMap[b.cluster] = b.color;
        }});

        // Draw points
        const points = content.append('g')
            .attr('class', 'points')
            .selectAll('circle')
            .data(data)
            .enter()
            .append('circle')
            .attr('cx', width / 2)
            .attr('cy', height / 2)
            .attr('r', 0)
            .attr('fill', d => clusterColorMap[d.cluster])
            .attr('stroke', 'white')
            .attr('stroke-width', 1.5)
            .style('cursor', 'pointer')
            .style('opacity', 0.8);

        // Animate points from center to final positions
        points.transition()
            .duration(2000)
            .delay((d, i) => i * 10)
            .attr('cx', d => xScale(d.x))
            .attr('cy', d => yScale(d.y))
            .attr('r', 4);

        // Hover effects
        points.on('mouseover', function(event, d) {{
            d3.select(this)
                .transition()
                .duration(200)
                .attr('r', 8)
                .style('opacity', 1);
            
            tooltip
                .style('opacity', 1)
                .html('<strong>' + d.title + '</strong><br/>Cluster: ' + d.cluster + '<br/>Position: (' + d.x.toFixed(2) + ', ' + d.y.toFixed(2) + ')')
                .style('left', (event.pageX + 10) + 'px')
                .style('top', (event.pageY - 10) + 'px');
        }})
        .on('mouseout', function(event, d) {{
            d3.select(this)
                .transition()
                .duration(200)
                .attr('r', 4)
                .style('opacity', 0.8);
            
            tooltip.style('opacity', 0);
        }});

        // Zoom function
        function zoomed(event) {{
            const transform = event.transform;
            
            // Update scales
            const newXScale = transform.rescaleX(xScale);
            const newYScale = transform.rescaleY(yScale);
            
            // Update axes
            xAxisG.call(d3.axisBottom(newXScale));
            yAxisG.call(d3.axisLeft(newYScale));
            
            // Update points
            points
                .attr('cx', d => newXScale(d.x))
                .attr('cy', d => newYScale(d.y));
            
            // Update boundaries
            boundaryGroup.selectAll('circle')
                .data(boundaries)
                .attr('cx', d => newXScale(d.centerX))
                .attr('cy', d => newYScale(d.centerY))
                .attr('r', d => Math.max(
                    Math.abs(newXScale(d.centerX + d.radius) - newXScale(d.centerX)),
                    Math.abs(newYScale(d.centerY + d.radius) - newYScale(d.centerY))
                ));
        }}

        // Reset zoom function
        function resetZoom() {{
            svg.transition()
                .duration(750)
                .call(zoom.transform, d3.zoomIdentity);
        }}

        // Replay animation function
        function replayAnimation() {{
            points
                .attr('cx', width / 2)
                .attr('cy', height / 2)
                .attr('r', 0)
                .transition()
                .duration(2000)
                .delay((d, i) => i * 10)
                .attr('cx', d => xScale(d.x))
                .attr('cy', d => yScale(d.y))
                .attr('r', 5);
            
            boundaryGroup.selectAll('circle')
                .attr('r', 0)
                .transition()
                .duration(1500)
                .delay(500)
                .attr('r', function() {{
                    const d = d3.select(this).data()[0];
                    return Math.max(
                        Math.abs(xScale(d.centerX + d.radius) - xScale(d.centerX)),
                        Math.abs(yScale(d.centerY + d.radius) - yScale(d.centerY))
                    );
                }});
        }}
    </script>
</body>
</html>
"""

# Save to file
output_file = "d3_scatter_plot.html"
with open(output_file, "w") as f:
    f.write(html_content)

# Open in browser
file_path = os.path.abspath(output_file)
webbrowser.open("file://" + file_path)

print(f"✓ Scatter plot saved to: {output_file}")
print(f"✓ Opening in your default browser...")
print(f"✓ Data points: {len(scatter_data)}")
print(f"✓ Clusters: {len(cluster_boundaries)}")

## Use Generative AI to Describe the Clusters

Use Generative AI to write natural language descriptions of the clusters.

In [None]:
def prepare_cluster_data(cluster_id: int, labels: np.ndarray, contents: list) -> list:
    """
    Prepare data from a specific cluster for analysis

    Args:
        cluster_id: The ID of the cluster to analyze
        labels: Array of cluster labels for each content item
        contents: List of (title, text) tuples

    Returns:
        List of content items in the specified cluster
    """
    # Get indices of items in this cluster
    cluster_indices = np.where(labels == cluster_id)[0]

    # Extract titles and descriptions from this cluster
    cluster_items = []
    for idx in cluster_indices:
        text = prepare_text_for_embedding(contents[idx])
        cluster_items.append(text)

    return cluster_items

In [None]:
def generate_description(client: Client, cluster_items: list) -> str:
    """
    Generate a description of a cluster using a language model

    Args:
        cluster_items: List of content items in the cluster

    Returns:
        Natural language description of the cluster
    """
    prompt = f"""You are analyzing a group of movies and TV shows.
Based on their titles, genres, keywords, and descriptions below, identify the common themes that define this group.

# Instructions
 - Summarize the overarching themes present in the group in a concise sentence.
 - Do not include phrases like "The cluster centers on" or "These works are about".
 - Do include any preamble, explaining, or additional context, just provide the thematic description directly.

# Example responses:
 - "Psychological suspense in close communities where complex women face trauma, uncover secrets, and confront moral ambiguity."
 - "Explorations of identity, family dynamics, and societal challenges through the lens of diverse characters in contemporary settings."

# Content items:
{json.dumps(cluster_items, indent=2)}

Your thematic description:
"""
    response = client.generate(
        model=LARGE_LANGUAGE_MODEL,
        prompt=prompt,
        options={"temperature": 0.3},
    )

    description = response["response"].strip().strip('"')
    return description

In [None]:
%%time

# Generate descriptions for all clusters
cluster_descriptions = []
for cluster_id in range(CLUSTER_COUNT):
    print(
        f"Generating description for Cluster {cluster_id:>2}...", end="\r", flush=True
    )
    cluster_items = prepare_cluster_data(cluster_id, labels, contents)
    description = generate_description(client, cluster_items)
    cluster_descriptions.append({"cluster_id": cluster_id, "description": description})

print(json.dumps(cluster_descriptions, indent=2))

In [None]:
import re, ast


def extract_keywords(client: Client, cluster_items: list) -> list:
    """
    Extract key themes or keywords from a cluster using a language model

    Args:
        cluster_items: List of content items in the cluster

    Returns:
        List of 5 keywords or short phrases
    """
    prompt = f"""You are analyzing a group of movies and TV shows.
Based on their titles, genres, keywords, and descriptions in the content items, identify the common themes that define this group.

# Instructions
 - Provide exactly 5 keywords or short phrases (1-4 words each) that capture the essence of this cluster.
 - Format your response as list of title case strings as shown in the examples response below. 
 - Do include any preamble, explanation, or additional context, just provide the list.

# Example response:
["Female-Led Investigation", "Small Town Secrets", "Psychological Trauma", "Murder Mystery", "Family Dynamics"]
["Superhero", "Corporate Dystopia", "Animated", "Multiverse", "Psychological Thriller"]

# Content items:
{json.dumps(cluster_items, indent=2)}

Your keywords:
"""
    response = client.generate(
        model=LARGE_LANGUAGE_MODEL,
        prompt=prompt,
        options={"temperature": 0.3},
    )

    keywords = response["response"].strip()
    keywords = keywords.replace("\n", "")
    match = re.search(r"\[.*?\]", keywords)
    if match:
        keyword_list = ast.literal_eval(match.group(0))
        return keyword_list
    return []

In [None]:
%%time

# Extract keywords for all clusters
cluster_keywords = []
for cluster_id in range(CLUSTER_COUNT):
    print(f"Extracting keywords for Cluster {cluster_id:>2}...", end="\r", flush=True)
    cluster_items = prepare_cluster_data(cluster_id, labels, contents)
    keywords = extract_keywords(client, cluster_items)
    if len(keywords) == 0:
        keywords = extract_keywords(cluster_items)
    cluster_keywords.append({"cluster_id": cluster_id, "keywords": keywords})

print(json.dumps(cluster_keywords, indent=2))

In [None]:
def combine_descriptions_and_keywords(
    cluster_descriptions: list, cluster_keywords: list
) -> list:
    """
    Combine cluster descriptions and keywords into a single structure

    Args:
        cluster_descriptions: List of cluster descriptions
        cluster_keywords: List of cluster keywords

    Returns:
        Combined list of clusters with descriptions and keywords
    """
    for cluster in cluster_descriptions:
        cluster_id = cluster["cluster_id"]
        description = cluster["description"]
        for keywords in cluster_keywords:
            if keywords["cluster_id"] == cluster_id:
                cluster["keywords"] = keywords["keywords"]
        keywords = cluster["keywords"]

        print(f"\nCluster {cluster_id}:")
        print(f"Description: {description}")
        print(f"Keywords: {keywords}")

    return cluster_descriptions

In [None]:
combined_descriptions = combine_descriptions_and_keywords(
    cluster_descriptions, cluster_keywords
)

with open("cluster_descriptions.json", "w") as f:
    json.dump(combined_descriptions, f)