In [4]:
import polars as pl


def get_frame(file: str) -> pl.LazyFrame:
    return (
        pl.scan_ndjson(f"sample/{file}.gz")
        .select(
            pl.col("id").alias("work_id"),
            pl.col("doi"),
            pl.col("title"),
            pl.col("publication_date").alias("date"),
            (
                pl.col("authorships")
                .alias("authors")
                .list.eval(
                    pl.struct(
                        pl.element().struct.field("author").struct.field("id"),
                        (
                            pl.element()
                            .struct.field("author")
                            .struct.field("display_name")
                            .alias("name")
                        ),
                        (pl.element().struct.field("countries").alias("country_codes")),
                        (
                            pl.element()
                            .struct.field("institutions")
                            .list.eval(
                                pl.struct(
                                    pl.element().struct.field("id"),
                                    pl.element()
                                    .struct.field("display_name")
                                    .alias("name"),
                                    pl.element().struct.field("country_code"),
                                )
                            )
                        ),
                    )
                )
            ),
            (
                pl.col("concepts").list.eval(
                    pl.struct(
                        pl.element().struct.field("id"),
                        pl.element().struct.field("display_name").alias("name"),
                        pl.element().struct.field("score").cast(pl.Float32),
                    )
                )
            ),
            (
                pl.col("topics").list.eval(
                    pl.struct(
                        pl.element().struct.field("id"),
                        pl.element().struct.field("display_name").alias("name"),
                        (
                            pl.element()
                            .struct.field("domain")
                            .struct.field("display_name")
                            .alias("domain")
                        ),
                        (
                            pl.element()
                            .struct.field("field")
                            .struct.field("display_name")
                            .alias("field")
                        ),
                        pl.element().struct.field("score"),
                    )
                )
            ),
            pl.col("referenced_works"),
            pl.col("related_works"),
            (pl.col("keywords").list.eval(pl.element().struct.field("display_name"))),
            (
                pl.col("counts_by_year")
                .alias("citations")
                .list.eval(
                    pl.struct(
                        pl.element().struct.field("year").cast(pl.Int32),
                        pl.element()
                        .struct.field("cited_by_count")
                        .alias("count")
                        .cast(pl.Int32),
                    )
                )
            ),
        )
        .with_columns(
            # authors
            (
                pl.col("authors").list.eval(
                    pl.element().struct.with_fields(
                        pl.field("id").str.strip_prefix("https://openalex.org/"),
                        pl.field("institutions").list.eval(
                            pl.element().struct.with_fields(
                                pl.field("id").str.strip_prefix("https://openalex.org/")
                            )
                        ),
                    ),
                )
            ),
            # topics
            (
                pl.col("topics").list.eval(
                    pl.element().struct.with_fields(
                        pl.field("id").str.strip_prefix("https://openalex.org/")
                    )
                )
            ),
            # referenced works
            (
                pl.col("referenced_works").list.eval(
                    pl.element().str.strip_prefix("https://openalex.org/")
                )
            ),
            # related works
            (
                pl.col("related_works").list.eval(
                    pl.element().str.strip_prefix("https://openalex.org/")
                )
            ),
            # concepts
            (
                pl.col("concepts").list.eval(
                    pl.element().struct.with_fields(
                        pl.field("id").str.strip_prefix("https://openalex.org/")
                    )
                )
            ),
        )
    )

In [5]:
import os
import time
from tqdm import tqdm


data = []


files = [
    "part_000"
]

for index, file in enumerate(files):
    lf = get_frame(file)
    for comp_level in tqdm(range(1, 23), desc=f"file {index + 1}/{len(files)}"):
        time_start = time.time_ns()
        lf.sink_parquet(
            f"sample/{file}@{comp_level}.parquet",
            compression="zstd",
            compression_level=comp_level,
        )
        time_end = time.time_ns()
        byte_size = os.path.getsize(f"sample/{file}@{comp_level}.parquet")
        time_elapsed = time_end - time_start
        data.append(
            {"size": byte_size, "time": time_elapsed, "level": comp_level, "file": file}
        )
        os.remove(f"sample/{file}@{comp_level}.parquet")

file 1/1: 100%|██████████| 22/22 [00:16<00:00,  1.31it/s]


In [6]:
df = (
    pl.DataFrame(data)
    .group_by(pl.col("level"))
    .agg(pl.mean("time"), pl.mean("size"))
    .sort(pl.col("level"))
    .to_pandas()
)

In [7]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(
    rows=2,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.1,
    subplot_titles=(
        "File Size vs Compression Level",
        "Compression Time vs Compression Level",
    ),
)


fig.add_trace(
    go.Scatter(
        x=df["level"],
        y=df["size"],
        mode="lines+markers+text",
        name="File Size",
        line=dict(color="blue"),
        text=df["level"].astype(str),
        textposition="top center",
        textfont=dict(size=10),
    ),
    row=1,
    col=1,
)


fig.add_trace(
    go.Scatter(
        x=df["level"],
        y=df["time"],
        mode="lines+markers+text",
        name="Compression Time",
        line=dict(color="red"),
        text=df["level"].astype(str),
        textposition="top center",
        textfont=dict(size=10),
    ),
    row=2,
    col=1,
)


fig.update_layout(
    title_text="Compression Statistics by Compression Level",
    height=700,
    hovermode="x unified",
    template="plotly_white",
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
)


fig.update_traces(
    hovertemplate="<b>Level</b>: %{x}<br><b>Value</b>: %{y}<extra></extra>"
)


fig.update_xaxes(title_text="Compression Level", row=2, col=1)
fig.update_yaxes(title_text="File Size", row=1, col=1)
fig.update_yaxes(title_text="Compression Time (ns)", row=2, col=1)


fig.show()