diff --git a/prompting/rewards/ordinal.py b/prompting/rewards/ordinal.py index d8cfd433b..ac8ac4e45 100644 --- a/prompting/rewards/ordinal.py +++ b/prompting/rewards/ordinal.py @@ -12,17 +12,11 @@ def name(self) -> str: def __init__(self, **kwargs): super().__init__() #TODO: Expand to allow for more than 3 classes (Must also adjust dataset/review.py) - self.sentiments = [ - "casual", - "basic", - "silly", - "random", - "thoughtful", - "serious", - "rushed", + self.sentiments = [ + "positive", + "neutral", + "negative", ] - #NOTE: These sentimens are not the same as the sentiments defined in the dataset/review.py file. These are the subtopic - def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: """Compute difference scores given a completion and reference pair.""" @@ -34,7 +28,8 @@ def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: # Check if exactly one answer can be found in the completion if sum(option in completion for option in classes) == 1: - reward = abs(classes.index(reference) - classes.index(completion)) + answer = [option for option in classes if option in completion][0] + reward = abs(classes.index(reference) - classes.index(answer)) else: reward = 0 timings.append(time.time() - t0) @@ -47,4 +42,4 @@ def reward(self, reference: str, completions: List[str]) -> BatchRewardOutput: "type": "math", }, ) - return output \ No newline at end of file + return output