diff --git a/python/examples/math/ast_comment_handling.py b/python/examples/math/ast_comment_handling.py deleted file mode 100644 index c9a39678..00000000 --- a/python/examples/math/ast_comment_handling.py +++ /dev/null @@ -1,105 +0,0 @@ -import tokenize -import ast -import io -import inspect -from schema_with_comments import MathAPI - - -def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=True): - - def _extract_tokens_between_line_numbers(gen, start_lineno, end_lineno): - # Extract tokens between start_lineno and end_lineno obtained from the tokenize generator - tokens = [] - for tok in gen: - if tok.start[0] < start_lineno: # Skip tokens before start_lineno - continue - if tok.start[0] >= start_lineno and tok.end[0] <= end_lineno: - # Add token if it is within the range - tokens.append((tok.type, tok.string)) - elif tok.start[0] > end_lineno: # Stop if token is beyond end_lineno - break - - return tokens - - schema_path = inspect.getfile(schema_class) - - with open(schema_path, 'r') as f: - schema_class_source = f.read() - gen = tokenize.tokenize(io.BytesIO( - schema_class_source.encode('utf-8')).readline) - - tree = ast.parse(schema_class_source) - - if debug: - print("Source code before transformation:") - print("--"*50) - print(schema_class_source) - print("--"*50) - - has_comments = False # Flag later used to perform imports of Annotated and Doc if needed - - for node in tree.body: - if isinstance(node, ast.ClassDef): - for n in node.body: - if isinstance(n, ast.AnnAssign): # Check if the node is an annotated assignment - assgn_comment = None - tokens = _extract_tokens_between_line_numbers( - # Extract tokens between the line numbers of the annotated assignment - gen, n.lineno, n.end_lineno - ) - for toknum, tokval in tokens: - if toknum == tokenize.COMMENT: - # Extract the comment - assgn_comment = tokval - break - - if assgn_comment: - # If a comment is found, transform the annotation to include the comment - assgn_subscript = n.annotation - has_comments = True - n.annotation = ast.Subscript( - value=ast.Name(id="Annotated", ctx=ast.Load()), - slice=ast.Tuple( - elts=[ - assgn_subscript, - ast.Call( - func=ast.Name( - id="Doc", ctx=ast.Load() - ), - args=[ - ast.Constant( - value=assgn_comment.strip("#").strip() - ) - ], - keywords=[] - ) - ], - ctx=ast.Load() - ), - ctx=ast.Load() - ) - - if has_comments: - for node in tree.body: - if isinstance(node, ast.ImportFrom): - if node.module == "typing_extensions": - if ast.alias(name="Annotated") not in node.names: - node.names.append(ast.alias(name="Annotated")) - if ast.alias(name="Doc") not in node.names: - node.names.append(ast.alias(name="Doc")) - - transformed_schema_source = ast.unparse(tree) - - if debug: - print("Source code after transformation:") - print("--"*50) - print(transformed_schema_source) - print("--"*50) - - namespace = {} - exec(transformed_schema_source, namespace) - return namespace[schema_class.__name__] - - -if __name__ == "__main__": - print(_convert_pythonic_comments_to_annotated_docs(MathAPI)) diff --git a/python/examples/math/demo.py b/python/examples/math/demo.py index 8c862ae7..6b54fe60 100644 --- a/python/examples/math/demo.py +++ b/python/examples/math/demo.py @@ -4,7 +4,7 @@ import sys from typing import cast from dotenv import dotenv_values -import schema_with_comments as math +import schema as math from typechat import Failure, create_language_model, process_requests from program import TypeChatProgramTranslator, TypeChatProgramValidator, evaluate_json_program diff --git a/python/examples/math/program.py b/python/examples/math/program.py index 27941543..fc9479b7 100644 --- a/python/examples/math/program.py +++ b/python/examples/math/program.py @@ -121,7 +121,6 @@ class TypeChatProgramTranslator(TypeChatJsonTranslator[JsonProgram]): _api_declaration_str: str def __init__(self, model: TypeChatLanguageModel, validator: TypeChatProgramValidator, api_type: type): - api_type = self._convert_pythonic_comments_to_annotated_docs(api_type) super().__init__(model=model, validator=validator, target_type=api_type, _raise_on_schema_errors = False) # TODO: the conversion result here has errors! conversion_result = python_type_to_typescript_schema(api_type) diff --git a/python/examples/math/pythonic_comment_handling.py b/python/examples/math/pythonic_comment_handling.py deleted file mode 100644 index 20db9e19..00000000 --- a/python/examples/math/pythonic_comment_handling.py +++ /dev/null @@ -1,46 +0,0 @@ -import re -import inspect -from schema_with_comments import MathAPI - - -def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=True): - - schema_path = inspect.getfile(schema_class) - - with open(schema_path, 'r') as file: - schema_class_source = file.read() - - if debug: - print("File contents before modification:") - print("--"*50) - print(schema_class_source) - print("--"*50) - - pattern = r"(\w+\s*:\s*.*?)(?=\s*#\s*(.+?)(?:\n|\Z))" - commented_fields = re.findall(pattern, schema_class_source) - annotated_fields = [] - - for field, comment in commented_fields: - field_separator = field.split(":") - field_name = field_separator[0].strip() - field_type = field_separator[1].strip() - - annotated_fields.append( - f"{field_name}: Annotated[{field_type}, Doc(\"{comment}\")]") - - for field, annotation in zip(commented_fields, annotated_fields): - schema_class_source = schema_class_source.replace(field[0], annotation) - - if debug: - print("File contents after modification:") - print("--"*50) - print(schema_class_source) - print("--"*50) - - namespace = {} - exec(schema_class_source, namespace) - return namespace[schema_class.__name__] - - -if __name__ == "__main__": - print(_convert_pythonic_comments_to_annotated_docs(MathAPI)) \ No newline at end of file diff --git a/python/examples/math/schema.py b/python/examples/math/schema.py index c948862b..6d58586b 100644 --- a/python/examples/math/schema.py +++ b/python/examples/math/schema.py @@ -1,5 +1,6 @@ from typing_extensions import TypedDict, Annotated, Callable, Doc + class MathAPI(TypedDict): """ This is API for a simple calculator diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index 598f961a..2c285592 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -1,10 +1,6 @@ from typing_extensions import Generic, TypeVar import pydantic_core -import ast -import io -import tokenize -import inspect from typechat._internal.model import PromptSection, TypeChatLanguageModel from typechat._internal.result import Failure, Result, Success @@ -123,99 +119,4 @@ def _create_repair_prompt(self, validation_error: str) -> str: ''' The following is a revised JSON object: """ - return prompt - - def _convert_pythonic_comments_to_annotated_docs(schema_class, debug=False): - - def _extract_tokens_between_line_numbers(gen, start_lineno, end_lineno): - # Extract tokens between start_lineno and end_lineno obtained from the tokenize generator - tokens = [] - for tok in gen: - if tok.start[0] < start_lineno: # Skip tokens before start_lineno - continue - if tok.start[0] >= start_lineno and tok.end[0] <= end_lineno: - # Add token if it is within the range - tokens.append((tok.type, tok.string)) - elif tok.start[0] > end_lineno: # Stop if token is beyond end_lineno - break - - return tokens - - schema_path = inspect.getfile(schema_class) - - with open(schema_path, 'r') as f: - schema_class_source = f.read() - gen = tokenize.tokenize(io.BytesIO( - schema_class_source.encode('utf-8')).readline) - - tree = ast.parse(schema_class_source) - - if debug: - print("Source code before transformation:") - print("--"*50) - print(schema_class_source) - print("--"*50) - - has_comments = False # Flag later used to perform imports of Annotated and Doc if needed - - for node in tree.body: - if isinstance(node, ast.ClassDef): - for n in node.body: - if isinstance(n, ast.AnnAssign): # Check if the node is an annotated assignment - assgn_comment = None - tokens = _extract_tokens_between_line_numbers( - # Extract tokens between the line numbers of the annotated assignment - gen, n.lineno, n.end_lineno - ) - for toknum, tokval in tokens: - if toknum == tokenize.COMMENT: - # Extract the comment - assgn_comment = tokval - break - - if assgn_comment: - # If a comment is found, transform the annotation to include the comment - assgn_subscript = n.annotation - has_comments = True - n.annotation = ast.Subscript( - value=ast.Name(id="Annotated", ctx=ast.Load()), - slice=ast.Tuple( - elts=[ - assgn_subscript, - ast.Call( - func=ast.Name( - id="Doc", ctx=ast.Load() - ), - args=[ - ast.Constant( - value=assgn_comment.strip("#").strip() - ) - ], - keywords=[] - ) - ], - ctx=ast.Load() - ), - ctx=ast.Load() - ) - - if has_comments: - for node in tree.body: - if isinstance(node, ast.ImportFrom): - if node.module == "typing_extensions": - if ast.alias(name="Annotated") not in node.names: - node.names.append(ast.alias(name="Annotated")) - if ast.alias(name="Doc") not in node.names: - node.names.append(ast.alias(name="Doc")) - - transformed_schema_source = ast.unparse(tree) - - if debug: - print("Source code after transformation:") - print("--"*50) - print(transformed_schema_source) - print("--"*50) - - namespace = {} - exec(transformed_schema_source, namespace) - return namespace[schema_class.__name__] + return prompt \ No newline at end of file diff --git a/python/examples/math/schema_with_comments.py b/python/utils/python_comment_handler/examples/commented_math_schema.py similarity index 100% rename from python/examples/math/schema_with_comments.py rename to python/utils/python_comment_handler/examples/commented_math_schema.py diff --git a/python/utils/python_comment_handler/examples/commented_music_schema.py b/python/utils/python_comment_handler/examples/commented_music_schema.py new file mode 100644 index 00000000..c4376a29 --- /dev/null +++ b/python/utils/python_comment_handler/examples/commented_music_schema.py @@ -0,0 +1,289 @@ +from typing_extensions import Literal, Required, NotRequired, TypedDict + + +class unknownActionParameters(TypedDict): + text: str # text typed by the user that the system did not understand + + +class UnknownAction(TypedDict): + """ + Use this action for requests that weren't understood + """ + + actionName: Literal["Unknown"] + text: unknownActionParameters + + +class EmptyParameters(TypedDict): + pass + + +class PlayParameters(TypedDict, total=False): + artist: str # artist (performer, composer) to search for to play + album: str # album to search for to play + trackName: str # track to search for to play + query: str # other description to search for to play + itemType: Literal["track", "album"] # this property is only used when the user specifies the item type + quantity: Required[int] # number of items to play, examples: three, a/an (=1), a few (=3), a couple of (=2), some (=5). Use -1 for all, 0 if unspecified. + trackNumber: int # play the track at this index in the current track list + trackRange: list[int] # play this range of tracks example 1-3 + + +class PlayAction(TypedDict): + """ + play a track, album, or artist; this action is chosen over search if both could apply + with no parameters, play means resume playback + """ + + actionName: Literal["play"] + parameters: PlayParameters + + +class StatusAction(TypedDict): + """ + show now playing including track information, and playback status including playback device + """ + + actionName: Literal["status"] + parameters: EmptyParameters + + +class PauseAction(TypedDict): + """ + pause playback + """ + + actionName: Literal["pause"] + parameters: EmptyParameters + + +class ResumeAction(TypedDict): + """ + resume playback + """ + + actionName: Literal["resume"] + parameters: EmptyParameters + + +class NextAction(TypedDict): + """ + next track + """ + + actionName: Literal["next"] + parameters: EmptyParameters + + +class PreviousAction(TypedDict): + """ + previous track + """ + + actionName: Literal["previous"] + parameters: EmptyParameters + + +class ShuffleActionParameters(TypedDict): + on: bool + + +class ShuffleAction(TypedDict): + """ + turn shuffle on or off + """ + + actionName: Literal["shuffle"] + parameters: ShuffleActionParameters + + +class ListDevicesAction(TypedDict): + """ + list available playback devices + """ + + actionName: Literal["listDevices"] + parameters: EmptyParameters + + +class SelectDeviceActionParameters(TypedDict): + keyword: str # keyword to match against device name + + +class SelectDeviceAction(TypedDict): + """ + select playback device by keyword + """ + + actionName: Literal["selectDevice"] + parameters: SelectDeviceActionParameters + + +class SelectVolumeActionParameters(TypedDict): + newVolumeLevel: int # new volume level + + +class SetVolumeAction(TypedDict): + """ + set volume + """ + + actionName: Literal["setVolume"] + parameters: SelectVolumeActionParameters + + +class ChangeVolumeActionParameters(TypedDict): + volumeChangePercentage: int # volume change percentage + + +class ChangeVolumeAction(TypedDict): + """ + change volume plus or minus a specified percentage + """ + + actionName: Literal["changeVolume"] + parameters: ChangeVolumeActionParameters + + +class SearchTracksActionParameters(TypedDict): + query: str # the part of the request specifying the the search keywords examples: song name, album name, artist name + + +class SearchTracksAction(TypedDict): + """ + this action is only used when the user asks for a search as in 'search', 'find', 'look for' + query is a Spotify search expression such as 'Rock Lobster' or 'te kanawa queen of night' + set the current track list to the result of the search + """ + + actionName: Literal["searchTracks"] + parameters: SearchTracksActionParameters + + +class ListPlaylistsAction(TypedDict): + """ + list all playlists + """ + + actionName: Literal["listPlaylists"] + parameters: EmptyParameters + + +class GetPlaylistActionParameters(TypedDict): + name: str # name of playlist to get + + +class GetPlaylistAction(TypedDict): + """ + get playlist by name + """ + + actionName: Literal["getPlaylist"] + parameters: GetPlaylistActionParameters + + +class GetAlbumActionParameters(TypedDict): + name: str # name of album to get + + +class GetAlbumAction(TypedDict): + """ + get album by name; if name is "", use the currently playing track + set the current track list the tracks in the album + """ + + actionName: Literal["getAlbum"] + parameters: GetPlaylistActionParameters + + +class GetFavoritesActionParameters(TypedDict): + count: NotRequired[int] # number of favorites to get + + +class GetFavoritesAction(TypedDict): + """ + Set the current track list to the user's favorite tracks + """ + + actionName: Literal["getFavorites"] + parameters: GetFavoritesActionParameters + + +class FilterTracksActionParameters(TypedDict): + filterType: Literal["genre", "artist", "name"] # filter type is one of 'genre', 'artist', 'name'; name does a fuzzy match on the track name + filterValue: str # filter value is the value to match against + negate: NotRequired[bool] # if negate is true, keep the tracks that do not match the filter + + +class FilterTracksAction(TypedDict): + """ + apply a filter to match tracks in the current track list + set the current track list to the tracks that match the filter + """ + + actionName: Literal["filterTracks"] + parameters: FilterTracksActionParameters + + +class CreatePlaylistActionParameters(TypedDict): + name: str # name of playlist to create + + +class CreatePlaylistAction(TypedDict): + """ + create a new playlist from the current track list + """ + + actionName: Literal["createPlaylist"] + parameters: CreatePlaylistActionParameters + + +class DeletePlaylistActionParameters(TypedDict): + name: str # name of playlist to delete + + +class DeletePlaylistAction(TypedDict): + """ + delete a playlist + """ + + actionName: Literal["deletePlaylist"] + parameters: DeletePlaylistActionParameters + + +class GetQueueAction(TypedDict): + """ + set the current track list to the queue of upcoming tracks + """ + + actionName: Literal["getQueue"] + parameters: EmptyParameters + + +PlayerAction = ( + PlayAction + | StatusAction + | PauseAction + | ResumeAction + | NextAction + | PreviousAction + | ShuffleAction + | ListDevicesAction + | SelectDeviceAction + | SetVolumeAction + | ChangeVolumeAction + | SearchTracksAction + | ListPlaylistsAction + | GetPlaylistAction + | GetAlbumAction + | GetFavoritesAction + | FilterTracksAction + | CreatePlaylistAction + | DeletePlaylistAction + | GetQueueAction + | UnknownAction +) + + +class PlayerActions(TypedDict): + actions: list[PlayerAction] diff --git a/python/utils/python_comment_handler/examples/commented_restaurant_schema.py b/python/utils/python_comment_handler/examples/commented_restaurant_schema.py new file mode 100644 index 00000000..d1c31f34 --- /dev/null +++ b/python/utils/python_comment_handler/examples/commented_restaurant_schema.py @@ -0,0 +1,46 @@ +from typing_extensions import Literal, Required, NotRequired, TypedDict + + +class UnknownText(TypedDict): + """ + Use this type for order items that match nothing else + """ + + itemType: Literal["Unknown"] + text: str # The text that wasn't understood + + +class Pizza(TypedDict, total=False): + itemType: Required[Literal["Pizza"]] + size: Literal["small", "medium", "large", "extra large"] # default: large + addedToppings: list[str] # toppings requested (examples: pepperoni, arugula + removedToppings: list[str] # toppings requested to be removed (examples: fresh garlic, anchovies + quantity: int # default: 1 + name: Literal["Hawaiian", "Yeti", "Pig In a Forest", "Cherry Bomb"] # used if the requester references a pizza by name + + +class Beer(TypedDict): + itemType: Literal["Beer"] + kind: str # examples: Mack and Jacks, Sierra Nevada Pale Ale, Miller Lite + quantity: NotRequired[int] # default: 1 + + +SaladSize = Literal["half", "whole"] + +SaladStyle = Literal["Garden", "Greek"] + + +class Salad(TypedDict, total=False): + itemType: Required[Literal["Salad"]] + portion: str # default: half + style: str # default: Garden + addedIngredients: list[str] # ingredients requested (examples: parmesan, croutons) + removedIngredients: list[str] # ingredients requested to be removed (example: red onions) + quantity: int # default: 1 + + +OrderItem = Pizza | Beer | Salad + + +class Order(TypedDict): + items: list[OrderItem | UnknownText] diff --git a/python/utils/python_comment_handler/examples/commented_sentiment_schema.py b/python/utils/python_comment_handler/examples/commented_sentiment_schema.py new file mode 100644 index 00000000..dd7f60c1 --- /dev/null +++ b/python/utils/python_comment_handler/examples/commented_sentiment_schema.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass +from typing_extensions import Literal, Annotated, Doc + +@dataclass +class Sentiment: + """ + The following is a schema definition for determining the sentiment of a some user input. + """ + + sentiment: Literal["negative", "neutral", "positive"] # The sentiment for the text diff --git a/python/utils/python_comment_handler/python_comment_handler.py b/python/utils/python_comment_handler/python_comment_handler.py new file mode 100644 index 00000000..eb3e3501 --- /dev/null +++ b/python/utils/python_comment_handler/python_comment_handler.py @@ -0,0 +1,144 @@ +import tokenize +import ast +import io +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--in_path", '-i', type=str, required=True, help='Path to the schema file containing pythonic comments') +parser.add_argument("--out_path", '-o', type=str, required=True, help='Path to the output file containing the transformed schema') +parser.add_argument("--debug", '-d', action='store_true', help='Print debug information') + +class PythonCommentHandler: + def __init__(self, in_schema_path, out_schema_path, debug=False): + self.in_schema_path = in_schema_path + self.out_schema_path = out_schema_path + self.debug = debug + + def _convert_pythonic_comments_to_annotated_docs(self): + + def _extract_tokens_between_line_numbers(gen, start_lineno, end_lineno): + # Extract tokens between start_lineno and end_lineno obtained from the tokenize generator + tokens = [] + for tok in gen: + if tok.start[0] < start_lineno: # Skip tokens before start_lineno + continue + if tok.start[0] >= start_lineno and tok.end[0] <= end_lineno: + # Add token if it is within the range + tokens.append((tok.type, tok.string)) + elif tok.start[0] > end_lineno: # Stop if token is beyond end_lineno + break + + return tokens + + with open(self.in_schema_path, 'r') as f: + schema_class_source = f.read() + gen = tokenize.tokenize(io.BytesIO( + schema_class_source.encode('utf-8')).readline) + + tree = ast.parse(schema_class_source) + + if self.debug: + print("Source code before transformation:") + print("--"*50) + print(schema_class_source) + print("--"*50) + + has_comments = False # Flag later used to perform imports of Annotated and Doc if needed + + for node in tree.body: + if isinstance(node, ast.ClassDef): + for n in node.body: + if isinstance(n, ast.AnnAssign): # Check if the node is an annotated assignment + assgn_comment = None + tokens = _extract_tokens_between_line_numbers( + # Extract tokens between the line numbers of the annotated assignment + gen, n.lineno, n.end_lineno + ) + for toknum, tokval in tokens: + if toknum == tokenize.COMMENT: + # Extract the comment + assgn_comment = tokval + # Remove the '#' character and any leading/trailing whitespaces + assgn_comment = assgn_comment.strip("#").strip() + break + + if assgn_comment: + # If a comment is found, transform the annotation to include the comment + assgn_subscript = n.annotation + has_comments = True + if isinstance(assgn_subscript, ast.Subscript) and (assgn_subscript.value.id == "Required" or assgn_subscript.value.id == "NotRequired"): + # If the annotation is a Required or NotRequired type, add the Annotated and Doc to inner type + n.annotation = ast.Subscript( + value=ast.Name(id=assgn_subscript.value.id, ctx=ast.Load()), + slice=ast.Subscript( + value=ast.Name(id="Annotated", ctx=ast.Load()), + slice=ast.Tuple( + elts=[ + assgn_subscript.slice, + ast.Call( + func=ast.Name( + id="Doc", ctx=ast.Load() + ), + args=[ + ast.Constant( + value=assgn_comment + ) + ], + keywords=[] + ) + ], + ctx=ast.Load() + ), + ctx=ast.Load() + ), + ctx=ast.Load() + ) + else: + n.annotation = ast.Subscript( + value=ast.Name(id="Annotated", ctx=ast.Load()), + slice=ast.Tuple( + elts=[ + assgn_subscript, + ast.Call( + func=ast.Name( + id="Doc", ctx=ast.Load() + ), + args=[ + ast.Constant( + value=assgn_comment + ) + ], + keywords=[] + ) + ], + ctx=ast.Load() + ), + ctx=ast.Load() + ) + + if has_comments: + for node in tree.body: + if isinstance(node, ast.ImportFrom): + if node.module == "typing_extensions": + if ast.alias(name="Annotated") not in node.names: + node.names.append(ast.alias(name="Annotated")) + if ast.alias(name="Doc") not in node.names: + node.names.append(ast.alias(name="Doc")) + + transformed_schema_source = ast.unparse(tree) + + if self.debug: + print("Source code after transformation:") + print("--"*50) + print(transformed_schema_source) + print("--"*50) + + with open(self.out_schema_path, 'w') as f: + f.write(transformed_schema_source) + + +if __name__ == "__main__": + args = parser.parse_args() + handler = PythonCommentHandler(args.in_path, args.out_path, args.debug) + handler._convert_pythonic_comments_to_annotated_docs() +