In [15]:

class DebuggingTask:
    reward_models = [
        dict(name='diff', lines=False, threshold=0.5),
    ]
    
class SummarizationTask:
    reward_models = [
        dict(name='rouge', ngram='l', metric='f'),
        dict(name='relevance', threshold=None),
    ]    

class QuestionAnsweringTask:
    reward_models = [
        dict(name='rouge', ngram='1', metric='f'),
        dict(name='relevance', threshold=None),
    ]    

class MathTask:
    reward_models = [
        dict(name='rouge', ngram='l', metric='f'),
    ]
    
class DateQuestionAnsweringTask:
    reward_models = [
        dict(name='rouge', ngram='l', metric='f'),
    ]


required_reward_models = []

# selected tasks can come from wandb config, or other testing suites so that we can enable them as required
selected_tasks = ['math', 'date_qa','debugging']

# config can be set either 
all_tasks = {
    'debugging': DebuggingTask, 
    'summarization': SummarizationTask, 
    'qa': QuestionAnsweringTask, 
    'math': MathTask, 
    'date_qa': DateQuestionAnsweringTask
}

for task in selected_tasks:
    if task not in all_tasks:
        raise ValueError(f'Task {task} not supported. Please choose from {all_tasks}')
    
    print(all_tasks[task])
    required_reward_models += all_tasks[task].reward_models

required_reward_models


<class '__main__.MathTask'>
<class '__main__.DateQuestionAnsweringTask'>
<class '__main__.DebuggingTask'>


[{'name': 'rouge', 'ngram': 'l', 'metric': 'f'},
 {'name': 'rouge', 'ngram': 'l', 'metric': 'f'},
 {'name': 'diff', 'lines': False, 'threshold': 0.5}]

In [16]:
import torch
import difflib
from typing import List
from angle_emb import AnglE
from torch.nn.functional import cosine_similarity
from rouge import Rouge

class RougeRewardModel:
    
    def __init__(self, ngram='l', metric='f', avg=False):
        self.ngram = ngram
        self.metric = metric
        self.avg = avg
        # TODO: Add init args to Rouge if required
        self.rouge = Rouge()
        
    def rouge_score(self, reference, completion):
        return self.rouge.get_scores(reference, completion, avg=self.avg)[0][self.metric][self.ngram]
        
    def reward(self, reference: str, completions: List[str]) -> torch.FloatTensor:
        """Compute ROUGE scores given a completion and reference pair."""

        return torch.FloatTensor([self.rouge_score(reference, completion) for completion in completions])


class RelevanceRewardModel:
    
    def __init__(self, threshold=None, device='cuda'):
        
        self.threshold = threshold
        self.model = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls')
        if device == 'cuda':
            self.model = self.model.cuda()

    def reward(self, reference: str, completions: List[str]) -> torch.FloatTensor:

        reference_embedding = self.model.encode(reference, to_numpy=False)
        
        completions_embeddings = self.model.encode(completions, to_numpy=False)
        
        return cosine_similarity(reference_embedding, completions_embeddings, dim=1)


class DiffRewardModel:
    
    def __init__(self, lines=False, threshold=None):
        self.lines = lines
        self.threshold = threshold
    
    def unified_diff(self, reference, completion):
        return len(difflib.unified_diff(reference.splitlines(), completion.splitlines()))
    
    def seq_match(self, reference, completion):
        return difflib.SequenceMatcher(None, reference, completion).ratio()
    
    def reward(self, reference: str, completions: List[str]) -> torch.FloatTensor:
        """Get the score between two strings.
        lines: If True, return a unified diff. If False, return a ratio.
        """
        
        if self.lines:
            return torch.FloatTensor([self.unified_diff(reference, completion) for completion in completions])
        else:
            return torch.FloatTensor([self.seq_match(reference, completion) for completion in completions])


all_reward_models = {
    'rouge': RougeRewardModel,
    'relevance': RelevanceRewardModel,
    'diff': DiffRewardModel
    
}

# Instantiate only the required reward models
reward_models = {}
for model in required_reward_models.copy():
    name = model.pop('name')
    reward_models[name] = all_reward_models[name](**model)


reward_models

{'rouge': <__main__.RougeRewardModel at 0x7fd557247a60>,
 'diff': <__main__.DiffRewardModel at 0x7fd557247d60>}