In [1]:
import nest_asyncio
import asyncio
import concurrent.futures

import socket
import json
from json import JSONDecodeError
import datetime

import tempfile
import shutil

import IPython.display

import pyspark
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
from pyspark.sql import Row, SparkSession
from pyspark.sql.types import *
import pyspark.sql.functions as func

nest_asyncio.apply()

def stream_packets(_parse):
    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
        s.bind(("0.0.0.0", 22055))
        while True:
            p, _ = s.recvfrom(65536)
            parsed = ""
            while parsed == "":
                try:
                    parsed = _parse(p)
                except JSONDecodeError:
                    parsed = ""
                    pn, _ = s.recvfrom(65536)
                    p += pn
            yield parsed

def parse_message(x):
    s = x.decode("utf-8").rstrip()
    return [json.loads(m) for m in s.split("\n")]

def record_packets(data_dir):
    streamer = iter(stream_packets(parse_message))
    while True:
        start = datetime.datetime.now()
        now = str(int(start.timestamp()))
        with tempfile.NamedTemporaryFile(dir="/tmp", delete=False, mode="w") as f:
            fname = f.name
            i = 0
            while i < 10 and (datetime.datetime.now() - start).total_seconds() < 5:
                p = next(streamer)
                for m in p:
                    f.write(json.dumps(m) + "\n")
        shutil.move(fname, data_dir + "/" + now + ".jsonl")

spark = SparkSession.builder\
    .appName("Read Socket")\
    .master("local")\
    .enableHiveSupport()\
    .getOrCreate()

sc = spark.sparkContext
ssc = StreamingContext(sc, 1)

In [2]:
class DnsRecordOperator:
    
    def __getstate__(self):
        self._ncalled["__getstate__"] += 1
        state = self.__dict__.copy()
        state["_hostnames_updated_str"] = state["_hostnames_updated"].isoformat()
        del state["_hostnames_updated"]
        return state
    
    def __setstate__(self, state):
        state["_hostnames_updated"] = datetime.datetime.fromisoformat(
            state["_hostnames_updated_str"]
        )
        del state["_hostnames_updated_str"]
        self.__dict__.update(state)
        self._ncalled["__setstate__"] += 1
        return

    def __init__(self, display_handle):
        self._ncalled = {
            "__getstate__": 0,
            "__setstate__": 0,
            "_update_hostnames": 0,
            "_extract_host": 0,
            "_addr_to_hostname": 0,
            "open": 0,
            "process": 0,
            "close": 0
        }        
        self._display_handle = display_handle
        self._update_hostnames()
    
    def _extract_host(self, line):
        self._ncalled["_extract_host"] += 1
        sline = line.rstrip().split("\t")
        return (sline[0], sline[-1])

    def _update_hostnames(self):
        self._ncalled["_update_hostnames"] += 1
        with open("/etc/lx2-hosts", "r") as f:
            self._hostnames = dict(self._extract_host(line) for line in f)
        self._hostnames_updated = datetime.datetime.now()

    def _addr_to_hostname(self, addr):
        self._ncalled["_addr_to_hostname"] += 1
        if (datetime.datetime.now() - self._hostnames_updated).total_seconds() > 300:
            self._update_hostnames()
        return self._hostnames.get(addr, addr)

    def open(self, partition_id, epoch_id):
        self._ncalled["open"] += 1
        return True

    def process(self, row):
        self._ncalled["process"] += 1
        record = row.asDict(True)
        dst_addr = record["layers"]["ip"]["ip_ip_dst"]
        dst_hostname = self._addr_to_hostname(dest_addr)
        qry_name = record["layers"]["dns"]["text_dns_qry_name"]
        out = "\t".join([dst_hostname, qry_name])
        self._display_handle.update(out)
        return out

    def close(self, error):
        self._ncalled["close"] += 1
        return True

In [3]:
with tempfile.TemporaryDirectory() as tempdir:

    packets = spark.readStream\
        .option("cleanSource", "delete")\
        .json(tempdir + "/*.jsonl", schema=schema)\
        .withColumn("time", func.to_timestamp("timestamp"))

    record_loop = asyncio.new_event_loop()
    record_handle = record_loop.run_in_executor(
        concurrent.futures.ThreadPoolExecutor(max_workers=1),
        record_packets, 
        tempdir
    )

    display_handle = IPython.display.display(display_id=True)
    dns_operator = DnsRecordOperator(display_handle)

    def stream_results(df, epoch_id):
        def parse_row(row):
            record = row.asDict(True)
            dst_addr = record["layers"]["ip"]["ip_ip_dst"]
            dst_hostname = dns_operator._addr_to_hostname(dest_addr)
            qry_name = record["layers"]["dns"]["text_dns_qry_name"]
            return {
                "dst_hostname": dst_hostname,
                "qry_name": qry_name
            }

        out = pd.DataFrame(
            df.rdd.map(parse_row).collect()
        )
        display_handle.update(out)
        pass

    dns = packets.filter(func.col("layers.dns.text_dns_qry_name").isNotNull())

    stream_query = dns.withWatermark("time", "2 minutes")\
        .writeStream\
        .outputMode("append")\
        .foreachBatch(stream_results)\
        .start()

In [8]:
tempdir

'/tmp/tmpw5rgm4vs'

In [7]:
json.dumps(dns_operator)

TypeError: Object of type DnsRecordOperator is not JSON serializable