Skip to content

Commit

Permalink
add input coder
Browse files Browse the repository at this point in the history
  • Loading branch information
ohnorobo committed Jun 15, 2023
1 parent ebfde3a commit 1b90698
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 8 deletions.
19 changes: 13 additions & 6 deletions pipeline/beam_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from apache_beam.transforms.sql import SqlTransform
from google.cloud import bigquery as cloud_bigquery # type: ignore

from pipeline.metadata.schema import BigqueryRow, DashboardRow, dict_to_gcs_json_string
from pipeline.metadata.schema import BigqueryRow, DashboardRow, BigqueryOutputRow, BigqueryInputRow, dict_to_gcs_json_string, convert_byperquack_row_to_bq_row_format
from pipeline.metadata import hyperquack
from pipeline.metadata import schema
from pipeline.metadata import flatten_base
Expand Down Expand Up @@ -581,9 +581,16 @@ def derive_dashboard_rows(self, rows: beam.PCollection[schema.HyperquackRow]) ->

sql_query = ''.join(open('table/queries/merge_hyperquack.sql').read())

coder_rows = (rows | 'convert rows' >>
beam.Map(convert_byperquack_row_to_bq_row_format).with_output_types(BigqueryInputRow))

pprint(sql_query)

dash_rows = (rows | 'derive dashboard rows' >> SqlTransform(sql_query)).with_output_types(beam.Row)
pprint(coder_rows)
pprint(dir(coder_rows))
pprint(coder_rows.element_type)

dash_rows = (coder_rows | 'derive dashboard rows' >> SqlTransform(sql_query).with_output_types(BigqueryOutputRow))

pprint("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
pprint(dash_rows)
Expand Down Expand Up @@ -645,13 +652,13 @@ def run_beam_pipeline(self, scan_type: str, incremental_load: bool,
lines |
'reshuffle' >> beam.Reshuffle().with_output_types(Tuple[str, str]))

if scan_type == schema.SCAN_TYPE_SATELLITE:
#if scan_type == schema.SCAN_TYPE_SATELLITE:
# PCollection[SatelliteRow]
rows = satellite.process_satellite_lines(lines, self.metadata_adder)
# rows = satellite.process_satellite_lines(lines, self.metadata_adder)

else: # Hyperquack scans
#else: # Hyperquack scans
# PCollection[HyperquackRow]
rows = hyperquack.process_hyperquack_lines(lines, self.metadata_adder)
rows = hyperquack.process_hyperquack_lines(lines, self.metadata_adder)

_raise_error_if_collection_empty(rows)

Expand Down
66 changes: 65 additions & 1 deletion pipeline/metadata/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import dataclasses
import json
from dataclasses import dataclass
from typing import Optional, List, Dict, Any, Union
from typing import Optional, List, Dict, Any, Union, NamedTuple

from apache_beam.io.gcp.internal.clients import bigquery as beam_bigquery
from apache_beam import coders

# pylint: disable=too-many-instance-attributes

Expand Down Expand Up @@ -218,6 +219,69 @@ class HyperquackRow(BigqueryRow):
outcome: Optional[str] = None


BigqueryInputRow = NamedTuple(
'BigqueryInputRow',
[('domain', str),
('category', str),
('ip', str),
('date', str),
('state_time', str),
('end_time', str),
('retry', int),
('error', str),
('anomaly', bool),
('success', bool),
('is_control', bool),
('controls_failed', bool),
('stateful_block', bool),
('measurement_id', str),
('source', str),
('outcome', str)
]
)
coders.registry.register_coder(BigqueryInputRow, coders.RowCoder)


def convert_byperquack_row_to_bq_row_format(row: HyperquackRow) -> BigqueryInputRow:
return BigqueryInputRow(
row.domain,
row.category,
row.ip,
row.date,
row.start_time,
row.end_time,
row.retry,
row.error,
row.anomaly,
row.success,
row.is_control,
row.controls_failed,
row.stateful_block,
row.measurement_id,
row.source,
row.outcome
)



BigqueryOutputRow = NamedTuple(
'BigqueryOutputRow',
[('date', str),
('source', str),
('country_name', str),
('network', str),
('domain', str),
('outcome', str),
('subnetwork', str),
('category', str),
('count', int),
('unexpected_count', int)
]
)
coders.registry.register_coder(BigqueryOutputRow, coders.RowCoder)



@dataclass
class SatelliteRow(BigqueryRow):
"""Class for satellite specific fields"""
Expand Down
2 changes: 1 addition & 1 deletion wordcount_xlang_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run(p, input_file, output_file):
# Split the line into individual words.
| 'Split' >> beam.FlatMap(lambda line: re.split(r'\W+', line))
# Map each word to an instance of MyRow.
| 'ToRow' >> beam.Map(MyRow).with_output_types(MyRow)
| 'ToRow' >> beam.Map(MyRow) #.with_output_types(MyRow)
# SqlTransform yields a PCollection containing elements with attributes
# based on the output of the query.
| 'Sql!!' >> SqlTransform(
Expand Down

0 comments on commit 1b90698

Please sign in to comment.