diff --git a/prompting/rewards/web_retrieval.py b/prompting/rewards/web_retrieval.py index 6de4242fc..9fbfaff91 100644 --- a/prompting/rewards/web_retrieval.py +++ b/prompting/rewards/web_retrieval.py @@ -3,6 +3,7 @@ import json import os from collections import defaultdict +from functools import lru_cache from urllib.parse import urlparse import numpy as np @@ -83,6 +84,7 @@ class WebsiteResult(BaseModel): class WebRetrievalRewardModel(RelevanceRewardModel): + @lru_cache(maxsize=1000) def _cosine_similarity(self, content1: str, content2: str) -> float: """Calculate the cosine similarity between sentence embeddings of the reference and completions.""" reference_emb_flatten = self.embedding_model.encode(content1, to_numpy=True).flatten()