Skip to content

Commit

Permalink
Merge branch 'tweaks'
Browse files Browse the repository at this point in the history
  • Loading branch information
gwax committed Jul 3, 2016
2 parents 42a36c8 + 2a6831f commit a3b68f7
Show file tree
Hide file tree
Showing 14 changed files with 126 additions and 147 deletions.
20 changes: 11 additions & 9 deletions mtg_ssm/mtg/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def find_printing(cdb, set_code, name, set_number, multiverseid, strict=True):
])
for snnm_key in snnm_keys:
found_printings = cdb.set_name_num_mv_to_printings.get(snnm_key, [])
if len(found_printings) == 1 or found_printings and not strict:
if len(found_printings) == 1:
return found_printings[0]
elif found_printings and not strict:
return sorted(found_printings, key=lambda p: p.id_)[0]

return None

Expand All @@ -84,7 +86,7 @@ def coerce_card_row(card_count):
return card_count


def aggregate_print_counts(cdb, card_rows, strict=True):
def aggregate_print_counts(cdb, card_rows, strict):
"""Given a card database Iterable[card_row], return print_counts"""
print_counts = collections.defaultdict(
lambda: collections.defaultdict(int))
Expand All @@ -109,28 +111,28 @@ def aggregate_print_counts(cdb, card_rows, strict=True):
ct_name = count_type.name
count = card_row.get(ct_name)
if count:
print_counts[printing][count_type] += count
print_counts[printing.id_][count_type] += count
return print_counts


def merge_print_counts(*print_counts_args):
"""Merge two sets of print_counts."""
print_counts = new_print_counts()
for in_print_counts in print_counts_args:
for printing, counts in in_print_counts.items():
for print_id, counts in in_print_counts.items():
for key, value in counts.items():
print_counts[printing][key] += value
print_counts[print_id][key] += value
return print_counts


def diff_print_counts(left, right):
"""Subtract right print counts from left print counts."""
print_counts = new_print_counts()
for printing in left.keys() | right.keys():
left_counts = left.get(printing, {})
right_counts = right.get(printing, {})
for print_id in left.keys() | right.keys():
left_counts = left.get(print_id, {})
right_counts = right.get(print_id, {})
for key in left_counts.keys() | right_counts.keys():
value = left_counts.get(key, 0) - right_counts.get(key, 0)
if value:
print_counts[printing][key] = value
print_counts[print_id][key] = value
return print_counts
7 changes: 4 additions & 3 deletions mtg_ssm/mtg/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import datetime as dt
import string
import weakref

VARIANT_CHARS = (string.ascii_letters + '★')
STRICT_BASICS = {'Plains', 'Island', 'Swamp', 'Mountain', 'Forest'}
Expand All @@ -12,7 +13,7 @@ class Card:
__slots__ = ('cdb', 'name', 'layout', 'names')

def __init__(self, card_db, card_data):
self.cdb = card_db
self.cdb = weakref.proxy(card_db)
self.name = card_data['name']
self.layout = card_data['layout']
self.names = card_data.get('names', [self.name])
Expand Down Expand Up @@ -41,7 +42,7 @@ class CardPrinting:
'counts')

def __init__(self, card_db, set_code, card_data):
self.cdb = card_db
self.cdb = weakref.proxy(card_db)
self.id_ = card_data['id']
self.card_name = card_data['name']
self.set_code = set_code
Expand Down Expand Up @@ -85,7 +86,7 @@ class CardSet:
'type_', 'online_only')

def __init__(self, card_db, set_data):
self.cdb = card_db
self.cdb = weakref.proxy(card_db)
self.code = set_data['code']
self.name = set_data['name']
self.block = set_data.get('block')
Expand Down
4 changes: 2 additions & 2 deletions mtg_ssm/serialization/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def rows_for_printings(cdb, print_counts, verbose):
"""Generator that yields csv rows from a card_db."""
for card_set in cdb.card_sets:
for printing in card_set.printings:
printing_counts = print_counts.get(printing, {})
printing_counts = print_counts.get(printing.id_, {})
if verbose or any(printing_counts):
yield row_for_printing(printing, printing_counts)

Expand All @@ -58,7 +58,7 @@ def read(self, filename: str):
"""Read print counts from file."""
with open(filename, 'r') as csv_file:
return counts.aggregate_print_counts(
self.cdb, csv.DictReader(csv_file))
self.cdb, csv.DictReader(csv_file), strict=True)


class CsvTerseDialect(CsvFullDialect):
Expand Down
2 changes: 1 addition & 1 deletion mtg_ssm/serialization/deckbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def rows_for_printing(printing, print_counts):
}
if name is not None:
row_base['Name'] = name
row_counts = print_counts.get(printing, {})
row_counts = print_counts.get(printing.id_, {})
copies = row_counts.get(counts.CountTypes.copies, 0)
foils = row_counts.get(counts.CountTypes.foils, 0)
if copies:
Expand Down
3 changes: 1 addition & 2 deletions mtg_ssm/serialization/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from mtg_ssm.mtg import card_db
from mtg_ssm.mtg import counts
from mtg_ssm.mtg import models


class Error(Exception):
Expand Down Expand Up @@ -50,7 +49,7 @@ def write(self, filename: str, print_counts) -> None:

@abc.abstractmethod
def read(self, filename: str) -> Dict[
models.CardPrinting, Dict[counts.CountTypes, int]]:
str, Dict[counts.CountTypes, int]]:
"""Read print counts from file."""

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions mtg_ssm/serialization/xlsx.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def create_set_sheet(sheet, card_set, print_counts):
printing.set_number,
printing.artist,
]
row_counts = print_counts.get(printing, {})
row_counts = print_counts.get(printing.id_, {})
for counttype in counts.CountTypes:
row.append(row_counts.get(counttype))
row.append(get_references(printing.card, exclude_sets={card_set}))
Expand Down Expand Up @@ -235,5 +235,5 @@ def read(self, filename: str):
'No known set with code {}'.format(sheet.title))
print_counts = counts.merge_print_counts(
print_counts, counts.aggregate_print_counts(
self.cdb, counts_from_sheet(sheet)))
self.cdb, counts_from_sheet(sheet), strict=True))
return print_counts
12 changes: 10 additions & 2 deletions mtg_ssm/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,22 @@ def get_serializer(cdb, dialect_mapping, filename):
return serialization_class(cdb)


def get_backup_name(filename):
"""Given a filename, return a timestamped backup name for the file."""
basename, extension = os.path.splitext(filename)
extension = extension.lstrip('.')
now = datetime.datetime.now()
return '{basename}.{now:%Y%m%d_%H%M%S}.{extension}'.format(
basename=basename, now=now, extension=extension)


def write_file(serializer, print_counts, filename):
"""Write print counts to a file, backing up existing target files."""
if not os.path.exists(filename):
print('Writing collection to file.')
serializer.write(filename, print_counts)
else:
backup_name = filename + '.bak-{:%Y%m%d_%H%M%S}'.format(
datetime.datetime.now())
backup_name = get_backup_name(filename)
with tempfile.NamedTemporaryFile() as temp_coll:
print('Writing to temporary file.')
serializer.write(temp_coll.name, print_counts)
Expand Down
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
xfail_strict=true
108 changes: 39 additions & 69 deletions tests/mtg/test_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,78 +29,48 @@ def test_coerce_card_row(raw_card_row, expected_card_row):
assert counts.coerce_card_row(raw_card_row) == expected_card_row


# aggregate_print_counts tests
def test_apc_bad_printing(cdb):
card_counts = [{}]
with pytest.raises(counts.UnknownPrintingError):
counts.aggregate_print_counts(cdb, card_counts)


def test_apc_no_counts(cdb):
card_counts = [{'id': TEST_PRINT_ID}]
print_counts = counts.aggregate_print_counts(cdb, card_counts)
assert not print_counts


def test_apc_zeros(cdb):
card_counts = [{'id': TEST_PRINT_ID, 'copies': 0, 'foils': 0}]
print_counts = counts.aggregate_print_counts(cdb, card_counts)
assert not print_counts


def test_apc_once(cdb):
print_ = cdb.id_to_printing[TEST_PRINT_ID]
card_counts = [{'id': TEST_PRINT_ID, 'copies': 1, 'foils': 2}]
print_counts = counts.aggregate_print_counts(cdb, card_counts)
assert print_counts == {
print_: {
@pytest.mark.parametrize('card_rows,strict,expected', [
([], True, {}),
pytest.mark.xfail(
([{}], True, 'N/A'),
raises=counts.UnknownPrintingError),
([{'id': TEST_PRINT_ID}], True, {}),
([{'id': TEST_PRINT_ID, 'copies': 0, 'foils': 0}], True, {}),
([{'id': TEST_PRINT_ID, 'copies': 1, 'foils': 2}], True, {
TEST_PRINT_ID: {
counts.CountTypes.copies: 1,
counts.CountTypes.foils: 2,
}
}


def test_apc_with_find(cdb):
print_ = cdb.id_to_printing[
'536d407161fa03eddee7da0e823c2042a8fa0262']
card_counts = [{'set': 'S00', 'name': 'Rhox', 'copies': 1}]
print_counts = counts.aggregate_print_counts(cdb, card_counts)
assert print_counts == {
print_: {counts.CountTypes.copies: 1}
}


def test_apc_multiple(cdb):
print1 = cdb.id_to_printing[TEST_PRINT_ID]
print2 = cdb.id_to_printing[
'536d407161fa03eddee7da0e823c2042a8fa0262']
card_counts = [
{'id': TEST_PRINT_ID, 'copies': 1, 'foils': 2},
{'set': 'S00', 'name': 'Rhox', 'copies': 1},
]
print_counts = counts.aggregate_print_counts(cdb, card_counts)
assert print_counts == {
print1: {
}}),
([{'set': 'S00', 'name': 'Rhox', 'copies': 1}], True, {
'536d407161fa03eddee7da0e823c2042a8fa0262': {
counts.CountTypes.copies: 1,
counts.CountTypes.foils: 2,
},
print2: {counts.CountTypes.copies: 1},
}


def test_apc_repeat(cdb):
print_ = cdb.id_to_printing[TEST_PRINT_ID]
card_counts = [
{'id': TEST_PRINT_ID, 'copies': 4},
{'id': TEST_PRINT_ID, 'copies': 3, 'foils': '8'},
]
print_counts = counts.aggregate_print_counts(cdb, card_counts)
assert print_counts == {
print_: {
counts.CountTypes.copies: 7,
counts.CountTypes.foils: 8,
}
}
}}),
([{'id': TEST_PRINT_ID, 'copies': 1, 'foils': 2},
{'set': 'S00', 'name': 'Rhox', 'copies': 1}],
True,
{TEST_PRINT_ID: {
counts.CountTypes.copies: 1,
counts.CountTypes.foils: 2},
'536d407161fa03eddee7da0e823c2042a8fa0262': {
counts.CountTypes.copies: 1}}),
([{'id': TEST_PRINT_ID, 'copies': 4},
{'id': TEST_PRINT_ID, 'copies': 3, 'foils': '8'}],
True,
{TEST_PRINT_ID: {
counts.CountTypes.copies: 7,
counts.CountTypes.foils: 8}}),
([{'set': 'LEA', 'name': 'Forest', 'copies': 1}], False, {
'5ede9781b0c5d157c28a15c3153a455d7d6180fa': {
counts.CountTypes.copies: 1}}),
pytest.mark.xfail(
([{'set': 'LEA', 'name': 'Forest', 'copies': 1}], True, {
'5ede9781b0c5d157c28a15c3153a455d7d6180fa': {
counts.CountTypes.copies: 1}}),
raises=counts.UnknownPrintingError),
])
def test_aggregate_print_counts(cdb, card_rows, strict, expected):
print_counts = counts.aggregate_print_counts(cdb, card_rows, strict)
assert print_counts == expected


@pytest.mark.parametrize('in_print_counts,out_print_counts', [
Expand Down
12 changes: 8 additions & 4 deletions tests/mtg/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@

from mtg_ssm.mtg import models

CARD_DB_SENTINEL = object()

class Sentinel:
"""Simple sentinel object class."""

CARD_DB_SENTINEL = Sentinel()


def test_card(cards_data):
ag_card_data = cards_data['958ae1416f8f6287115ccd7c5c61f2415a313546']
card = models.Card(CARD_DB_SENTINEL, ag_card_data)
assert card.cdb is CARD_DB_SENTINEL
assert card.cdb == CARD_DB_SENTINEL
assert card.name == 'Abattoir Ghoul'
assert not card.strict_basic

Expand All @@ -32,7 +36,7 @@ def test_card_strict_basic(cards_data, name, id_, strict_basic):
def test_card_printing(cards_data):
ag_card_data = cards_data['958ae1416f8f6287115ccd7c5c61f2415a313546']
printing = models.CardPrinting(CARD_DB_SENTINEL, 'ISD', ag_card_data)
assert printing.cdb is CARD_DB_SENTINEL
assert printing.cdb == CARD_DB_SENTINEL
assert printing.id_ == '958ae1416f8f6287115ccd7c5c61f2415a313546'
assert printing.card_name == 'Abattoir Ghoul'
assert printing.set_code == 'ISD'
Expand Down Expand Up @@ -64,7 +68,7 @@ def test_card_set(sets_data):
# Execute
card_set = models.CardSet(CARD_DB_SENTINEL, set_data)
# Verify
assert card_set.cdb is CARD_DB_SENTINEL
assert card_set.cdb == CARD_DB_SENTINEL
assert card_set.code == 'PLS'
assert card_set.name == 'Planeshift'
assert card_set.block == 'Invasion'
Expand Down
Loading

0 comments on commit a3b68f7

Please sign in to comment.