In [78]:
from pymilvus import connections, Collection

connections.connect(alias="default", host="localhost", port="19530")

In [79]:
collection = Collection("embedded_music_data")

In [80]:
collection.load()

In [81]:
# Optional vibe filter and retrieval controls
# Use a task description -> estimate minutes with local LLM (Qwen) -> compute number of songs based on avg track length
import os, json, re, requests

# Set a target vibe label (e.g., "Pop", "Pump Up", "Chill"). Leave empty/None to search across all vibes.
target_vibe = "Groove"    # change as needed; set to None or "" to disable vibe filtering
include_blended = True         # include items labeled as 'Blended' alongside the target vibe

# Task-to-playlist sizing
use_task_time = True
# When True and a minutes estimate exists, cap by cumulative duration instead of fixed count
time_budget_mode = True
# Allow a small budget overshoot tolerance (ms) to avoid underfilling due to one slightly long track
budget_overshoot_ms = 15_000

task_description = None  # e.g., "Deep work sprint, 90 minutes"; if None or empty, fallback to static result_limit below
max_playlist_songs = 50  # cap to avoid huge queries if task is long

# Retrieval sizing and blended cap (fallbacks when no task time available)
result_limit = 10              # fallback final number of results to return
blended_max_ratio = 0.4        # allow up to 40% of results to be 'Blended'
fetch_multiplier = 5           # search more candidates to allow post-filtering (e.g., 5x of result_limit)

OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen2.5:3b-instruct")

def ollama_up(base: str) -> bool:
    try:
        r = requests.get(f"{base}/api/tags", timeout=3)
        return r.status_code == 200
    except Exception:
        return False

def ollama_generate(prompt: str, model: str, expect_json: bool = False, num_predict: int = 64) -> str:
    url = f"{OLLAMA_HOST}/api/generate"
    payload = {
        "model": model,
        "prompt": prompt,
        "stream": False,
        "options": {"temperature": 0.2, "num_predict": num_predict}
    }
    if expect_json:
        payload["format"] = "json"
    r = requests.post(url, json=payload, timeout=45)
    r.raise_for_status()
    data = r.json()
    return (data.get("response") or "").strip()

def estimate_minutes_from_task(desc: str) -> int:
    if not desc or not desc.strip():
        return 0
    prompt = (
        "Estimate the total time required for this task in minutes. "
        "Return ONLY valid JSON like {\"minutes\": 45}.\nTask: " + desc.strip()
    )
    try:
        text = ollama_generate(prompt, OLLAMA_MODEL, expect_json=True, num_predict=64)
        # Parse JSON
        try:
            obj = json.loads(text)
        except Exception:
            # Try to extract JSON object if wrapped
            m = re.search(r"\{[^\}]*\}", text)
            obj = json.loads(m.group(0)) if m else {}
        minutes = int(float(obj.get("minutes", 0))) if isinstance(obj, dict) else 0
        return max(0, minutes)
    except Exception:
        return 0

def sample_avg_song_minutes(sample_size: int = 5000) -> float:
    try:
        MAX_QUERY_WINDOW = 16384
        lim = min(sample_size, MAX_QUERY_WINDOW)
        rows = collection.query(expr="duration_ms >= 0", output_fields=["duration_ms"], limit=lim)
        if not rows:
            return 3.5
        vals = [r.get("duration_ms") for r in rows if r.get("duration_ms") is not None]
        if not vals:
            return 3.5
        avg_ms = sum(vals) / len(vals)
        return max(1e-6, avg_ms / 60000.0)
    except Exception:
        return 3.5

minutes_estimate = None
avg_song_minutes = None
playlist_count = None

if use_task_time and task_description and ollama_up(OLLAMA_HOST):
    minutes_estimate = estimate_minutes_from_task(task_description)
    if minutes_estimate and minutes_estimate > 0:
        avg_song_minutes = sample_avg_song_minutes()
        playlist_count = int(max(1, min(max_playlist_songs, minutes_estimate // max(1e-6, avg_song_minutes))))
        result_limit = playlist_count

# Build Milvus boolean expression for filtering
if target_vibe and isinstance(target_vibe, str) and target_vibe.strip():
    if include_blended:
        filter_expr = f'vibe in ["{target_vibe}", "Blended"]'
    else:
        filter_expr = f'vibe == "{target_vibe}"'
else:
    filter_expr = None

print("Using filter expr:", filter_expr)
print({
    'target_vibe': target_vibe,
    'include_blended': include_blended,
    'result_limit': result_limit,
    'blended_max_ratio': blended_max_ratio,
    'fetch_multiplier': fetch_multiplier,
    'task_description': task_description,
    'minutes_estimate': minutes_estimate,
    'avg_song_minutes': avg_song_minutes,
    'playlist_count': playlist_count,
    'time_budget_mode': time_budget_mode,
    'budget_overshoot_ms': budget_overshoot_ms,
})

Using filter expr: vibe in ["Groove", "Blended"]
{'target_vibe': 'Groove', 'include_blended': True, 'result_limit': 10, 'blended_max_ratio': 0.4, 'fetch_multiplier': 5, 'task_description': None, 'minutes_estimate': None, 'avg_song_minutes': None, 'playlist_count': None, 'time_budget_mode': True, 'budget_overshoot_ms': 15000}


In [82]:
# Seed selection and vibe-aware retrieval with blended cap
from typing import List

# Configure a seed by name if desired; otherwise we auto-pick one from target_vibe
seed_song_name = None  # e.g., "Shape of You"; leave None to auto-select from vibe

# Helper to get one entity by expr safely
def get_one(expr: str):
    try:
        res = collection.query(expr=expr, output_fields=["embedding", "vibe", "name", "artists"], limit=1)
        return res[0] if res else None
    except Exception as e:
        return None

# If specific seed name is provided, query its embedding
seed_row = None
if seed_song_name:
    seed_row = get_one(f'name == "{seed_song_name}"')
else:
    if target_vibe:
        seed_row = get_one(f'vibe == "{target_vibe}"')
    if not seed_row:
        # Fallbacks in order: any non-Blended with known vibe -> any with any vibe
        seed_row = get_one('vibe != "" and vibe != "Blended"') or get_one('vibe != ""')

if not seed_row:
    raise ValueError("No seed candidate found. Check that the collection has 'vibe' populated.")

query_embedding = seed_row["embedding"]
print(f"Using seed: {seed_row['name']} | Vibe={seed_row.get('vibe')}")

# Prepare search params
search_params = {"metric_type": "L2", "params": {"ef": 64}}
expr = filter_expr if 'filter_expr' in globals() and filter_expr else None

# Fetch a larger candidate pool to allow post-filtering (respect Milvus constraints)
MAX_QUERY_WINDOW = 16384
fetch_k = max(result_limit * fetch_multiplier, result_limit)
fetch_k = min(fetch_k, MAX_QUERY_WINDOW)  # cap

raw_results = collection.search(
    data=[query_embedding],
    anns_field="embedding",
    param=search_params,
    limit=fetch_k,
    output_fields=["id", "name", "artists", "vibe", "duration_ms", "embedding_json"],
    expr=expr
)[0]

# Post-process: blended ratio cap, then time-accurate capping if enabled
use_time_budget = (
    'minutes_estimate' in globals() and minutes_estimate and minutes_estimate > 0 and
    'time_budget_mode' in globals() and time_budget_mode
)

max_blended = int(round(blended_max_ratio * result_limit)) if include_blended else 0
final: List = []
blended_used = 0

if use_time_budget:
    target_ms = int(minutes_estimate * 60_000)
    budget = target_ms
    # Simple greedy fill by ranking order; stop when next track would exceed budget + tolerance
    for hit in raw_results:
        this_vibe = hit.entity.get('vibe') or ''
        is_blended = (this_vibe == 'Blended')
        if is_blended and blended_used >= max_blended:
            continue
        dur = hit.entity.get('duration_ms') or 0
        # If no duration available, treat as average 3.5 min
        if not isinstance(dur, (int, float)) or dur <= 0:
            dur = int(3.5 * 60_000)
        next_total = sum((h.entity.get('duration_ms') or int(3.5 * 60_000)) for h in final) + dur
        if next_total > budget + budget_overshoot_ms:
            break
        final.append(hit)
        if is_blended:
            blended_used += 1
        # Hard stop on max songs to avoid pathological cases
        if len(final) >= max_playlist_songs:
            break
    # If underfilled (e.g., all long tracks), relax slightly within tolerance by adding one if possible
    if not final and len(raw_results) > 0:
        # take the top candidate if within overshoot tolerance alone
        top = raw_results[0]
        top_dur = top.entity.get('duration_ms') or int(3.5 * 60_000)
        if top_dur <= budget + budget_overshoot_ms:
            final = [top]
else:
    for hit in raw_results:
        this_vibe = hit.entity.get('vibe') or ''
        is_blended = (this_vibe == 'Blended')
        if is_blended and blended_used >= max_blended:
            continue
        final.append(hit)
        if is_blended:
            blended_used += 1
        if len(final) >= result_limit:
            break

filled_ms = sum((h.entity.get('duration_ms') or int(3.5 * 60_000)) for h in final)
print(f"Selected {len(final)} results (blended_used={blended_used}, max_blended={max_blended}, filled_ms={filled_ms})")

# Print results
for hit in final:
    print(f"Score: {hit.score:.4f}")
    print(f"Name: {hit.entity.get('name')} | Vibe: {hit.entity.get('vibe')} | Duration(ms): {hit.entity.get('duration_ms')}")
    print(f"Artists: {hit.entity.get('artists')}")
    print("---")

Using seed: Emotion | Vibe=Groove
Selected 10 results (blended_used=3, max_blended=4, filled_ms=2467895)
Score: 0.0000
Name: Emotion | Vibe: Groove | Duration(ms): 236400
Artists: ["Destiny's Child"]
---
Score: 141.9996
Name: Emotions In Motion | Vibe: Blended | Duration(ms): 298400
Artists: ['Billy Squier']
---
Score: 147.1324
Name: Oh What A Feeling | Vibe: Groove | Duration(ms): 301667
Artists: ['Wailing Souls']
---
Score: 152.7369
Name: Hello | Vibe: Blended | Duration(ms): 226308
Artists: ['OMFG']
---
Score: 153.8161
Name: Heartbeat | Vibe: Groove | Duration(ms): 269840
Artists: ['Childish Gambino']
---
Score: 155.7005
Name: Act Of Affection | Vibe: Groove | Duration(ms): 164907
Artists: ['Wailing Souls']
---
Score: 158.2440
Name: Straight On | Vibe: Groove | Duration(ms): 307973
Artists: ['Heart']
---
Score: 160.7546
Name: Star | Vibe: Blended | Duration(ms): 264200
Artists: ['Earth, Wind & Fire']
---
Score: 162.0261
Name: NO NAME | Vibe: Groove | Duration(ms): 183827
Artists: ['

In [83]:
# connections.disconnect(alias="default")