Skip to content

Commit

Permalink
Changing up column removal API
Browse files Browse the repository at this point in the history
- making the column removal API more organized
- adding a "master" function for common use (aggressively_strip)
- some docs updates
  • Loading branch information
douglasdavis committed Apr 24, 2019
1 parent a57c4e6 commit ba64b7b
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 71 deletions.
5 changes: 3 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
author = "Doug Davis"

# The short X.Y version
version = ".".join(twaml.version.version.split(".")[:3])
#version = ".".join(twaml.version.version.split(".")[:3])
version = twaml.__version__

# The full version, including alpha/beta/rc tags
release = twaml.version.version
release = twaml.__version__

# -- General configuration ---------------------------------------------------

Expand Down
66 changes: 37 additions & 29 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from twaml.data import from_root, from_pytables, from_h5

branches = ["pT_lep1", "pT_lep2", "eta_lep1", "eta_lep2"]
ds = from_root(["tests/data/test_file.root"], name="myds", branches=branches, TeXlabel=r"$t\bar{t}$")
ds = from_root(
["tests/data/test_file.root"],
name="myds",
branches=branches,
TeXlabel=r"$t\bar{t}$",
)


def test_name():
Expand Down Expand Up @@ -37,9 +42,7 @@ def test_nothing():


def test_with_executor():
lds = from_root(
["tests/data/test_file.root"], branches=branches, nthreads=4
)
lds = from_root(["tests/data/test_file.root"], branches=branches, nthreads=4)
np.testing.assert_array_almost_equal(lds.weights, ds.weights, 8)


Expand All @@ -53,9 +56,7 @@ def test_weight():


def test_add():
ds2 = from_root(
["tests/data/test_file.root"], name="ds2", branches=branches
)
ds2 = from_root(["tests/data/test_file.root"], name="ds2", branches=branches)
ds2.weights = ds2.weights * 22
combined = ds + ds2
comb_w = np.concatenate([ds.weights, ds2.weights])
Expand Down Expand Up @@ -87,12 +88,8 @@ def test_selection():

def test_append():
branches = ["pT_lep1", "pT_lep2", "eta_lep1", "eta_lep2"]
ds1 = from_root(
["tests/data/test_file.root"], name="myds", branches=branches
)
ds2 = from_root(
["tests/data/test_file.root"], name="ds2", branches=branches
)
ds1 = from_root(["tests/data/test_file.root"], name="myds", branches=branches)
ds2 = from_root(["tests/data/test_file.root"], name="ds2", branches=branches)
ds2.weights = ds2.weights * 5
# raw
comb_w = np.concatenate([ds1.weights, ds2.weights])
Expand Down Expand Up @@ -161,9 +158,7 @@ def test_auxweights():


def test_label():
ds2 = from_root(
["tests/data/test_file.root"], name="ds2", branches=branches
)
ds2 = from_root(["tests/data/test_file.root"], name="ds2", branches=branches)
assert ds2.label is None
assert ds2.label_asarray() is None
ds2.label = 6
Expand All @@ -173,9 +168,7 @@ def test_label():


def test_auxlabel():
ds2 = from_root(
["tests/data/test_file.root"], name="ds2", branches=branches
)
ds2 = from_root(["tests/data/test_file.root"], name="ds2", branches=branches)
assert ds2.auxlabel is None
assert ds2.auxlabel_asarray() is None
ds2.auxlabel = 3
Expand All @@ -197,9 +190,7 @@ def test_save_and_read():


def test_raw_h5():
inds = from_h5(
"tests/data/raw.h5", "WtLoop_nominal", ["pT_jet1", "nbjets", "met"]
)
inds = from_h5("tests/data/raw.h5", "WtLoop_nominal", ["pT_jet1", "nbjets", "met"])
rawf = h5py.File("tests/data/raw.h5")["WtLoop_nominal"]
raww = rawf["weight_nominal"]
rawm = rawf["met"]
Expand All @@ -208,12 +199,8 @@ def test_raw_h5():


def test_scale_weight_sum():
ds1 = from_root(
["tests/data/test_file.root"], name="myds", branches=branches
)
ds2 = from_root(
["tests/data/test_file.root"], name="ds2", branches=branches
)
ds1 = from_root(["tests/data/test_file.root"], name="myds", branches=branches)
ds2 = from_root(["tests/data/test_file.root"], name="ds2", branches=branches)
ds2.weights = np.random.randn(len(ds1)) * 10
scale_weight_sum(ds1, ds2)
testval = abs(1.0 - ds2.weights.sum() / ds1.weights.sum())
Expand Down Expand Up @@ -250,14 +237,35 @@ def test_columnrming():
auxweights=["pT_lep1", "pT_lep2", "pT_jet1"],
)

ds1.rmcolumns(["met", "sumet"])
ds1.rm_columns(["met", "sumet"])
list_of_cols = list(ds1.df.columns)
assert (
len(list_of_cols) == 2
and "pT_jet2" in list_of_cols
and "reg2j2b" in list_of_cols
)

ds1 = from_root(["tests/data/test_file.root"], name="myds")
list_of_cols = list(ds1.df.columns)
assert "OS" in list_of_cols
assert "SS" in list_of_cols
assert "elmu" in list_of_cols
assert "elel" in list_of_cols
assert "mumu" in list_of_cols
list_of_regs = [reg for reg in list_of_cols if "reg" in reg]
ds1.rm_chargeflavor_columns()
ds1.rm_region_columns()
ds1.rm_weight_columns()
list_of_cols_after = list(ds1.df.columns)
assert "OS" not in list_of_cols_after
assert "SS" not in list_of_cols_after
assert "elmu" not in list_of_cols_after
assert "mumu" not in list_of_cols_after
assert "elel" not in list_of_cols_after
assert "reg1j1b" not in list_of_cols_after
for r in list_of_regs:
assert r not in list_of_cols_after


def test_apply_selections():
ds2 = from_root(
Expand Down
6 changes: 3 additions & 3 deletions twaml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"""

from .version import version
__version__ = version

import logging

logging.basicConfig(
Expand All @@ -27,6 +30,3 @@
logging.addLevelName(
logging.DEBUG, "\033[1;34m{:8}\033[1;0m".format(logging.getLevelName(logging.DEBUG))
)


from .data import dataset
4 changes: 2 additions & 2 deletions twaml/_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import argparse
from twaml.data import dataset
from twaml.data import from_root
import twaml.utils
import yaml
Expand Down Expand Up @@ -110,7 +109,8 @@ def root2pytables():
"r1j1b": twaml.utils.SELECTION_1j1b,
"r2j1b": twaml.utils.SELECTION_2j1b,
"r2j2b": twaml.utils.SELECTION_2j2b,
"r3j": twaml.utils.SELECTION_3j,
"r3j1b": twaml.utils.SELECTION_3j1b,
"r3jHb": twaml.utils.SELECTION_3jHb,
}

elif args.selection.endswith(".yml") or args.selection.endswith(".yaml"):
Expand Down
105 changes: 71 additions & 34 deletions twaml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def _set_df_and_weights(
self._auxweights = auxw

def keep_columns(self, cols: List[str]) -> None:

"""Drop all columns not included in ``cols``
Parameters
Expand All @@ -233,51 +234,67 @@ def keep_columns(self, cols: List[str]) -> None:
"""
self._df = self._df[cols]

def keep_weights(self, weights: List[str]) -> None:
"""Drop all columns from the aux weights frame that are not in
``weights``
def aggressively_strip(self) -> None:
"""Drop all columns that should never be used in a classifier.
Parameters
----------
weights: List[str]
Weights to keep in the aux weights frame
This calls the following functions:
- :meth:`rm_meta_columns`
- :meth:`rm_region_columns`
- :meth:`rm_chargeflavor_columns`
- :meth:`rm_weight_columns`
"""
self._auxweights = self._auxweights[weights]
self.rm_meta_columns()
self.rm_region_columns()
self.rm_chargeflavor_columns()
self.rm_weight_columns()

def rm_weight_columns(self) -> None:
"""Remove all payload df columns which begin with ``weight_``
def rm_meta_columns(self) -> None:
"""Drop all columns are are considered meta data from the payload
If you are reading a dataset that was created retaining
weights in the main payload, this is a useful function to
remove them. The design of ``twaml.data.dataset`` expects
weights to be separated from the payload's main dataframe.
This includes runNumber, eventNumber, randomRunNumber
Internally this is done by calling
:meth:`pandas.DataFrame.drop` with ``inplace`` on the payload
:meth:`pandas.DataFrame.drop` with ``inplace`` on the payload.
"""
import re
self.df.drop(
columns=["runNumber", "randomRunNumber", "eventNumber"], inplace=True
)

pat = re.compile("^weight_")
rmthese = [c for c in self._df.columns if re.match(pat, c)]
def rm_region_columns(self) -> None:
"""Drop all columns that are prefixed with "reg", e.g. "reg2j2b"
Internally this is done by calling
:meth:`pandas.DataFrame.drop` with ``inplace`` on the payload.
"""
rmthese = [c for c in self._df.columns if re.match(r"^reg[0-9]\w+", c)]
self._df.drop(columns=rmthese, inplace=True)

def rmcolumns_re(self, pattern: str) -> None:
"""Remove some columns from the payload based on regex paterns
def rm_chargeflavor_columns(self) -> None:
"""Drop all columns that are related to charge and flavor
This would be [elmu, elel, mumu, OS, SS]
Internally this is done by calling
:meth:`pandas.DataFrame.drop` with ``inplace`` on the payload.
"""
self.df.drop(columns=["OS", "SS", "elmu", "elel", "mumu"], inplace=True)

def rm_weight_columns(self) -> None:
"""Remove all payload df columns which begin with ``weight_``
If you are reading a dataset that was created retaining
weights in the main payload, this is a useful function to
remove them. The design of ``twaml.data.dataset`` expects
weights to be separated from the payload's main dataframe.
Internally this is done by calling
:meth:`pandas.DataFrame.drop` with ``inplace`` on the payload
Parameters
----------
pattern : str
Regex used to remove columns
"""
pat = re.compile(pattern)
rmthese = [c for c in self._df.columns if re.search(pat, c)]
rmthese = [c for c in self._df.columns if re.match(r"^weight_\w+", c)]
self._df.drop(columns=rmthese, inplace=True)

def rmcolumns(self, cols: List[str]) -> None:
def rm_columns(self, cols: List[str]) -> None:
"""Remove columns from the dataset
Internally this is done by calling
Expand All @@ -291,6 +308,17 @@ def rmcolumns(self, cols: List[str]) -> None:
"""
self._df.drop(columns=cols, inplace=True)

def keep_weights(self, weights: List[str]) -> None:
"""Drop all columns from the aux weights frame that are not in
``weights``
Parameters
----------
weights: List[str]
Weights to keep in the aux weights frame
"""
self._auxweights = self._auxweights[weights]

def change_weights(self, wname: str) -> None:
"""Change the main weight of the dataset
Expand Down Expand Up @@ -477,7 +505,7 @@ def apply_selections(self, selections: Dict[str, str]) -> Dict[str, "dataset"]:
>>> selections = { '1j1b' : '(reg1j1b == True) & (OS == True) & (elmu == True)',
... '2j1b' : '(reg2j1b == True) & (OS == True) & (elmu == True)',
... '2j2b' : '(reg2j2b == True) & (OS == True) & (elmu == True)',
... '3j' : '(reg3j == True) & (OS == True) & (elmu == True)'}
... '3j1b' : '(reg3j1b == True) & (OS == True) & (elmu == True)'}
>>> selected_datasets = ds.apply_selections(selections)
"""
Expand Down Expand Up @@ -518,6 +546,7 @@ def from_root(
label: Optional[int] = None,
auxlabel: Optional[int] = None,
allow_weights_in_df: bool = False,
aggressively_strip: bool = False,
auxweights: Optional[List[str]] = None,
detect_weights: bool = False,
nthreads: Optional[int] = None,
Expand Down Expand Up @@ -547,7 +576,9 @@ def from_root(
auxlabel:
Give the dataset an integer auxiliary label
allow_weights_in_df:
Allow "^weight_" branches in the payload dataframe
Allow "^weight_\w+" branches in the payload dataframe
aggressively_strip:
Call :meth:`twaml.data.dataset.aggressively_strip` during construction
auxweights:
Auxiliary weights to store in a second dataframe.
detect_weights:
Expand Down Expand Up @@ -632,7 +663,7 @@ def from_root(

uproot_trees = [uproot.open(file_name)[tree_name] for file_name in input_files]

wpat = re.compile("^weight_")
wpat = re.compile(r"^weight_\w+")
if auxweights is not None:
w_branches = auxweights
elif detect_weights:
Expand Down Expand Up @@ -679,6 +710,9 @@ def from_root(

ds._set_df_and_weights(df, weights_array, auxw=aw_df)

if aggressively_strip:
ds.aggressively_strip()

return ds


Expand Down Expand Up @@ -721,7 +755,9 @@ def from_pytables(
Examples
--------
>>> ds1 = dataset.from_pytables("ttbar.h5", "ttbar")
Creating a dataset from pytables where everything is auto detected:
>>> ds1 = dataset.from_pytables("ttbar.h5")
>>> ds1.label = 1 ## add label dataset after the fact
"""
Expand Down Expand Up @@ -796,7 +832,7 @@ def from_h5(
Names of columns (branches) to include in payload
tree_name:
Name of tree dataset originates from (HDF5 dataset name)
weight_name: str
weight_name:
Name of the weight array inside the h5 file
label:
Give the dataset an integer label
Expand All @@ -808,7 +844,8 @@ def from_h5(
Examples
--------
>>> ds = dataset.from_h5('file.h5', 'dsname', tree_name='WtLoop_EG_RESOLUTION_ALL__1up')
>>> ds = dataset.from_h5("file.h5", "dsname", TeXlabel=r"$tW$",
... tree_name="WtLoop_EG_RESOLUTION_ALL__1up")
"""
ds = dataset()
Expand Down
2 changes: 2 additions & 0 deletions twaml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
SELECTION_1j1b = "(OS == True) & (elmu == True) & (reg1j1b == True)"
SELECTION_2j1b = "(OS == True) & (elmu == True) & (reg2j1b == True)"
SELECTION_2j2b = "(OS == True) & (elmu == True) & (reg2j2b == True)"
SELECTION_3j1b = "(OS == True) & (elmu == True) & (reg3j1b == True)"
SELECTION_3jHb = "(OS == True) & (elmu == True) & (reg3jHb == True)"
SELECTION_3j = "(OS == True) & (elmu == True) & (reg3j == True)"

TEXIT = {
Expand Down

0 comments on commit ba64b7b

Please sign in to comment.