In [None]:
%load_ext autoreload
%autoreload 2

import json
from pathlib import Path

import opencc
from google import genai
from google.genai import types as gtypes

from ai_storyteller.music_generation.diffrhythm import DiffRhythm
from ai_storyteller.utils.env_utils import get_env_var
from ai_storyteller.utils.text_utils import clean_lyric_lines

In [None]:
api_key = get_env_var("GEMINI_API_KEY")
client = genai.Client(api_key=api_key)
model = "gemini-2.0-flash"

In [None]:
dr = DiffRhythm()

In [None]:
dr.generate_music(
    lrc_path="../../data/music_generation/lrc/eg_en_full.lrc",
    ref_audio_path="../../data/music_generation/music/snoozy beats - Feel the Glow.mp3",
    output_dir="./output",
    output_file_name="sample.wav",
)

# Generate lyrics from a children's story

In [None]:
with open("../../data/stories/the_wolf_and_the_seven_kids/story.json") as f:
    data = json.load(f)
data

In [None]:
pages = data["pages"]
text = []
for page in pages:
    text.append(pages[page]["text"].strip())
text = "\n".join(text).strip()
print(text)

## Prepare prompt for lyric generation

In [None]:
gen_lyrics_from_story_prompt = """\
根據以下故事生成歌詞，這首歌長度為{seconds}秒
格式應該是 .lrc，例如：

```
[00:00.00]歌詞內容
[00:01.00]歌詞內容
...
```
除了歌詞以外，其他的內容都不需要，中間不能有空行，如果故事的內容是中文，請用中文生成歌詞，如果故事的內容是英文，請用英文生成歌詞。

故事：
```
{story}
```
"""

filled_gen_lyrics_from_story_prompt = gen_lyrics_from_story_prompt.format(
    story=text, seconds=95
)
print(filled_gen_lyrics_from_story_prompt)

## Option 1. Generate using Gemini

In [None]:
res = client.models.generate_content(
    model=model,
    contents=filled_gen_lyrics_from_story_prompt,
)
lyrics = res.text
print(lyrics)

## Option 2. Generate using HuggingChat

In [None]:
from huggingface_hub import InferenceClient

huggingchat_api_key = get_env_var("HUGGINGCHAT_API_KEY")
client = InferenceClient(
    provider="novita",
    api_key=huggingchat_api_key,
)

In [None]:
messages = [
    {"role": "user", "content": f"{filled_gen_lyrics_from_story_prompt}/nothink"}
]
res = client.chat.completions.create(
    model="Qwen/Qwen3-235B-A22B",
    messages=messages,
    temperature=0.5,
    max_tokens=8192,
    top_p=0.7,
)
lyrics = res.choices[0].message.content
print(lyrics)

## Option 3. Transformers

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,  # pyright: ignore[reportPrivateImportUsage]
    AutoTokenizer,  # pyright: ignore[reportPrivateImportUsage]
    BitsAndBytesConfig,  # pyright: ignore[reportPrivateImportUsage]
)

model_name = "Qwen/Qwen8-4B"
# model_name = "Qwen/Qwen3-4B"  # 如果遇到沒有VRAM的問題，可以試試這個模型（）
device = "cuda" if torch.cuda.is_available() else "cpu"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
)

In [None]:
messages = [{"role": "user", "content": f"{filled_gen_lyrics_from_story_prompt}"}]

text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False,  # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

# conduct text completion
generated_ids = model.generate(**model_inputs, max_new_tokens=32768)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist()

# parsing thinking content
try:
    # rindex finding 151668 (</think>)
    index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
    index = 0

thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip(
    "\n"
)
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

print("thinking content:", thinking_content)
print("content:", content)
lyrics = content

## Clean lyrics

In [None]:
lyrics = clean_lyric_lines(lyrics)
lyrics

## Simplified Chinese seems to work better, so we'll convert

In [None]:
t2s = opencc.OpenCC("t2s.json")
lyrics = t2s.convert(lyrics)
print(lyrics)

## Save lyrics to file

In [None]:
with open("/tmp/lyrics.lrc", "w") as f:
    f.write(lyrics)

# Generate a song from the lyrics

## Using reference audio

There are other samples in `data/music_generation/music`

In [None]:
audio_paths = list(Path("../../data/music_generation/music").glob("*.mp3"))
audio_paths

In [None]:
dr.generate_music(
    lrc_path="/tmp/lyrics.lrc",
    ref_audio_path=audio_paths[0],
    output_dir="./output",
    output_file_name="song_from_audio.wav",
)

## Using prompt seems to work better than using reference song

Describe styles/scenes in words (e.g., `Jazzy Nightclub Vibe`, `Pop Emotional Piano` or `Indie folk ballad`, `coming-of-age themes`, `acoustic guitar picking with harmonica interludes`)

In [None]:
dr.generate_music(
    ref_prompt="Children's song",
    lrc_path="/tmp/lyrics.lrc",
    output_dir="./output",
    output_file_name="song_from_prompt.wav",
)

# Generate instruments only 

In [None]:
dr.generate_music(
    ref_prompt="Children's song",
    output_dir="./output",
    output_file_name="instrumental_from_prompt.wav",
    instrumental_only=True,
)

# Identify instruments (with Gemini)

In [None]:
api_key = get_env_var("GEMINI_API_KEY")
client = genai.Client(api_key=api_key)
model = "gemini-2.0-flash"

In [None]:
identify_instruments_prompt = "你聽到了什麼樂器？請列出來"
with open("./output/instrumental_from_prompt.wav", "rb") as f:
    audio_bytes = f.read()

res = client.models.generate_content(
    model=model,
    contents=[
        identify_instruments_prompt,
        gtypes.Part.from_bytes(
            data=audio_bytes,
            mime_type="audio/wav",
        ),
    ],
)
print(res.text)

# Exercise: Generate an instrumental using your own prompt

* **分組。** 線上參與的同學可以自行組成小組。
* **生成音樂：** 看看你的提示語可以多詳細，描述不同的音樂風格或特定樂器。測試模型的極限，試著給它一些挑戰。
* **完成後，** 將檔案上傳到 Google Drive，檔名設為你們的組別編號 (線上組別也一樣)。
* **組別請猜測試算表上順序下一組的歌曲（例如：第一組猜第二組，第二組猜第三組，...，最後一組猜第一組）。** 你們需要猜測那一首被分配到的歌曲用了什麼提示語。把你們猜的答案寫在試算表中你們組別的欄位下。
* **線上的組別可以自由選擇任何組別的歌曲來猜測提示語。** 現場的組別在完成指定的猜測後，也可以自由猜測其他組的歌曲。把你們所有的猜測都寫進試算表。看看誰能猜得最接近、最準確！
* **最後，** 原本生成歌曲的那一組要把他們用的提示語寫進試算表裡。