## Composite Transforms

In [5]:
import apache_beam as beam
import apache_beam.runners.interactive.interactive_beam as ib
from apache_beam.runners.interactive.interactive_runner import InteractiveRunner

In [6]:
def split_row(e: str):
    return e.split(",")

In [12]:
p = beam.Pipeline(InteractiveRunner())

input_collection = (
    p
    | "Read from file" >> beam.io.ReadFromText("dept_data.txt")
    | "Split by comma" >> beam.Map(split_row)
)

## branching
accounts_count = (
    input_collection
    | "Filter by account" >> beam.Filter(lambda r: r[3] == "Accounts")
    ## duplicate transform logic
    | "Pair each account employee with 1" >> beam.Map(lambda e: (e[1], 1))
    | "Count by account employee" >> beam.CombinePerKey(sum)
)

hr_count = (
    input_collection
    | "Filter by hr" >> beam.Filter(lambda r: r[3] == "HR")
    ## duplicate transform logic
    | "Pair each hr employee with 1" >> beam.Map(lambda e: (e[1], 1))
    | "Count by hr employee" >> beam.CombinePerKey(sum)
)

ib.show_graph(p)

/usr/bin/dot


In [15]:
class CommonTransform(beam.PTransform):
    def expand(self, input_coll):
        output_coll = (
            input_coll
            | "Pair each employee with 1" >> beam.Map(lambda e: (e[1], 1))
            | "Count by employee" >> beam.CombinePerKey(sum)
        )
        return output_coll

p1 = beam.Pipeline(InteractiveRunner())

input_collection = (
    p1
    | "Read from file" >> beam.io.ReadFromText("dept_data.txt")
    | "Split by comma" >> beam.Map(split_row)
)

## branching
accounts_count = (
    input_collection
    | "Filter by account" >> beam.Filter(lambda r: r[3] == "Accounts")
    | "Count by account employee" >> CommonTransform()
)

hr_count = (
    input_collection
    | "Filter by hr" >> beam.Filter(lambda r: r[3] == "HR")
    | "Count by hr employee" >> CommonTransform()
)

ib.show_graph(p1)

/usr/bin/dot
