Skip to content

Commit

Permalink
restructure dp.audit
Browse files Browse the repository at this point in the history
reduces the number of events being emitted

also reduces the amount of memory needed, by removing the list of
duration floats that it was using to compute the average loop duration.

a python float float = 24 bytes in bits => 192 bits
iters = 1,200,000,000

(float * iters) in KB => 28,125,000 KB
(float * iters) in MB => 28,800 MB
(float * iters) in GB => 28.8 GB

that's right, for a 1.2 billion loop audit, we're using 28 gigabytes of
RAM, just to store the audit loop durations.

This should reduce that overhead to, oh, about 24 bytes.
  • Loading branch information
hawkrives committed Jan 4, 2020
1 parent 211a002 commit da4c35c
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 122 deletions.
11 changes: 5 additions & 6 deletions dp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ def main() -> int: # noqa: C901
first_progress_message = False

if not cli_args.quiet or (cli_args.tracemalloc_init or cli_args.tracemalloc_each):
avg_iter_s = sum(msg.recent_iters) / max(len(msg.recent_iters), 1)
avg_iter_time = pretty_ms(avg_iter_s * 1_000, format_sub_ms=True)
print(f"{msg.count:,} at {avg_iter_time} per audit (best: {msg.best_rank})", file=sys.stderr)
avg_iter_time = pretty_ms(msg.avg_iter_ms, format_sub_ms=True)
print(f"{msg.iters:,} at {avg_iter_time} per audit (best: {msg.best_rank})", file=sys.stderr)

elif isinstance(msg, ResultMsg):
if not cli_args.quiet:
Expand Down Expand Up @@ -158,9 +157,9 @@ def result_str(
return "\n" + "".join(summarize(
result=dict_result,
transcript=msg.transcript,
count=msg.count,
elapsed=msg.elapsed,
iterations=msg.iterations,
count=msg.iters,
avg_iter_ms=msg.avg_iter_ms,
elapsed=pretty_ms(msg.elapsed_ms),
show_paths=show_paths,
show_ranks=show_ranks,
claims=msg.result.keyed_claims(),
Expand Down
114 changes: 34 additions & 80 deletions dp/audit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import attr
from typing import List, Optional, Tuple, Sequence, Iterator, Union, Dict, Any
from datetime import datetime
from typing import List, Optional, Tuple, Sequence, Iterator, Union, Dict
from decimal import Decimal
import time

from .constants import Constants
from .exception import RuleException
from .area import AreaOfStudy, AreaResult
from .ms import pretty_ms
from .data import CourseInstance, AreaPointer, MusicAttendance, MusicPerformance, MusicProficiencies


Expand All @@ -22,37 +20,13 @@ class Arguments:
progress_every: int = 1_000


@attr.s(slots=True, kw_only=True, auto_attribs=True)
class NoStudentsMsg:
pass


@attr.s(slots=True, kw_only=True, auto_attribs=True)
class AuditStartMsg:
stnum: str
area_code: str
area_catalog: str
student: Dict[str, Any]


@attr.s(slots=True, kw_only=True, auto_attribs=True)
class ResultMsg:
result: AreaResult
transcript: Tuple[CourseInstance, ...]
count: int
elapsed: str
iters: int
avg_iter_ms: float
elapsed_ms: float
iterations: List[float]
startup_time: float
potentials_for_all_clauses: Dict[int, List[str]]


@attr.s(slots=True, kw_only=True, auto_attribs=True)
class ExceptionMsg:
ex: Exception
tb: str
stnum: Optional[str]
area_code: Optional[str]


@attr.s(slots=True, kw_only=True, auto_attribs=True)
Expand All @@ -67,21 +41,14 @@ class EstimateMsg:

@attr.s(slots=True, kw_only=True, auto_attribs=True)
class ProgressMsg:
count: int
recent_iters: List[float]
start_time: datetime
best_rank: Union[int, Decimal]


@attr.s(slots=True, kw_only=True, auto_attribs=True)
class AreaFileNotFoundMsg:
area_file: str
stnums: Sequence[str]
iters: int
avg_iter_ms: float
elapsed_ms: float


Message = Union[
EstimateMsg,
ExceptionMsg,
NoAuditsCompletedMsg,
ProgressMsg,
ResultMsg,
Expand All @@ -100,15 +67,11 @@ def audit(
music_proficiencies: MusicProficiencies = MusicProficiencies(),
transcript: Tuple[CourseInstance, ...] = tuple(),
transcript_with_failed: Tuple[CourseInstance, ...] = tuple(),
) -> Iterator[Message]: # noqa: C901
) -> Iterator[Message]:
best_sol: Optional[AreaResult] = None
best_rank: Union[int, Decimal] = 0
total_count = 0
iterations: List[float] = []
start_time = datetime.now()
start = time.perf_counter()
iter_start = time.perf_counter()
startup_time = 0.00
total_count = 0

estimate = area.estimate(
transcript=transcript,
Expand All @@ -124,8 +87,6 @@ def audit(
if args.estimate_only:
return

potentials_for_all_clauses = find_potentials(area, constants)

for sol in area.solutions(
transcript=transcript,
areas=tuple(area_pointers),
Expand All @@ -135,73 +96,66 @@ def audit(
exceptions=list(exceptions),
transcript_with_failed=transcript_with_failed,
):
if total_count == 0:
iter_start = time.perf_counter()
startup_time = time.perf_counter() - iter_start

total_count += 1

result = sol.audit()
result_rank = result.rank()

if total_count % args.progress_every == 0:
elapsed_ms = ms_since(start)
yield ProgressMsg(
count=total_count,
recent_iters=iterations[-args.progress_every:],
start_time=start_time,
best_rank=best_sol.rank() if best_sol else 0,
best_rank=best_rank,
iters=total_count,
avg_iter_ms=elapsed_ms / total_count,
elapsed_ms=elapsed_ms,
)

result = sol.audit()
result_rank = result.rank()

if args.print_all:
elapsed_ms = ms_since(start)
yield ResultMsg(
result=result,
transcript=transcript,
count=total_count,
elapsed='∞',
elapsed_ms=0,
iterations=[],
startup_time=startup_time,
potentials_for_all_clauses=potentials_for_all_clauses,
iters=total_count,
avg_iter_ms=elapsed_ms / total_count,
elapsed_ms=elapsed_ms,
)

# if this is the first solution, store it, because it's the best so far
if best_sol is None:
best_sol, best_rank = result, result_rank
elif result_rank > best_rank:

# if the current solution is better, then store it
if result_rank > best_rank:
best_sol, best_rank = result, result_rank

# if the current solution is OK, then store it, and end the loop
if result.ok():
best_sol, best_rank = result, result_rank
iter_end = time.perf_counter()
iterations.append(iter_end - iter_start)
break

iter_end = time.perf_counter()
iterations.append(iter_end - iter_start)
iter_start = time.perf_counter()

if args.stop_after is not None and total_count >= args.stop_after:
break

if not best_sol:
yield NoAuditsCompletedMsg()
return

end = time.perf_counter()
elapsed_ms = (end - start) * 1000
elapsed = pretty_ms(elapsed_ms)

elapsed_ms = ms_since(start)
yield ResultMsg(
result=best_sol,
transcript=transcript,
count=total_count,
elapsed=elapsed,
iters=total_count,
avg_iter_ms=elapsed_ms / total_count,
elapsed_ms=elapsed_ms,
iterations=iterations,
startup_time=startup_time,
potentials_for_all_clauses=potentials_for_all_clauses,
)


def ms_since(start: float, *, now: Optional[float] = None) -> float:
if now is None:
now = time.perf_counter()
return (now - start) * 1000


def find_potentials(area: AreaOfStudy, constants: Constants) -> Dict[int, List[str]]:
return {}

Expand Down
14 changes: 6 additions & 8 deletions dp/bin/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@ def main() -> int: # noqa: C901

elif isinstance(msg, ProgressMsg):
if not cli_args.quiet:
avg_iter_s = sum(msg.recent_iters) / max(len(msg.recent_iters), 1)
avg_iter_time = pretty_ms(avg_iter_s * 1_000, format_sub_ms=True)
print(f"{msg.count:,} at {avg_iter_time} per audit (best: {msg.best_rank})", file=sys.stderr)
avg_iter_time = pretty_ms(msg.avg_iter_ms, format_sub_ms=True)
print(f"{msg.iters:,} at {avg_iter_time} per audit (best: {msg.best_rank})", file=sys.stderr)

elif isinstance(msg, ResultMsg):
result = json.loads(json.dumps(msg.result.to_dict()))
if cli_args.table:
avg_iter_s = sum(msg.iterations) / max(len(msg.iterations), 1)
avg_iter_time = pretty_ms(avg_iter_s * 1_000, format_sub_ms=True)
avg_iter_time = pretty_ms(msg.avg_iter_ms, format_sub_ms=True)
print(','.join([
stnum,
catalog,
Expand All @@ -89,9 +87,9 @@ def main() -> int: # noqa: C901
print("\n" + "".join(summarize(
result=result,
transcript=msg.transcript,
count=msg.count,
elapsed=msg.elapsed,
iterations=msg.iterations,
count=msg.iters,
avg_iter_ms=msg.avg_iter_ms,
elapsed=pretty_ms(msg.elapsed_ms),
show_paths=cli_args.show_paths,
show_ranks=cli_args.show_ranks,
claims=msg.result.keyed_claims(),
Expand Down
28 changes: 7 additions & 21 deletions dp/server/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,15 @@ def audit(*, area_spec: Dict, area_code: str, area_catalog: str, student: Dict,
pass

elif isinstance(msg, ProgressMsg):
avg_iter_s = sum(msg.recent_iters) / max(len(msg.recent_iters), 1)
avg_iter_time = pretty_ms(avg_iter_s * 1_000, format_sub_ms=True, unit_count=1)
avg_iter_time = pretty_ms(msg.avg_iter_ms, format_sub_ms=True)

curs.execute("""
UPDATE result
SET iterations = %(count)s, duration = cast(now() - %(start_time)s as interval)
SET iterations = %(count)s, duration = interval %(elapsed)s
WHERE id = %(result_id)s
""", {"result_id": result_id, "count": msg.count, "start_time": msg.start_time})
""", {"result_id": result_id, "count": msg.iters, "elapsed": f"{msg.elapsed_ms}ms"})

logger.info(f"{msg.count:,} at {avg_iter_time} per audit")
logger.info(f"{msg.iters:,} at {avg_iter_time} per audit")

elif isinstance(msg, ResultMsg):
record(curs=curs, result_id=result_id, message=msg)
Expand All @@ -80,9 +79,6 @@ def audit(*, area_spec: Dict, area_code: str, area_catalog: str, student: Dict,
def record(*, message: ResultMsg, curs: psycopg2.extensions.cursor, result_id: int) -> None:
result = message.result.to_dict()

avg_iter_s = sum(message.iterations) / max(len(message.iterations), 1)
avg_iter_time = pretty_ms(avg_iter_s * 1_000, format_sub_ms=True, unit_count=1)

curs.execute("""
UPDATE result
SET iterations = %(total_count)s
Expand All @@ -99,23 +95,13 @@ def record(*, message: ResultMsg, curs: psycopg2.extensions.cursor, result_id: i
WHERE id = %(result_id)s
""", {
"result_id": result_id,
"total_count": message.count,
"elapsed": message.elapsed,
"avg_iter_time": avg_iter_time.strip("~"),
"total_count": message.iters,
"elapsed": f"{message.elapsed_ms}ms",
"avg_iter_time": f"{message.avg_iter_ms}ms",
"result": json.dumps(result),
"claimed_courses": json.dumps(message.result.keyed_claims()),
"rank": result["rank"],
"max_rank": result["max_rank"],
"gpa": result["gpa"],
"ok": result["ok"],
})

for clause_hash, clbids in message.potentials_for_all_clauses.items():
curs.execute("""
INSERT INTO potential_clbids (result_id, clause_hash, clbids)
VALUES (%(result_id)s, %(clause_hash)s, %(clbids)s)
""", {
"result_id": result_id,
"clause_hash": clause_hash,
"clbids": clbids,
})
5 changes: 2 additions & 3 deletions dp/stringify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ def summarize(
result: Dict[str, Any],
count: int,
elapsed: str,
iterations: List[float],
avg_iter_ms: float,
show_paths: bool = True,
show_ranks: bool = True,
claims: Dict[str, List[List[str]]],
) -> Iterator[str]:
avg_iter_s = sum(iterations) / max(len(iterations), 1)
avg_iter_time = pretty_ms(avg_iter_s * 1_000, format_sub_ms=True)
avg_iter_time = pretty_ms(avg_iter_ms, format_sub_ms=True)
mapped_transcript = {c.clbid: c for c in transcript}
endl = "\n"

Expand Down
6 changes: 2 additions & 4 deletions dp/testbed/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .sqlite import sqlite_connect, sqlite_cursor

from dp.run import run
from dp.audit import ResultMsg, Arguments, EstimateMsg, AuditStartMsg
from dp.audit import ResultMsg, Arguments, EstimateMsg


def audit(
Expand Down Expand Up @@ -65,7 +65,7 @@ def audit(
"stnum": stnum,
"catalog": catalog,
"code": code,
"iterations": message.count,
"iterations": message.iters,
"duration": message.elapsed_ms / 1000,
"gpa": result["gpa"],
"ok": result["ok"],
Expand Down Expand Up @@ -111,8 +111,6 @@ def estimate(
for message in run(args=Arguments(estimate_only=True), student=student, area_spec=area_spec):
if isinstance(message, EstimateMsg):
return message.estimate
elif isinstance(message, AuditStartMsg):
pass
else:
assert False, type(message)

Expand Down

0 comments on commit da4c35c

Please sign in to comment.