<a href="https://colab.research.google.com/github/mohammadalmalt/AI/blob/main/MusicGen.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U git+https://github.com/facebookresearch/audiocraft.git
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install scipy numpy


In [None]:
import torch
from audiocraft.models import MusicGen
from audiocraft.utils.notebook import display_audio
import scipy.io.wavfile
import numpy as np

# Check environment and load model
device = "cuda" if torch.cuda.is_available() else "cpu"

model = MusicGen.get_pretrained("facebook/musicgen-small")
print("MusicGen model loaded!")

# Define prompt and generation parameters
prompt = ["oriental music using Oud and violin. Maqam Hijaz"]
model.set_generation_params(duration=15)  # 15 seconds of music

# Generate music (output as float32 in [-1, 1])
print("Generating music...")
generated_audio = model.generate(prompt, progress=True)
music_waveform = generated_audio[0].cpu().numpy()  # still in float32

# Listen in Colab (using the float32 waveform)
display_audio(torch.tensor(music_waveform), sample_rate=model.sample_rate)

# --- Saving the audio correctly ---
output_path = "generated_music.wav"

# Ensure the data is a NumPy array and remove any extra dimensions
music_waveform = np.squeeze(music_waveform)

# If stereo with shape [2, samples], transpose to [samples, 2]
if music_waveform.ndim == 2 and music_waveform.shape[0] == 2:
    music_waveform = music_waveform.T

# Clip to ensure values are within [-1, 1] (should already be, but just in case)
music_waveform = np.clip(music_waveform, -1, 1)

# Convert float32 [-1,1] to int16 [-32768, 32767] **once**
music_waveform_int16 = (music_waveform * 32767).astype(np.int16)

# Use the sample rate from the model (or default to 32000)
sample_rate = int(model.sample_rate) if hasattr(model, "sample_rate") else 32000

# Save the WAV file
scipy.io.wavfile.write(output_path, sample_rate, music_waveform_int16)
print(f"High-quality music saved as: {output_path}")
