In [None]:
import pickle
from pathlib import Path

import pandas as pd
import plotly.express as px

from analytics.plotting.common.dataset_histogram import (
    build_countplot,
    build_cum_barplot,
    build_histogram_multicategory_barnorm,
)
from analytics.plotting.common.save import save_plot
from benchmark.arxiv_kaggle.data_generation import ArxivKaggleDataGenerator

%load_ext autoreload
%autoreload 2

In [None]:
# use interactive plotly
interactive = False

In [None]:
arxiv_dataset = ArxivKaggleDataGenerator(
    Path("/Users/robinholzinger/robin/dev/eth/modyn-2/.data/datasets/arxiv_kaggle/"),
    Path("/Users/robinholzinger/robin/dev/eth/modyn-2/.data/datasets/arxiv_kaggle/raw/arxiv.zip"),
)
arxiv_dataset.extract_data(
    Path("/Users/robinholzinger/robin/dev/eth/modyn-2/.data/datasets/arxiv_kaggle/raw/arxiv.zip")
)
arxiv_df = arxiv_dataset.load_into_dataframe(keep_true_category=True)

In [None]:
pickle_path = Path("/Users/robinholzinger/robin/dev/eth/modyn-2/.data/datasets/arxiv_kaggle/raw/arxiv_kaggle.pkl")

In [None]:
pickle.dump(arxiv_df, open(pickle_path, "wb"))

In [None]:
arxiv_df = pickle.load(open(pickle_path, "rb"))

In [None]:
arxiv_df = arxiv_df[arxiv_df["first_version_timestamp"] >= "1990-01-01"]

In [None]:
arxiv_df["category"].unique()

In [None]:
arxiv_df["year"] = arxiv_df["first_version_timestamp"].dt.year

In [None]:
arxiv_df["year"].max()

In [None]:
# number of samples over time
if interactive:
    fig = px.histogram(
        arxiv_df,
        x="first_version_timestamp",
        color="category",
        facet_col="category",
        facet_col_wrap=6,
        height=5000,
        facet_row_spacing=0.001,
    )
    fig.update_yaxes(matches=None, showticklabels=True)
    fig.update_xaxes(showticklabels=True)
    fig.show()

else:
    # polished
    fig1 = build_countplot(
        arxiv_df,
        x="year",
        x_ticks=[y for y in range(1990, 2020 + 1, 5)],
        y_ticks_bins=4,
        height_factor=0.4,
        width_factor=1.0,
        x_label="Sample Time",
        y_label="Num Samples",
    )

    save_plot(fig1, "arxiv_kaggle_samples_over_time")

In [None]:
category_and_years = arxiv_df[["category", "first_version_timestamp"]]
category_and_years["year"] = category_and_years["first_version_timestamp"].dt.year
category_and_years = category_and_years[["category", "year"]].drop_duplicates()
category_and_years = category_and_years.groupby("category").size().reset_index()
category_and_years.columns = ["category", "num_years"]
category_and_years[category_and_years["num_years"] > 9]

In [None]:
arxiv_df_reduced = arxiv_df.merge(category_and_years, on="category")

In [None]:
def find_category_ratios(df: pd.DataFrame) -> pd.DataFrame:
    total_samples = df.shape[0]
    category_counts = df["category"].value_counts().reset_index().sort_values("count", ascending=False)
    category_counts["ratio"] = category_counts["count"] / total_samples
    return category_counts

In [None]:
# Analyse ratio of categories
category_counts = find_category_ratios(arxiv_df_reduced)
category_counts

In [None]:
wordcloud: list[str] = []
for i, category in enumerate(list(category_counts["category"].unique())):
    for _ in range(100 - i):
        wordcloud.append(category)
print(wordcloud)

In [None]:
# Export for thesis table
from analytics.plotting.common.save import save_csv_df

# select top 8 and bottom 2
export_csv = pd.concat([category_counts.head(8)])  # , category_counts.tail(2)
export_csv["ratio"] = export_csv["ratio"].apply(lambda x: round(x * 100, 1))
print(export_csv)

save_csv_df(export_csv, "arxiv_kaggle_category_ratios")

In [None]:
sorted_categories = (category_counts.sort_values("count", ascending=False))["category"]
sorted_categories


arxiv_df_reduced["sort_idx"] = pd.Categorical(arxiv_df_reduced["category"], categories=sorted_categories, ordered=True)
arxiv_df_reduced = arxiv_df_reduced.sort_values("sort_idx", ascending=False)

In [None]:
plotting_threshold = category_counts.reset_index()[["index", "ratio"]]
plotting_threshold["index"] = plotting_threshold["index"] + 1
# add first row: 0
plotting_threshold = pd.concat([pd.DataFrame({"index": [0], "ratio": [0]}), plotting_threshold])

# cumulative sum
plotting_threshold["ratio"] = plotting_threshold["ratio"].cumsum() * 100
plotting_threshold

In [None]:
# Plot coverage of categories
label_hist = build_cum_barplot(
    plotting_threshold,
    x="index",
    y="ratio",
    x_label="Categories",
    y_label="% of Dataset",
    height_factor=0.4,
    width_factor=0.4,
    y_ticks_bins=3,
    x_ticks_bins=4,
)
save_plot(label_hist, "arxiv_kaggle_category_coverage")

In [None]:
# we want to find out the ratio of the dataset that we cover when only
# show the top 24 categories
category_counts.sort_values("ratio", ascending=False).head(n=24)["ratio"].sum()

In [None]:
# legend: find the top 10 labels
labels = category_counts["category"].head(n=10)
fig_labels_distribution = build_histogram_multicategory_barnorm(
    arxiv_df_reduced,
    x="first_version_timestamp",
    label="category",
    sorted_coloring_categories=sorted_categories,
    height_factor=0.65,
    width_factor=1.0,
    legend_labels=list(labels),
    x_label="Sample Time",
    y_label="Label Distribution",
    y_ticks=[1.0, 0.75, 0.5, 0.25, 0.0],
    y_ticks_bins=4,
    x_ticks=[
        pd.to_datetime(d) for d in ["1991-09-01", "1996-07-01", "2000-01-01", "2009-01-01", "2014-01-01", "2018-01-01"]
    ],
    legend_title="Paper Category",
    nbins=60,
)
save_plot(fig_labels_distribution, "arxiv_kaggle_category_relative")

In [None]:
# legend: find the top 10 labels
labels = category_counts["category"][category_counts["category"].str.contains("astro-ph")].head(n=10)
fig_labels_distribution = build_histogram_multicategory_barnorm(
    arxiv_df_reduced[arxiv_df_reduced["category"].str.contains("astro-ph")],
    x="first_version_timestamp",
    label="category",
    sorted_coloring_categories=sorted_categories,
    height_factor=0.65,
    width_factor=1.0,
    legend_labels=list(labels),
    x_label="Sample Time",
    y_label="Label Distribution",
    y_ticks=[1.0, 0.75, 0.5, 0.25, 0.0],
    y_ticks_bins=4,
    x_ticks=[
        pd.to_datetime(d) for d in ["1991-09-01", "1996-07-01", "2000-01-01", "2009-01-01", "2014-01-01", "2018-01-01"]
    ],
    legend_title="Paper Category",
    nbins=60,
)
save_plot(fig_labels_distribution, "arxiv_kaggle_category_relative_astro_ph")