In [10]:
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer

In [11]:
import torch

In [12]:
import re

In [13]:
class ConversationGenerator:
    def __init__(self, model_name="stabilityai/stablelm-2-zephyr-1_6b"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        )
        self.generator = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer)

    def _format_chat(self, summary_text: str, max_turns: int):
        # Use the model's chat template for best behavior
        messages = [
            {"role": "system",
             "content": (
                 "You produce podcast-style dialogue between two hosts, Alex and Jordan. "
                 "Output ONLY lines that start with 'Alex:' or 'Jordan:'. No narration, no summary."
             )},
            {"role": "user",
             "content": (
                 "Topic summary:\n"
                 f"{summary_text}\n\n"
                 "Write a natural, back-and-forth conversation (about "
                 f"{max_turns} exchanges). Keep each line concise (1–2 sentences). "
                 "Vary questions and insights. Do NOT restate the summary—discuss it."
                 "\n\nExample style:\n"
                 "Alex: Quick hook question.\n"
                 "Jordan: Concise insight plus a follow-up question.\n"
                 "Alex: Builds on it, asks why.\n"
                 "Jordan: Gives context, adds a nuance.\n"
                 "----\n"
                 "Begin the transcript now."
             )},
        ]
        return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    def _postprocess(self, text: str):
        # Keep only Alex:/Jordan: lines, stop if model drifts
        lines = []
        for line in text.splitlines():
            line = line.strip()
            if re.match(r"^(Alex|Jordan):", line):
                lines.append(line)
        # de-duplicate and trim trailing partials
        seen = set()
        cleaned = []
        for l in lines:
            if l not in seen:
                cleaned.append(l)
                seen.add(l)
        return "\n".join(cleaned)

    def generate_conversation(self, summary_text: str, max_turns: int = 6):
        prompt = self._format_chat(summary_text, max_turns)

        out = self.generator(
            prompt,
            max_new_tokens=480,          # keeps it bounded
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,      # discourages looping
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
        )[0]["generated_text"]

        # Keep only model continuation (strip the prompt) if needed
        continuation = out[len(prompt):]
        return self._postprocess(continuation)