diff --git a/prompting/tasks/sentiment.py b/prompting/tasks/sentiment.py index 47b8b1e73..7b4b58945 100644 --- a/prompting/tasks/sentiment.py +++ b/prompting/tasks/sentiment.py @@ -25,7 +25,6 @@ class SentimentAnalysisTask(Task): def __init__(self, llm_pipeline, context, create_reference=True): self.context = context - self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content) self.query = self.generate_query(llm_pipeline) self.reference = context.subtopic diff --git a/prompting/tools/datasets/review.py b/prompting/tools/datasets/review.py index 5c9300d91..41e512971 100644 --- a/prompting/tools/datasets/review.py +++ b/prompting/tools/datasets/review.py @@ -12,7 +12,7 @@ class ReviewDataset(TemplateDataset): SENTIMENTS = ["positive", "neutral", "negative"] # TODO: filter nonsense combinations of params # "Create a {topic} review of a {title} in the style of {mood} person in a {subtopic} tone. The review must be of {sentiment} sentiment." - query_template = "Create a {topic} review of a {title}. The review must be of {sentiment} sentiment." + query_template = "Create a {topic} review of a {title}. The review must be of {subtopic} sentiment." params = dict( topic=[ "short", @@ -44,5 +44,6 @@ class ReviewDataset(TemplateDataset): "company", "live event", ], - sentiment=SENTIMENTS, - ) \ No newline at end of file + subtopic=SENTIMENTS, + ) + \ No newline at end of file