diff --git a/troi/external/__init__.py b/troi/external/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/troi/external/gpt.py b/troi/external/gpt.py new file mode 100644 index 0000000..6fc5cc0 --- /dev/null +++ b/troi/external/gpt.py @@ -0,0 +1,48 @@ +import json + +from troi import Element, Artist, Recording, PipelineError +from openai import OpenAI + +PLAYLIST_PROMPT_PREFIX = "Create a playlist of 50 songs that are suitable for a playlist with the given description:" +PLAYLIST_PROMPT_SUFFIX = "The output should strictly adhere to the following JSON format: the top level JSON object should have three keys, playlist_name to denote the name of the playlist, playlist_description to denote the description of the playlist, and recordings a JSON array of objects where each element JSON object has the recording_name and artist_name keys." + + +class GPTRecordingElement(Element): + """ Ask GPT to generate a list of recordings given the playlist description. """ + + def __init__(self, api_key, prompt): + super().__init__() + self.client = OpenAI(api_key=api_key) + self.prompt = prompt + + @staticmethod + def outputs(): + return [Recording] + + def read(self, inputs): + full_prompt = PLAYLIST_PROMPT_PREFIX + " " + self.prompt + " " + PLAYLIST_PROMPT_SUFFIX + response = self.client.chat.completions.create( + model="gpt-3.5-turbo-0125", + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": "You are a helpful assistant designed to output JSON."}, + {"role": "user", "content": full_prompt} + ] + ) + + try: + playlist = json.loads(response.choices[0].message.content) + except json.JSONDecodeError as e: + raise PipelineError("Cannot parse JSON response from OpenAI: %s" % (response.choices[0].message.content,)) \ + from e + + recordings = [] + for item in playlist["recordings"]: + artist = Artist(name=item["artist_name"]) + recording = Recording(name=item["recording_name"], artist=artist) + recordings.append(recording) + + self.local_storage["_playlist_name"] = playlist["playlist_name"] + self.local_storage["_playlist_desc"] = playlist["playlist_description"] + + return recordings diff --git a/troi/patches/ai_jams.py b/troi/patches/ai_jams.py new file mode 100755 index 0000000..41af65c --- /dev/null +++ b/troi/patches/ai_jams.py @@ -0,0 +1,54 @@ +import troi.external.gpt +import troi.musicbrainz.mbid_mapping +import troi.patch +from troi import Playlist +from troi.playlist import PlaylistMakerElement + + +class AiJamsPatch(troi.patch.Patch): + """ Generate a playlist using AI from the given prompt. """ + + @staticmethod + def inputs(): + """ + Generate a playlist using AI + + \b + API_KEY is the OpenAI api key. + PROMPT is the description of the playlist to generate. + """ + return [ + {"type": "argument", "args": ["api_key"]}, + {"type": "argument", "args": ["prompt"]} + ] + + @staticmethod + def outputs(): + return [Playlist] + + @staticmethod + def slug(): + return "ai-jams" + + @staticmethod + def description(): + return "Generate a playlist using AI from the given prompt." + + def create(self, inputs): + api_key = inputs['api_key'] + prompt = inputs['prompt'].strip() + + ai_recordings_lookup = troi.external.gpt.GPTRecordingElement(api_key, prompt) + + recs_lookup = troi.musicbrainz.mbid_mapping.MBIDMappingLookupElement(remove_unmatched=True) + recs_lookup.set_sources(ai_recordings_lookup) + + pl_maker = PlaylistMakerElement( + patch_slug=self.slug(), + max_num_recordings=50, + max_artist_occurrence=2, + shuffle=False + ) + pl_maker.set_sources(recs_lookup) + + return pl_maker