In [None]:
import correctionlib
import correctionlib.schemav2 as cs
import numpy as np
import rich

np.set_printoptions(precision=3,floatmode='fixed')

# Initialize some dummy data
sz = 5
event_number = np.random.randint(max(sz,123456), max(sz*100,1234567), size=sz)
jet_pt = np.random.exponential(scale=10, size=sz) + np.random.exponential(scale=15.0, size=sz)

### Example 1

Every correction we want to use will be constructed/instantiated as a `Correction` object that all go into a `CorrectionSet` for later lookup in your event processor. Here we create a super basic correction that is just always a flat value. The only input variable here (`"some shape array"`) is used to let `correctionlib` infer the shape of the output array it should return.

In [None]:
flat_correction = cs.Correction(
    name="flat correction",
    version=1,
    inputs=[
        # Doesn't affect the weights, but does define the shape of the output array
        cs.Variable(name="some shape array", type="real", description="Placeholder input to control shape of output array"),
    ],
    output=cs.Variable(name="weight", type="real"),
    data=1.03
)

Usage of this correction might look like:
```python
output = evaluator["flat correction"].evaluate(some_pt_array)
```
This would look functionally equivalent to something like:
```python
output = numpy.full_like(some_pt_array,1.03,dtype='float')
```

### Example 2

Very similar to the first example, but now we use a `Category`, instead of a single `float`, to create named bins to switch between different variations of a particular correction. This lends itself very naturally to computing systematic up/down variations.

In [None]:
flat_variation = cs.Correction(
    name="flat variation",
    version=1,
    inputs=[
        # Again, this first input is simply to determine the shape of the generated output array
        cs.Variable(name="shape", type="real", description="Placeholder input to control shape of output array"),

        # This input will determine which variation we want to use and so will affect the array contents
        cs.Variable(name="direction", type="string"),
    ],
    output=cs.Variable(name="weight", type="real"),
    data=cs.Category(
        nodetype="category",
        input="direction",
        content=[
            cs.CategoryItem(key="up",value=1.0 + 0.03),
            cs.CategoryItem(key="down",value=1.0 - 0.03),
        ],
        default=1.0,    # Value to use when no key is matched, useful for 'nominal' variations
    )
)

### Example 3

Now for an example using more dynamic correction. As before we use the `Category` object in order to define up/down bins, but now the returned value is defined using the `cs.Formula` class. This class lets us define a `TFormula` expression that computes an output based on some per-element input array, e.g. object pt. Instead of a `TFormula`, could just as easily swap it for a `correctionlb` `Binning` or `MultiBinning` object to make use of `pt`,`eta`,`BDT` output, etc. binned corrections/variations

In [None]:
variable_variation = cs.Correction(
    name="variable variation",
    version=1,
    inputs=[
        # We don't need a placeholder 'shape' input, since correctionlib can infer the shape from our inputs this time
        cs.Variable(name="obj pt", type="real"),
        cs.Variable(name="direction", type="string")
    ],
    output=cs.Variable(name="weight", type="real"),
    data=cs.Category(
        nodetype="category",
        input="direction",
        content=[
            cs.CategoryItem(
                key="up",
                value=cs.Formula(
                    nodetype="formula",
                    parser="TFormula",
                    variables=["obj pt"],
                    expression="1.0 + (x*0.075 / 50)"
                )
            ),
            cs.CategoryItem(
                key="down",
                value=cs.Formula(
                    nodetype="formula",
                    parser="TFormula",
                    variables=["obj pt"],
                    expression="1.0 - (x*0.075 / 50)"
                )
            )
        ],
        default=1.0
    )
)

### Putting it all together

We collect all of our corrections together into a `CorrectionSet` object, that is what we feed our inputs into and get a corresponding correction array with a shape that matches the input array shape.

In [None]:
simple_cset = cs.CorrectionSet(
    schema_version=2,
    corrections=[
        flat_correction,
        flat_variation,
        variable_variation
    ]
)

In [None]:
rich.print(simple_cset)

### Saving Corrections

Now to save our corrections to a `JSON` file for later use by us or some colleague. Accomplished simply via `CorrectionSet.json()` method:

In [None]:
with open("demo_corrections.json", "w") as fout:
    fout.write(simple_cset.json(exclude_unset=True))

### Loading Corrections

We can then use this `JSON` file to insantiate our `CorrectionSet` object directly.

**Note:** When loading from a file, the code automatically converts the `CorrectionSet` to an evaluator. If we instead wanted to use the `simple_cset` object we constructed earlier directly, we would need to something like `ceval = simple_cset.to_evaluator()` before proceeding.

In [None]:
ceval = correctionlib.CorrectionSet.from_file("demo_corrections.json")
print(f"flat correction -- {ceval['flat correction'].evaluate(jet_pt)}")
for syst in ['flat variation', 'variable variation']:
    for d in ['up', 'nominal', 'down']:
        print(f"{syst} -- {d:<7}: {ceval[syst].evaluate(jet_pt,d)}")

In [None]:
import json
with open("demo_corrections.json", "r") as fin:
    j = json.load(fin)
    print(json.dumps(j,indent=4))

### Alternate `CorrectionSet` Structure

The nestable nature of the `Correction` objects makes `correctionlib` very flexible. Here we show how we can put each of the above correction examples into individual `CategoryItem` objects that are grouped together under a single overarching `Correction`, which might correspond to some more abstract class of corrections, e.g. all of the jet systematics, or all of the pt/eta binned systematics, etc.

In [None]:
# Note: Need to be a bit careful here and make sure that all Variable names are internally consistent
#       within the Correction object. To make this easier, we can simply insantiate the Variable objects
#       directly and re-use them where needed
syst_name_variable = cs.Variable(name="systematic name", type="string", description="Name of systematic")
direction_variable = cs.Variable(name="direction", type="string", description="Direction of the variation")
jet_pt_variable = cs.Variable(name="pt", type="real", description="The pt of the jets")
output_variable = cs.Variable(name="weight", type="real")

# Example 1
jet_pt_scale = cs.CategoryItem(
    key="flat correction",
    value=1.03
)

# Example 2
jet_pt_variation = cs.CategoryItem(
    key="flat variation",
    value=cs.Category(
        nodetype="category",
        input=direction_variable.name,
        content=[
            cs.CategoryItem(key="up", value=1.0 + 0.03),
            cs.CategoryItem(key="down", value=1.0 - 0.03)
        ],
        default=1.0
    )
)

# Example 3
jet_pt_formula_up = cs.Formula(
    nodetype="formula",
    parser="TFormula",
    variables=[jet_pt_variable.name],
    expression="1.0 + (x*0.075 / 50)"
)
jet_pt_formula_down = cs.Formula(
    nodetype="formula",
    parser="TFormula",
    variables=[jet_pt_variable.name],
    expression="1.0 - (x*0.075 / 50)"
)
jet_pt_variable_variation = cs.CategoryItem(
    key="variable variation",
    value=cs.Category(
        nodetype="category",
        input=direction_variable.name,
        content=[
            cs.CategoryItem(key="up", value=jet_pt_formula_up),
            cs.CategoryItem(key="down", value=jet_pt_formula_down)
        ],
        default=1.0
    )
)

In [None]:
# Putting it all together
jet_systs = cs.Correction(
    name="jet systematics",
    version=1,
    inputs=[
        syst_name_variable,
        direction_variable,
        jet_pt_variable
    ],
    output=output_variable,
    data=cs.Category(
        nodetype="category",
        input=syst_name_variable.name,
        content=[
            jet_pt_variation,
            jet_pt_scale,
            jet_pt_variable_variation
        ],
        default=1.0
    )
)
compact_cset = cs.CorrectionSet(
    schema_version=2,
    corrections=[
        jet_systs
    ]
)
rich.print(compact_cset)

In [None]:
compact_ceval = compact_cset.to_evaluator()
print(f"flat correction -- {compact_ceval['jet systematics'].evaluate('flat correction','nominal',jet_pt)}")
for syst in ['flat variation', 'variable variation']:
    for d in ['up','nominal','down']:
        print(f"{syst} -- {d:<7}: {compact_ceval['jet systematics'].evaluate(syst,d,jet_pt)}")