In [None]:
import os
import polars as pl
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
kaggle_run_type = os.environ.get('KAGGLE_KERNEL_RUN_TYPE', None)
if kaggle_run_type:
    DATA_PATH = "/kaggle/input/linking-writing-processes-to-writing-quality"
else:
    DATA_PATH = "../../data"

In [None]:
logs = pl.read_csv(f"{DATA_PATH}/train_logs.csv")
scores = pl.read_csv(f"{DATA_PATH}/train_scores.csv")

In [None]:
def agg_log_data(essay_id, resolution_ms, max_events_clip, max_time_ms):
    df = (
        logs
        .filter(pl.col("id").eq(essay_id))
        .sort("down_time")
        .group_by_dynamic("down_time", every=f"{resolution_ms}i")
        .agg(
            pl.col("id").filter(pl.col("activity").eq("Input") | pl.col("activity").eq("Paste") | pl.col("activity").eq("Replace")).count().alias("input_events"),
            pl.col("id").filter(pl.col("activity").eq("Remove/Cut") | pl.col("activity").eq("Replace")).count().alias("remove_events"),
            pl.col("id").filter(pl.col("activity").eq("Nonproduction")).count().alias("nonproduction_events"),
            pl.col("word_count").last()
        )
    )
    df = (
        pl.DataFrame({"down_time": pl.int_range(0, max_time_ms, step=resolution_ms, eager=True)})
        .join(df, how="left", on="down_time")
        .filter(pl.col("down_time").le(max_time_ms))
        .with_columns(
            pl.col("down_time").floordiv(resolution_ms),
            pl.col("input_events").fill_null(0).clip(upper_bound=max_events_clip),
            pl.col("remove_events").fill_null(0).clip(upper_bound=max_events_clip).mul(-1),
            pl.col("nonproduction_events").fill_null(0).clip(upper_bound=max_events_clip),
            pl.col("word_count").fill_null(strategy="forward")
        )
    )
    return(df)


In [None]:
def sample_sparklines(
        scores_per_column_dict,
        essays_to_sample_for_rows=5,
        resolution_ms=5_000,
        max_events_clip=60,
        max_word_count_clip=600,
        max_time_ms=2_000_000,
        plot_width=800, row_height=100):
    
    column_titles = [k for k in scores_per_column_dict.keys()]
    scores_per_column_list = [v for v in scores_per_column_dict.values()]

    cols = len(scores_per_column_list)
    rows = essays_to_sample_for_rows

    single_spec = {"secondary_y": True}
    specs = [[single_spec for _ in range(cols)] for _ in range(rows)]

    fig = make_subplots(
        rows = rows, cols = cols,
        column_titles = column_titles,
        specs=specs,
        horizontal_spacing=0.005,
        shared_xaxes="all", shared_yaxes="all"
    )
    subplot_width = plot_width / cols
    for col in range(cols):
        essay_ids = scores_per_column_list[col].sample(rows)["id"]
        for row in range(rows):
            essay_id = essay_ids[row]
            df = agg_log_data(essay_id, resolution_ms, max_events_clip, max_time_ms)
            
            fig.add_trace(
                go.Bar(
                    x=df["down_time"],
                    y=df["input_events"],
                    marker_color="green",
                ), secondary_y=False, row=row+1, col=col+1,
            )

            fig.add_trace(
                go.Bar(
                    x=df["down_time"],
                    y=df["remove_events"],
                    marker_color="red",
                ), secondary_y=False, row=row+1, col=col+1,
            )

            fig.add_trace(
                go.Scatter(
                    x=df["down_time"],
                    y=df["word_count"],
                    mode="lines",
                    marker_color="black",
                    opacity=0.5,
                ), secondary_y=True, row=row+1, col=col+1,
            )

    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    fig.update_yaxes(range=[0, max_word_count_clip], secondary_y=True, matches="y2")
    fig.update_layout(
        barmode="relative",
        showlegend=False,
        plot_bgcolor="white",
        autosize=False,
        width=plot_width,
        height=row_height*rows,
        margin=dict(t=50,l=10,b=10,r=10)
    )

    return(fig)

In [None]:
scores_per_column_dict = {
    "score = [0.5, 1.5]": scores.filter(pl.col("score").is_between(0.5, 1.5)),
    "score = [2.0, 3.0]": scores.filter(pl.col("score").is_between(2.0, 3.0)),
    "score = [3.5, 4.5]": scores.filter(pl.col("score").is_between(3.5, 4.5)),
    "score = [5.0, 6.0]": scores.filter(pl.col("score").is_between(5.0, 6.0)),
}
fig = sample_sparklines(scores_per_column_dict, plot_width=1600, essays_to_sample_for_rows=10)
fig.show()