In [1]:
from operator import itemgetter
from itertools import count
import mmap
import multiprocessing as mp
import os
from pathlib import Path
from pprint import pprint
from time import perf_counter_ns
from typing import TypeAlias

import polars as pl

In [2]:
PathLike: TypeAlias = os.PathLike | str
Accumulator: TypeAlias = dict[bytes, list[float, float, float, int]]

In [17]:
here = Path(os.path.abspath("")).resolve()
data_dir = here.parent.parent.parent / "data"
n_lines = 10_000
data_path = data_dir / f"measurements_{n_lines}.txt"
data_path = data_dir / "measurements.txt"

In [4]:
def pretty_print_solution(d: Accumulator):
    formatted_strings = (
        f"{k.decode('utf-8')}={v[0]}/{v[1]}/{v[2]}"
        for k, v in sorted(d.items(), key=itemgetter(0))
    )
    print("{", end="")
    print(", ".join(formatted_strings), end="")
    print("}")

# How fast can we iterate through the file?

Execution times on real data set:
- 2min 21s
- 1min 18s

In [5]:
%%time
with open(data_path) as fp:
    for line in fp:
        pass

CPU times: user 5.25 ms, sys: 221 µs, total: 5.47 ms
Wall time: 4.43 ms


In [6]:
%%time
with open(data_path, "r+b") as fp, mmap.mmap(fp.fileno(), 0) as mm:
    while (line := mm.readline()):
        pass

CPU times: user 1.64 ms, sys: 255 µs, total: 1.9 ms
Wall time: 1.79 ms


# Polars

Execution times on real data set:
- 53.7s

In [7]:
%%time
records = (
    pl.scan_csv(data_path, has_header=False, separator=";", new_columns=["station", "temperature"])
    .group_by("station")
    .agg(
        pl.col("temperature").min().name.suffix("_min"),
        pl.col("temperature").mean().name.suffix("_avg"), 
        pl.col("temperature").max().name.suffix("_max")
    )
    .sort(pl.col("station"))
    .collect(streaming=True)
    .to_dicts()
)
#pretty_print_solution({d["station"]: [d["temperature_min"], d["temperature_avg"], d["temperature_max"]] for d in records})

CPU times: user 54.3 ms, sys: 4.99 ms, total: 59.3 ms
Wall time: 69.2 ms


# Naive solution

Execution times on real data set:
- 15min 4s
- 15min 57s

In [8]:
def update_in_place(line: bytes, d: Accumulator):
    station, temperature = line.split(b";")
    temperature = float(temperature.decode("utf-8"))
    if station in d:
        old_min, old_avg, old_max, old_count = old = d[station]
        if temperature < old_min:
            old[0] = temperature
        old[1] = old_avg + ((temperature - old_avg) / old_count)
        if temperature > old_max:
            old[2] = temperature
        old[3] += 1
    else:
        d[station] = [float("+inf"), temperature, float("-inf"), 1]

In [9]:
%%time
d = {}  # station_name => [min, avg, max, count]
with open(data_path, "rb") as fp:
    for line in fp:
        update_in_place(line, d)
pretty_print_solution(d)

{A Coruña=inf/-65.8/-inf, Aarsâl=4.2/4.200000000000003/4.2, Aasiaat=inf/-65.6/-inf, Abakaliki=inf/16.6/-inf, Abbeville=inf/-26.5/-inf, Abbiategrasso=-24.0/-24.0/-24.0, Abdullahnagar=-5.1/-5.1/-5.1, Abdurahmoni Jomí=-34.6/-34.6/-34.6, Abergavenny=inf/-55.8/-inf, Abergele=inf/27.1/-inf, Abertawe=-33.9/-33.9/-33.9, Abertillery=-23.2/8.933333333333334/51.2, Abhwar=62.0/62.0/62.0, Abinsk=inf/-15.0/-inf, Abjīj=inf/6.9/-inf, Ablu=inf/28.9/-inf, Aboso=-86.9/-30.266666666666666/19.1, Abqaiq=8.3/8.299999999999997/8.3, Abram=-75.6/-75.6/-75.6, Abrandābād-e Shāhedīyeh=inf/-97.0/-inf, Abrantes=45.1/45.1/45.1, Abu=-22.7/6.466666666666666/45.6, Abéché=inf/40.8/-inf, Abī al Khaşīb=-36.2/-7.550000000000004/21.1, Acahay=14.8/14.800000000000002/14.8, Acajutla=inf/4.8/-inf, Acala=inf/-70.7/-inf, Acandí=inf/42.2/-inf, Acarigua=inf/51.7/-inf, Acatenango=-77.5/-68.2/-58.9, Acatzingo=-42.8/-42.800000000000004/-42.8, Accokeek=inf/-61.9/-inf, Accrington=26.3/26.299999999999997/26.3, Acerra=-1.9/38.15/78.2, Achc

In [10]:
%%time
d = {}  # station_name => [min, avg, max, count]
with open(data_path, "r+b") as fp, mmap.mmap(fp.fileno(), 0) as mm:
    while (line := mm.readline()):
        update_in_place(line, d)
# pretty_print_solution(d)

CPU times: user 9.5 ms, sys: 0 ns, total: 9.5 ms
Wall time: 9.49 ms


# Batching

Execution time in real data set:
- 13min 9s

In [11]:
def batch_indices(filepath: PathLike, n: int) -> tuple[int, ...]:
    """Returns the byte indices that divide `filepath` in `n` almost-equal chunks.
    We say "almost" cause for easiness downstream we find the closest newline after
    the actual chunking index byte.
    """
    n_bytes = os.path.getsize(filepath)
    chunk_size = n_bytes // (n - 1)
    indices = [0]
    with open(filepath, "rb") as fp:
        for i in range(chunk_size, n_bytes, chunk_size):
            for j in count(i):
                fp.seek(j)
                c = fp.read(1)
                if c == b"\n":
                    break
            indices.append(j)
    indices.append(n_bytes)
    return tuple(indices)

In [12]:
def process_batch(batch: bytes) -> Accumulator:
    d: Accumulator = {}
    for line in batch.splitlines():
        if not line: continue
        update_in_place(line, d)
    return d

In [13]:
def consolidate_accumulators(*accumulators: Accumulator) -> Accumulator:
    iter_accs = iter(accumulators)
    consolidated: Accumulator = next(iter_accs)
    for accumulator in iter_accs:
        common_keys = set(consolidated) & set(accumulator)
        for k, v in accumulator.items():
            if k not in common_keys:
                consolidated[k] = v
            else:
                old_min, old_avg, old_max, old_count = old = consolidated[k]
                new_min, new_avg, new_max, new_count = new = accumulator[k]
                if new_min < old_min:
                    old[0] = new_min
                old[1] = (1 / (old_count + new_count)) * ((old_count * old_avg) + (new_count * new_avg))
                if new_max > old_max:
                    old[2] = new_max
                old[3] += new_count
    return consolidated

In [14]:
%%time
indices = batch_indices(data_path, 8)
accumulators = []
with open(data_path, "rb") as fp:
    for start, end in zip(indices, indices[1:]):
        fp.seek(start)
        accumulator = process_batch(fp.read(end - start))
        # pretty_print_solution(processed)
        accumulators.append(accumulator)
consolidated = consolidate_accumulators(*accumulators)
# pretty_print_solution(consolidated)

CPU times: user 12.9 ms, sys: 133 µs, total: 13 ms
Wall time: 12.9 ms


# Parallelization

Execution time on real dataset:
- 14min 2s

In [None]:
%%time
n_workers = 4
indices = batch_indices(data_path, 8)
jobs = []
with mp.Pool(processes=n_workers) as pool, open(data_path, "rb") as fp:
    contents = []
    for start, end in zip(indices, indices[1:]):
        fp.seek(start)
        contents.append(fp.read(end - start))
    pool.map(process_batch, contents)

In [16]:
%%time
result = consolidate_accumulators(*results)
pretty_print_solution(result)

NameError: name 'results' is not defined