<a href="https://colab.research.google.com/github/cburchett/podcastcreator/blob/main/PodcastCreator_Dia_1_6B.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install gradio

In [2]:
# Install directly from GitHub
!pip -q install git+https://github.com/nari-labs/dia.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
from dia.model import Dia

audio_model = Dia.from_pretrained("nari-labs/Dia-1.6B")

  WeightNorm.apply(module, name, dim)


In [4]:
import os
from google import genai

from google.colab import userdata
API_KEY = userdata.get('GOOGLE_API_KEY')

os.environ["GOOGLE_API_KEY"] = API_KEY

# Create a client
client = genai.Client(api_key=API_KEY)

In [5]:
#MODEL_ID = "gemini-2.5-pro-exp-03-25"
YOUTUBE_URL = "https://www.youtube.com/watch?v=rSCaiHFRx0k"

In [6]:
PROMPT = """Analyze the attached Youtube video.

Based on the key topics, information, and events presented in the video, generate a medium length, conversational podcast script between two speakers, labeled S1 and S2.

The script should summarize or discuss the main points of the video in a natural, back-and-forth dialogue format.

**Crucially, format the output *exactly* as follows:**

*   Each line of dialogue must start with either `[S1]` or `[S2]`.
*   Follow the speaker tag with a space, then their dialogue.
*   Present the dialogue turns sequentially, mimicking a conversation.
*   Don't add any prefix or suffix to the conversation

**Use this specific structure as your template:**

```
[S1] {Dialogue for speaker 1}
[S2] {Dialogue for speaker 2}
[S1] {Dialogue for speaker 1, potentially a reaction or follow-up}
[S2] {Dialogue for speaker 2}
[S1] {Dialogue for speaker 1}
```

**Example of the desired output format:**

```
[S1] Hey Sam, How are you? Let me tell you about Dia it's an open weights text to dialogue model.
[S2] You get full control over scripts and voices.
[S1] Wow. Amazing. (laughs)
[S2] Try it now on Git hub or Hugging Face.
[S1] You bet I will!
```

**Constraints:**

*   Keep the turns relatively short and conversational.
*   Focus on the core message or interesting aspects of the video.
*   Adhere strictly to the `[S1]` / `[S2]` formatting.
*   **Incorporate non-verbal cues where natural and appropriate.** These should be enclosed in parentheses within the dialogue line (e.g., `(laughs)` or `(sighs)`). You may use cues from this list: `(laughs)`, `(clears throat)`, `(sighs)`, `(gasps)`, `(coughs)`, `(singing)`, `(sings)`, `(mumbles)`, `(beep)`, `(groans)`, `(sniffs)`, `(claps)`, `(screams)`, `(inhales)`, `(exhales)`, `(applause)`, `(burps)`, `(humming)`, `(sneezes)`, `(whistles)`.
*   Do not add any introductory text, explanations, or summaries outside of the formatted script itself.

**Now, analyze the video and generate the script.**

---
"""

In [7]:
import os
from datetime import datetime

def ensure_folder_exists():
    now = datetime.now()
    timestamp = now.strftime('%Y-%m-%d %H:%M:%S')
    folder_path = '/content/' + timestamp
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        print(f"Folder '{folder_path}' created.")
    else:
        print(f"Folder '{folder_path}' already exists.")
    return folder_path

In [8]:
from google.genai import types

def generate_podcast_script(youtube_url, model, prompt):
    response = client.models.generate_content(
        model=model,
        contents=types.Content(
            parts=[
                types.Part(text=prompt),
                types.Part(
                    file_data=types.FileData(file_uri=youtube_url)
                )
            ]
        )
    )
    return response.text

In [9]:
import re

def split_podcast_transcript(transcript, pairs):
    """Splits a podcast transcript into segments based on S1 and S2 pairs.

    Args:
        transcript: The podcast transcript as a string.

    Returns:
        A list of strings, where each string represents a segment of the transcript.
        Returns an empty list if the input is invalid or no valid segments are found.
    """

    segments = []
    try:
        # Split the transcript into lines
        lines = transcript.strip().split('\n')

        # Use regular expressions to find S1 and S2 pairs
        pattern = r"\[(S[12])\](.*)"
        s1_s2_pairs = []
        for line in lines:
          match = re.match(pattern, line)
          if match:
            s1_s2_pairs.append(match.groups())

        # Group lines into segments of three S1/S2 pairs
        for i in range(0, len(s1_s2_pairs), pairs):
            segment = ""
            for j in range(i, min(i + pairs, len(s1_s2_pairs))):
                segment += f"[{s1_s2_pairs[j][0]}] {s1_s2_pairs[j][1]}\n"
            segments.append(segment.strip())
    except Exception as e:
        print(f"Error processing transcript: {e}")
        return []

    return segments

In [10]:
import os
import soundfile as sf
from pydub import AudioSegment

def combine_mp3s(folder_path, output_file):
    """Combines all MP3 files in a folder into a single MP3 file.

    Args:
        folder_path: The path to the folder containing the MP3 files.
        output_file: The path to the output MP3 file.
    """
    combined = AudioSegment.empty()
    file_list = os.listdir(folder_path)
    file_list.sort()
    for filename in file_list:
        if filename.endswith(".mp3"):
            filepath = os.path.join(folder_path, filename)
            try:
                segment = AudioSegment.from_mp3(filepath)
                combined += segment
            except Exception as e:
                print(f"Error processing {filename}: {e}")
    combined.export(output_file, format="mp3")
    print(f"Combined {len(file_list)} MP3 files into {output_file}")
    return output_file

In [11]:
import soundfile as sf

def generate_podcast(
    transcript: str,
    pairs: int,
    max_new_tokens: int,
    cfg_scale: float,
    temperature: float,
    top_p: float,
    cfg_filter_top_k: int,
    speed_factor: float,
    ):

    folder_path = ensure_folder_exists()

    segments = split_podcast_transcript(transcript, pairs)
    print(f"Number of segments: " + str(len(segments)))

    for idx, seg in enumerate(segments):
      print(f"Generating segment {idx+1}")
      print(seg)
      output = audio_model.generate(text=seg, max_tokens=max_new_tokens, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, cfg_filter_top_k=cfg_filter_top_k)
      sf.write(folder_path + f"/podcast_{idx+1}.mp3", output, 44100)

    return combine_mp3s(folder_path, folder_path + "/finaL_podcast.mp3")


In [12]:
import gradio as gr

In [None]:
with gr.Blocks() as podcast_script_generator:
  gr.Markdown("## Podcast Generator")
  with gr.Tab("Script"):
    with gr.Row():
      with gr.Column():
        youtube_url = gr.Textbox(label="URL", value=YOUTUBE_URL)
        selected_model = gr.Dropdown(["gemini-2.5-pro-exp-03-25"])
        system_prompt = gr.Textbox(label="System prompt", value=PROMPT)
        generate_script_button = gr.Button("Generate script")
      with gr.Column():
        podcast_script = gr.Textbox(label="Podcast script", lines=20)
  generate_script_button.click(fn=generate_podcast_script, inputs=[youtube_url, selected_model, system_prompt], outputs=[podcast_script])

  with gr.Tab("Podcast"):
     with gr.Row():
       with gr.Column():
          final_podcast_script = gr.Textbox(label="Final podcast script", lines=20)
          pairs = gr.Slider(
                    label="Segment pairs",
                    minimum=1,
                    maximum=10,
                    value=2,  # Default from inference.py
                    step=1,
                    info="Higher values increase number of [S1] [S2] pairs will be in each batch.",
                )
          with gr.Accordion("Generation Parameters", open=False):
                max_new_tokens = gr.Slider(
                    label="Max New Tokens (Audio Length)",
                    minimum=860,
                    maximum=3072,
                    value=audio_model.config.data.audio_length,  # Use config default if available, else fallback
                    step=50,
                    info="Controls the maximum length of the generated audio (more tokens = longer audio).",
                )
                cfg_scale = gr.Slider(
                    label="CFG Scale (Guidance Strength)",
                    minimum=1.0,
                    maximum=5.0,
                    value=3.0,  # Default from inference.py
                    step=0.1,
                    info="Higher values increase adherence to the text prompt.",
                )
                temperature = gr.Slider(
                    label="Temperature (Randomness)",
                    minimum=1.0,
                    maximum=1.5,
                    value=1.3,  # Default from inference.py
                    step=0.05,
                    info="Lower values make the output more deterministic, higher values increase randomness.",
                )
                top_p = gr.Slider(
                    label="Top P (Nucleus Sampling)",
                    minimum=0.80,
                    maximum=1.0,
                    value=0.95,  # Default from inference.py
                    step=0.01,
                    info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.",
                )
                cfg_filter_top_k = gr.Slider(
                    label="CFG Filter Top K",
                    minimum=15,
                    maximum=50,
                    value=30,
                    step=1,
                    info="Top k filter for CFG guidance.",
                )
                speed_factor_slider = gr.Slider(
                    label="Speed Factor",
                    minimum=0.8,
                    maximum=1.0,
                    value=0.94,
                    step=0.02,
                    info="Adjusts the speed of the generated audio (1.0 = original speed).",
                )
          generate_podcast_button = gr.Button("Generate podcast")
       with gr.Column():
          podcast_audio = gr.Audio(label="Final podcast audio", type='filepath')
  generate_podcast_button.click(fn=generate_podcast, inputs=[
            final_podcast_script,
            pairs,
            max_new_tokens,
            cfg_scale,
            temperature,
            top_p,
            cfg_filter_top_k,
            speed_factor_slider,], outputs=[podcast_audio])
  podcast_script.change(fn=lambda x: x, inputs=podcast_script, outputs=final_podcast_script)

podcast_script_generator.launch(share=True, debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://8a037483557a52af94.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Folder '/content/2025-05-02 19:29:06' created.
Number of segments: 7
Generating segment 1
[S1]  Hey, did you catch that NetApp video about AI and data infrastructure? It got me thinking about how much pressure AI is putting on traditional storage.
[S2]  Totally. Tom Shields kicked it off by saying exactly that – AI data pipelines are really straining storage architectures because the workloads are constantly evolving.


Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/gradio/queueing.py", line 625, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/route_utils.py", line 322, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/blocks.py", line 2146, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/gradio/blocks.py", line 1664, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^