Skip to content

Commit

Permalink
AI Jams: a simple demo for creating new patches and elements (#130)
Browse files Browse the repository at this point in the history
Use OpenAI API to retrieve recordings for a given prompt in json, lookup
mbids for the same using labs api and generate a playlist.
  • Loading branch information
amCap1712 committed Apr 8, 2024
1 parent b7839a8 commit e998c41
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
Empty file added troi/external/__init__.py
Empty file.
48 changes: 48 additions & 0 deletions troi/external/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json

from troi import Element, Artist, Recording, PipelineError
from openai import OpenAI

PLAYLIST_PROMPT_PREFIX = "Create a playlist of 50 songs that are suitable for a playlist with the given description:"
PLAYLIST_PROMPT_SUFFIX = "The output should strictly adhere to the following JSON format: the top level JSON object should have three keys, playlist_name to denote the name of the playlist, playlist_description to denote the description of the playlist, and recordings a JSON array of objects where each element JSON object has the recording_name and artist_name keys."


class GPTRecordingElement(Element):
""" Ask GPT to generate a list of recordings given the playlist description. """

def __init__(self, api_key, prompt):
super().__init__()
self.client = OpenAI(api_key=api_key)
self.prompt = prompt

@staticmethod
def outputs():
return [Recording]

def read(self, inputs):
full_prompt = PLAYLIST_PROMPT_PREFIX + " " + self.prompt + " " + PLAYLIST_PROMPT_SUFFIX
response = self.client.chat.completions.create(
model="gpt-3.5-turbo-0125",
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": full_prompt}
]
)

try:
playlist = json.loads(response.choices[0].message.content)
except json.JSONDecodeError as e:
raise PipelineError("Cannot parse JSON response from OpenAI: %s" % (response.choices[0].message.content,)) \
from e

recordings = []
for item in playlist["recordings"]:
artist = Artist(name=item["artist_name"])
recording = Recording(name=item["recording_name"], artist=artist)
recordings.append(recording)

self.local_storage["_playlist_name"] = playlist["playlist_name"]
self.local_storage["_playlist_desc"] = playlist["playlist_description"]

return recordings
54 changes: 54 additions & 0 deletions troi/patches/ai_jams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import troi.external.gpt
import troi.musicbrainz.mbid_mapping
import troi.patch
from troi import Playlist
from troi.playlist import PlaylistMakerElement


class AiJamsPatch(troi.patch.Patch):
""" Generate a playlist using AI from the given prompt. """

@staticmethod
def inputs():
"""
Generate a playlist using AI
\b
API_KEY is the OpenAI api key.
PROMPT is the description of the playlist to generate.
"""
return [
{"type": "argument", "args": ["api_key"]},
{"type": "argument", "args": ["prompt"]}
]

@staticmethod
def outputs():
return [Playlist]

@staticmethod
def slug():
return "ai-jams"

@staticmethod
def description():
return "Generate a playlist using AI from the given prompt."

def create(self, inputs):
api_key = inputs['api_key']
prompt = inputs['prompt'].strip()

ai_recordings_lookup = troi.external.gpt.GPTRecordingElement(api_key, prompt)

recs_lookup = troi.musicbrainz.mbid_mapping.MBIDMappingLookupElement(remove_unmatched=True)
recs_lookup.set_sources(ai_recordings_lookup)

pl_maker = PlaylistMakerElement(
patch_slug=self.slug(),
max_num_recordings=50,
max_artist_occurrence=2,
shuffle=False
)
pl_maker.set_sources(recs_lookup)

return pl_maker

0 comments on commit e998c41

Please sign in to comment.