Skip to content

Commit

Permalink
#30 make test more descriptive when failing
Browse files Browse the repository at this point in the history
  • Loading branch information
kyuridenamida committed Dec 25, 2018
1 parent 78779c5 commit f76307f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
14 changes: 13 additions & 1 deletion tests/test_fmtprediction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import tempfile
import unittest
import os
Expand All @@ -9,6 +10,9 @@
os.path.dirname(os.path.abspath(__file__)),
'./resources/test_fmtprediction/answer.txt')

fmt = "%(asctime)s %(levelname)s: %(message)s"
logging.basicConfig(level=logging.DEBUG, format=fmt)


class TestFormatPrediction(unittest.TestCase):

Expand All @@ -27,6 +31,7 @@ def test_overall(self):
output_text = ""
for case in case_names:
response = runner.run(case)

if response.status == "OK":
output_text += "{:40} {:20} {} {}\n".format(case, response.status, response.simple_format,
response.types)
Expand All @@ -37,7 +42,14 @@ def test_overall(self):
answer = f.read()

for ans, out in zip(answer.split("\n"), output_text.split("\n")):
self.assertEqual(ans, out)
if ans != out:
case_name = ans.split()[0] # case name is expected to be stored to the first column in the file
content = runner.load_problem_content(case_name)
logging.debug("=== {} ===".format(case_name))
logging.debug("Input Format:\n{}".format(content.input_format_text))
for idx, s in enumerate(content.samples):
logging.debug("Sample Input {num}:\n{inp}".format(inp=s.get_input(), num=idx + 1))
self.assertEqual(ans, out)

self.assertEqual(len(answer), len(output_text))

Expand Down
13 changes: 9 additions & 4 deletions tests/utils/fmtprediction_test_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from typing import Optional

from atcodertools.fmtprediction.predict_format import FormatPredictor, MultiplePredictionResultsError, NoPredictionResultError
from atcodertools.fmtprediction.predict_format import FormatPredictor, MultiplePredictionResultsError, \
NoPredictionResultError
from atcodertools.models.problem_content import ProblemContent
from atcodertools.models.sample import Sample
from atcodertools.models.predictor.format_prediction_result import FormatPredictionResult
Expand All @@ -28,7 +29,7 @@ def __init__(self, test_dir):
def is_valid_case(self, case_name):
return os.path.isdir(self._get_test_case_dir(case_name))

def run(self, case_name: str) -> Response:
def load_problem_content(self, case_name: str) -> ProblemContent:
case_dir = self._get_test_case_dir(case_name)
format_file = os.path.join(case_dir, FORMAT_FILE_NAME)
example_files = [os.path.join(case_dir, file)
Expand All @@ -41,10 +42,14 @@ def run(self, case_name: str) -> Response:
for ex_file in example_files:
with open(ex_file, 'r') as f:
examples.append(Sample(f.read(), None))
problem_content = ProblemContent(input_format, examples)

return ProblemContent(input_format, examples)

def run(self, case_name: str) -> Response:
content = self.load_problem_content(case_name)

try:
result = FormatPredictor.predict(problem_content)
result = FormatPredictor.predict(content)
return Response(result, "OK")
except MultiplePredictionResultsError:
return Response(None, "Multiple results")
Expand Down

0 comments on commit f76307f

Please sign in to comment.