Skip to content

Commit

Permalink
Minor refactor for simplicity and potential speedup.
Browse files Browse the repository at this point in the history
  • Loading branch information
jab committed Feb 17, 2024
1 parent 9822192 commit e6d404e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 30 deletions.
40 changes: 19 additions & 21 deletions bidict/_base.py
Expand Up @@ -18,7 +18,6 @@

import typing as t
import weakref
from functools import partial
from itertools import starmap
from operator import eq
from types import MappingProxyType
Expand Down Expand Up @@ -46,7 +45,7 @@

OldKV: t.TypeAlias = t.Tuple[OKT[KT], OVT[VT]]
DedupResult: t.TypeAlias = t.Optional[OldKV[KT, VT]]
Unwrite: t.TypeAlias = t.Callable[[], None]
Unwrites: t.TypeAlias = t.List[t.Tuple[t.Any, ...]]
BT = t.TypeVar('BT', bound='BidictBase[t.Any, t.Any]')


Expand Down Expand Up @@ -353,7 +352,7 @@ def _dedup(self, key: KT, val: VT, on_dup: OnDup) -> DedupResult[KT, VT]:
# else neither isdupkey nor isdupval.
return oldkey, oldval

def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwrites: list[Unwrite] | None) -> None:
def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwrites: Unwrites | None) -> None:
"""Insert (newkey, newval), extending *unwrites* with associated inverse operations if provided.
*oldkey* and *oldval* are as returned by :meth:`_dedup`.
Expand All @@ -377,38 +376,38 @@ def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwri
# {0: 1, 2: 3} | {4: 5} => {0: 1, 2: 3, 4: 5}
if unwrites is not None:
unwrites.extend((
partial(fwdm_del, newkey),
partial(invm_del, newval),
(fwdm_del, newkey),
(invm_del, newval),
))
elif oldval is not MISSING and oldkey is not MISSING: # key and value duplication across two different items
# {0: 1, 2: 3} | {0: 3} => {0: 3}
fwdm_del(oldkey)
invm_del(oldval)
if unwrites is not None:
unwrites.extend((
partial(fwdm_set, newkey, oldval),
partial(invm_set, oldval, newkey),
partial(fwdm_set, oldkey, newval),
partial(invm_set, newval, oldkey),
(fwdm_set, newkey, oldval),
(invm_set, oldval, newkey),
(fwdm_set, oldkey, newval),
(invm_set, newval, oldkey),
))
elif oldval is not MISSING: # just key duplication
# {0: 1, 2: 3} | {2: 4} => {0: 1, 2: 4}
invm_del(oldval)
if unwrites is not None:
unwrites.extend((
partial(fwdm_set, newkey, oldval),
partial(invm_set, oldval, newkey),
partial(invm_del, newval),
(fwdm_set, newkey, oldval),
(invm_set, oldval, newkey),
(invm_del, newval),
))
else:
assert oldkey is not MISSING # just value duplication
# {0: 1, 2: 3} | {4: 3} => {0: 1, 4: 3}
fwdm_del(oldkey)
if unwrites is not None:
unwrites.extend((
partial(fwdm_set, oldkey, newval),
partial(invm_set, newval, oldkey),
partial(fwdm_del, newkey),
(fwdm_set, oldkey, newval),
(invm_set, newval, oldkey),
(fwdm_del, newkey),
))

def _update(
Expand Down Expand Up @@ -449,18 +448,17 @@ def _update(
# as we go. If the update results in a DuplicationError and rollback is enabled, apply the accumulated unwrites
# before raising, to ensure that we fail clean.
write = self._write
unwrites: list[Unwrite] | None = [] if rollback else None
unwrites: Unwrites | None = [] if rollback else None
for key, val in iteritems(arg, **kw):
try:
dedup_result = self._dedup(key, val, on_dup)
except DuplicationError:
if unwrites is not None:
for unwrite in reversed(unwrites):
unwrite()
for fn, *args in reversed(unwrites):
fn(*args)
raise
if dedup_result is None: # no-op
continue
write(key, val, *dedup_result, unwrites=unwrites)
if dedup_result is not None:
write(key, val, *dedup_result, unwrites=unwrites)

def __copy__(self: BT) -> BT:
"""Used for the copy protocol. See the :mod:`copy` module."""
Expand Down
17 changes: 8 additions & 9 deletions bidict/_orderedbase.py
Expand Up @@ -17,11 +17,10 @@
from __future__ import annotations

import typing as t
from functools import partial
from weakref import ref as weakref

from ._base import BidictBase
from ._base import Unwrite
from ._base import Unwrites
from ._bidict import bidict
from ._iter import iteritems
from ._typing import KT
Expand Down Expand Up @@ -167,7 +166,7 @@ def _init_from(self, other: MapOrItems[KT, VT]) -> None:
for k, v in iteritems(other):
korv_by_node_set(new_node(), k if bykey else v)

def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwrites: list[Unwrite] | None) -> None:
def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwrites: Unwrites | None) -> None:
"""See :meth:`bidict.BidictBase._spec_write`."""
super()._write(newkey, newval, oldkey, oldval, unwrites)
assoc, dissoc = self._assoc_node, self._dissoc_node
Expand All @@ -177,7 +176,7 @@ def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwri
newnode = self._sntl.new_last_node()
assoc(newnode, newkey, newval)
if unwrites is not None:
unwrites.append(partial(dissoc, newnode))
unwrites.append((dissoc, newnode))
elif oldval is not MISSING and oldkey is not MISSING: # key and value duplication across two different items
# {0: 1, 2: 3} | {0: 3} => {0: 3}
# n1, n2 => n1 (collapse n1 and n2 into n1)
Expand All @@ -192,25 +191,25 @@ def _write(self, newkey: KT, newval: VT, oldkey: OKT[KT], oldval: OVT[VT], unwri
assoc(newnode, newkey, newval)
if unwrites is not None:
unwrites.extend((
partial(assoc, newnode, newkey, oldval),
partial(assoc, oldnode, oldkey, newval),
oldnode.relink,
(assoc, newnode, newkey, oldval),
(assoc, oldnode, oldkey, newval),
(oldnode.relink,),
))
elif oldval is not MISSING: # just key duplication
# {0: 1, 2: 3} | {2: 4} => {0: 1, 2: 4}
# oldkey: MISSING, oldval: 3, newkey: 2, newval: 4
node = node_by_korv[newkey if bykey else oldval]
assoc(node, newkey, newval)
if unwrites is not None:
unwrites.append(partial(assoc, node, newkey, oldval))
unwrites.append((assoc, node, newkey, oldval))
else:
assert oldkey is not MISSING # just value duplication
# {0: 1, 2: 3} | {4: 3} => {0: 1, 4: 3}
# oldkey: 2, oldval: MISSING, newkey: 4, newval: 3
node = node_by_korv[oldkey if bykey else newval]
assoc(node, newkey, newval)
if unwrites is not None:
unwrites.append(partial(assoc, node, oldkey, newval))
unwrites.append((assoc, node, oldkey, newval))

def __iter__(self) -> t.Iterator[KT]:
"""Iterator over the contained keys in insertion order."""
Expand Down

0 comments on commit e6d404e

Please sign in to comment.