# Please run on Colab with T4 gpu

In [None]:
%%capture
!pip install TTS
!pip install torch
!pip install audiocraft

## Music generation Class

In [None]:
import torch
from transformers import pipeline
from TTS.api import TTS
from IPython.display import Audio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
import wave

class GenerateAudio():
    """
    A class to generate audio files based on a provided text input. The class generates
    speech, classifies the sentiment of the text, and creates background music to accompany
    the speech based on the classified sentiment.

    Attributes:
    device (str): The device to run the models on ('cuda' if GPU is available, 'cpu' otherwise).
    text (str): The input text to generate speech and classify sentiment.
    """

    def __init__(self, text):
        """
        Initializes the GenerateAudio class with the provided text and sets up the device.

        Parameters:
        text (str): The input text for generating speech and sentiment analysis.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.text = text

    def _generate_sentiment(self, candidate_labels=["happy", "sad", "scary"]):
        """
        Classifies the sentiment of the input text using a zero-shot classification model.

        Parameters:
        candidate_labels (list): A list of sentiment labels to classify the text. Default is
                                  ["happy", "sad", "scary"].

        Returns:
        str: The label with the highest classification score indicating the sentiment of the text.
        """
        classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=self.device)  # Run on GPU if available
        result = classifier(self.text, candidate_labels)

        labels_and_scores = list(zip(result['labels'], result['scores']))
        best_label, best_score = max(labels_and_scores, key=lambda x: x[1])

        return best_label

    def _generate_speech(self, outfile='tts_output.wav'):
        """
        Generates speech from the input text using a Tacotron 2 TTS model and saves it to a file.

        Parameters:
        outfile (str): The path to save the generated speech audio file. Default is 'tts_output.wav'.

        Returns:
        str: The path to the generated speech audio file.
        """
        tts_model = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC").to(self.device)
        tts_model.tts_to_file(text=self.text, file_path=outfile)
        return outfile

    def _generate_background_music(self, label, duration, outfile="bg_audio"):
        """
        Generates background music based on a given label and duration using the MusicGen model.

        Parameters:
        label (str): The sentiment label used to guide the generation of background music.
        duration (float): The duration (in seconds) for which the background music should play.
        outfile (str): The path to save the generated background music file. Default is 'bg_audio'.

        Returns:
        str: The path to the generated background music audio file.
        """
        model = MusicGen.get_pretrained('small', device=self.device)
        model.set_generation_params(duration=duration)  # Duration of the generated waveform in seconds
        output = model.generate(
            descriptions=[f'{label} + Orchestral Background Music']
        )

        if outfile.endswith('.wav'):
            outfile = outfile[:-4]

        audio_write(stem_name=outfile, wav=output[0], sample_rate=model.sample_rate)
        return outfile

    def _find_duration(self, filePath):
        """
        Calculates the duration of an audio file based on its sample rate and number of frames.

        Parameters:
        filePath (str): The path to the audio file.

        Returns:
        float: The duration of the audio file in seconds.
        """
        with wave.open(filePath, 'rb') as audio_file:
            sample_rate = audio_file.getframerate()  # Sample rate (Hz)
            num_frames = audio_file.getnframes()    # Total number of frames
            duration = num_frames / float(sample_rate)  # Duration in seconds
        return duration

    def create_audio_files(self, ttsPath, bgMusicPath):
        """
        Creates speech and background music audio files based on the sentiment of the input text.

        The function generates speech from the input text, calculates its duration,
        and generates background music based on the classified sentiment.

        Parameters:
        ttsPath (str): The path to save the generated speech audio file.
        bgMusicPath (str): The path to save the generated background music audio file.

        Returns:
        tuple: A tuple containing the paths to the generated speech and background music files.
        """
        sentiment = self._generate_sentiment()
        ttsPath = self._generate_speech(ttsPath)
        duration = self._find_duration(ttsPath)
        bgMusicPath = self._generate_background_music(sentiment, duration,outfile=bgMusicPath)

        return ttsPath, bgMusicPath



In [None]:
obj = GenerateAudio("The dark forest gave me chills.")
obj.create_audio_files("tts_test.wav","bg_audio_test.wav")