Skip to content

Commit

Permalink
Merge pull request #178 from SinclairHudson/json-test
Browse files Browse the repository at this point in the history
Adding JSON validity test
  • Loading branch information
benjaminye committed Jun 3, 2024
2 parents efaa6e9 + 665fd29 commit bf691d8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# experiment files
*/experiments
*/experiment
experiment/*
*/archive
*/backup
*/baseline_results
Expand Down Expand Up @@ -49,4 +50,4 @@ venv.bak/

# Coverage Report
.coverage
/htmlcov
/htmlcov
20 changes: 20 additions & 0 deletions llmtune/qa/qa_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import nltk
import numpy as np
import torch
from langchain.evaluation import JsonValidityEvaluator
from nltk import pos_tag
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
Expand All @@ -12,6 +13,7 @@
from llmtune.qa.generics import LLMQaTest


json_validity_evaluator = JsonValidityEvaluator()
model_name = "distilbert-base-uncased"
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertModel.from_pretrained(model_name)
Expand Down Expand Up @@ -120,6 +122,24 @@ def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> U
return float(overlap_percentage)


@QaTestRegistry.register("json_valid")
class JSONValidityTest(LLMQaTest):
"""
Checks to see if valid json can be parsed from the model output, according
to langchain_core.utils.json.parse_json_markdown
The JSON can be wrapped in markdown and this test will still pass
"""

@property
def test_name(self) -> str:
return "json_valid"

def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float:
result = json_validity_evaluator.evaluate_strings(prediction=model_prediction)
binary_res = result["score"]
return float(binary_res)


class PosCompositionTest(LLMQaTest):
def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
words = word_tokenize(text)
Expand Down
19 changes: 19 additions & 0 deletions tests/qa/test_qa_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AdjectivePercent,
DotProductSimilarityTest,
JaccardSimilarityTest,
JSONValidityTest,
LengthTest,
NounPercent,
RougeScoreTest,
Expand All @@ -23,6 +24,7 @@
(VerbPercent, float),
(AdjectivePercent, float),
(NounPercent, float),
(JSONValidityTest, float),
],
)
def test_metric_return_type(test_class, expected_type):
Expand Down Expand Up @@ -84,3 +86,20 @@ def test_noun_percent():
test = NounPercent()
result = test.get_metric("prompt", "The cat", "The cat and the dog")
assert result >= 0, "Noun percentage should be non-negative."


@pytest.mark.parametrize(
"input_string,expected_value",
[
('{"Answer": "The cat"}', 1),
("{'Answer': 'The cat'}", 0), # Double quotes are required in json
('{"Answer": "The cat",}', 0),
('{"Answer": "The cat", "test": "case"}', 1),
('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed
('Here is an example of a JSON block: {"Answer": "The cat"}', 0),
],
)
def test_json_valid(input_string: str, expected_value: float):
test = JSONValidityTest()
result = test.get_metric("prompt", "The cat", input_string)
assert result == expected_value, f"JSON validity should be {expected_value} but got {result}."

0 comments on commit bf691d8

Please sign in to comment.