Skip to content

Commit

Permalink
fix: Use model revision in results folder (#842)
Browse files Browse the repository at this point in the history
* use model revision in results folder
* make lint
* tests (not dataset missing ones) passing
* load specified model revision
* check for model.revision first
* use no_revision_available
* make lint
* add revisions to test dir
* points
* make lint
  • Loading branch information
isaac-chung committed Jun 1, 2024
1 parent 5fa2aee commit 2c6065b
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 10 deletions.
2 changes: 2 additions & 0 deletions docs/mmteb/points/842.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"GitHub": "isaac-chung", "Bug fixes": 3}
{"GitHub": "KennethEnevoldsen", "Review PR": 2}
12 changes: 11 additions & 1 deletion mteb/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def main():
help="Display the available tasks",
)

## model revision
parser.add_argument(
"--model_revision",
type=str,
default=None,
help="Revision of the model to be loaded. Revisions are automatically read if the model is loaded from huggingface. ",
)

args = parser.parse_args()

# set logging based on verbosity level
Expand All @@ -142,7 +150,9 @@ def main():
if args.output_folder is None:
args.output_folder = f"results/{_name_to_path(args.model)}"

model = SentenceTransformer(args.model, device=args.device)
model = SentenceTransformer(
args.model, device=args.device, revision=args.model_revision
)

tasks = mteb.get_tasks(
categories=args.categories,
Expand Down
15 changes: 14 additions & 1 deletion mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,19 @@ def run(

# Create output folder
if output_folder is not None:
pathlib.Path(output_folder).mkdir(parents=True, exist_ok=True)
# check for model revision. For SentenceTransformers, this is available from sentence-transformers==3.0.0.
# if not found, check for supplied model revision.
if hasattr(model, "revision"):
model_revision = model.revision
else:
try:
model_revision = model.model_card_data.base_model_revision
except AttributeError:
logger.warning("Please supply model revision.")
model_revision = "no_revision_available"
pathlib.Path(os.path.join(output_folder, model_revision)).mkdir(
parents=True, exist_ok=True
)

# Run selected tasks
logger.info(f"\n\n## Evaluating {len(self.tasks)} tasks:")
Expand All @@ -316,6 +328,7 @@ def run(
if output_folder is not None:
save_path = os.path.join(
output_folder,
model_revision,
f"{task.metadata_dict['name']}{task.save_suffix}.json",
)
if os.path.exists(save_path) and overwrite_results is False:
Expand Down
1 change: 1 addition & 0 deletions mteb/tasks/Clustering/eng/ArXivHierarchicalClustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools

from datasets import Dataset, DatasetDict

from mteb.abstasks.AbsTaskClusteringFast import AbsTaskClusteringFast
from mteb.abstasks.TaskMetadata import TaskMetadata

Expand Down
1 change: 1 addition & 0 deletions mteb/tasks/PairClassification/multilingual/XStance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from datasets import load_dataset

from mteb.abstasks.TaskMetadata import TaskMetadata

from ....abstasks import MultilingualTask
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"requests>=2.26.0",
"scikit_learn>=1.0.2",
"scipy",
"sentence_transformers>=2.2.0",
"sentence_transformers>=3.0.0",
"torch",
"tqdm",
"rich",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"dataset_revision": "22793b6a6465bf00120ad525e38c51210858132c",
"evaluation_time": 5.2783849239349365,
"kg_co2_emissions": null,
"mteb_version": "1.8.3",
"scores": {
"test": [
{
"hf_subset": "default",
"languages": [
"deu-Latn"
],
"main_score": 0.328715325522886,
"v_measure": 0.328715325522886,
"v_measures": {
"Level 0": [
0.3471582917364298,
0.33187678394488757,
0.3339154791528878,
0.33650398216453725,
0.32595313807570325,
0.32245265936378653,
0.31757950383888045,
0.3359519050422736,
0.3242846547475274,
0.3114768571619466
]
}
}
]
},
"task_name": "BlurbsClusteringS2S.v2"
}
11 changes: 6 additions & 5 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@ def test_available_tasks():
def test_run_task(
model_name: str = "average_word_embeddings_komninos",
task_name="BornholmBitextMining",
model_revision="21eec43590414cb8e3a6f654857abed0483ae36e",
):
command = f"mteb -m {model_name} -t {task_name} --verbosity 3 --output_folder tests/results/test_model"
command = f"mteb -m {model_name} -t {task_name} --verbosity 3 --output_folder tests/results/test_model --model_revision {model_revision}"
result = subprocess.run(command, shell=True, capture_output=True, text=True)
assert result.returncode == 0, "Command failed"

path = Path("tests/results/test_model")
assert path.exists(), "Output folder not created"
json_files = list(path.glob("*.json"))
results_path = Path(f"tests/results/test_model/{model_revision}")
assert results_path.exists(), "Output folder not created"
assert "model_meta.json" in [
f.name for f in json_files
f.name for f in list(path.glob("*.json"))
], "model_meta.json not found in output folder"
assert f"{task_name}.json" in [
f.name for f in json_files
f.name for f in list(results_path.glob("*.json"))
], f"{task_name} not found in output folder"
6 changes: 4 additions & 2 deletions tests/test_mteb_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,12 @@ def test_reranker_same_ndcg1():
)

# read in stage 1 and stage two and check ndcg@1 is the same
with open("tests/results/stage1/SciFact.json") as f:
with open(
"tests/results/stage1/21eec43590414cb8e3a6f654857abed0483ae36e/SciFact.json"
) as f:
stage1 = json.load(f)

with open("tests/results/stage2/SciFact.json") as f:
with open("tests/results/stage2/no_revision_available/SciFact.json") as f:
stage2 = json.load(f)

assert (
Expand Down

0 comments on commit 2c6065b

Please sign in to comment.