In [None]:
import os

import pandas as pd
import pyrootutils
import seaborn as sns

from formal_gym import grammar as fg_grammar

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

grammar_path = PROJECT_ROOT / "data" / "sample_trim_20241022141559.cfg"
# grammar_path = PROJECT_ROOT / "data" / "sample_raw_20241022141532.cfg"
# grammar_path = PROJECT_ROOT / "data" / "sample_raw_20241022141532_fixed.cfg"

In [None]:
grammar = fg_grammar.ContextFreeGrammar.from_file(grammar_path)

print(grammar.as_pcfg)

In [None]:
NUM_SAMPLES = 500_000

records = [
    {"step": i, "num_samples": i, "case": "expected"} for i in range(NUM_SAMPLES)
]
samples = set()

for i in range(NUM_SAMPLES):
    sample = grammar.generate(max_depth=1000, sep=" ")
    samples.add(sample)
    records.append({"step": i, "num_samples": len(samples), "case": "measured"})

df = pd.DataFrame.from_records(records)
sns.lineplot(data=df, x="step", y="num_samples", hue="case", style="case")

In [None]:
sample_lens = [{"length": len(sample.split())} for sample in samples]
sl_df = pd.DataFrame.from_records(sample_lens)
ax = sns.histplot(data=sl_df, x="length")

ax.set_xscale("log")

In [None]:
samples_path = PROJECT_ROOT / "data" / "samples" / "sample_trim_20241022141559"

pos_path = samples_path / "positive.txt"
neg_path = samples_path / "negative.txt"

samples = []

with open(pos_path, "r") as f:
    for line in f:
        sample = line.strip()
        samples.append(
            {"sample": sample, "length": len(sample.split()), "type": "positive"}
        )

with open(neg_path, "r") as f:
    for line in f:
        sample = line.strip()
        samples.append(
            {"sample": sample, "length": len(sample.split()), "type": "negative"}
        )

samples_df = pd.DataFrame.from_dict(samples)

In [None]:
samples_df["type"] = pd.Categorical(
    samples_df["type"], categories=["negative", "positive"], ordered=True
)

In [None]:
ax = sns.histplot(
    data=samples_df[samples_df["length"] < 100], x="length", hue="type", bins=50
)

ax.set_yscale("log")

In [None]:
grammar.test_sample("t2 t2 t2 t4 t4 t2 t4 t4 t2 t2 t0")

In [None]:
grammar.generate_negative_sample()

In [None]:
def subsample_length_matching(
    samples_df: pd.DataFrame, max_length: int | None = 100
) -> pd.DataFrame:
    if max_length is not None:
        samples_df = samples_df[samples_df["length"] <= max_length]
    min_counts_by_length = samples_df.groupby(["type", "length"]).count().reset_index()

    mc_pivot = (
        min_counts_by_length.pivot(index="length", columns="type", values="sample")
        .fillna(0)
        .astype(int)
    )
    counts_df = pd.DataFrame(
        {"length": mc_pivot.index, "count": mc_pivot.min(axis=1)}
    ).reset_index(drop=True)

    counts_df = counts_df[counts_df["count"] > 0]

    subsampled_dfs = []

    for _, row in counts_df.iterrows():
        length = row["length"]
        count = row["count"]

        length_mask = samples_df["length"] == length
        current_samples = samples_df[length_mask]

        for sample_type in samples_df["type"].unique():
            type_mask = current_samples["type"] == sample_type
            type_samples = current_samples[type_mask]

            n_samples = min(count, len(type_samples))

            if n_samples > 0:
                subsampled_samples = type_samples.sample(count)
                subsampled_dfs.append(subsampled_samples)
    result = pd.concat(subsampled_dfs, ignore_index=True, axis=0)
    return result


def subsample(
    samples_df: pd.DataFrame, max_n: int, max_length: int | None = 100
) -> pd.DataFrame:
    subsampled_dfs = []
    if max_length is not None:
        samples_df = samples_df[samples_df["length"] <= max_length]
    lengths = samples_df["length"].unique()
    sample_types = samples_df["type"].unique()

    for length in lengths:
        for sample_type in sample_types:
            mask = (samples_df["length"] == length) & (
                samples_df["type"] == sample_type
            )
            current_samples = samples_df[mask]
            n = min(max_n, len(current_samples))
            subsampled_samples = current_samples.sample(n)
            subsampled_dfs.append(subsampled_samples)
    result = pd.concat(subsampled_dfs, ignore_index=True, axis=0)
    return result

In [None]:
equal_lengths_df = subsample_length_matching(samples_df)
one_fifty_df = subsample(samples_df, 150)

In [None]:
sns.histplot(data=equal_lengths_df, x="length", hue="type", bins=50)

In [None]:
sns.histplot(
    data=one_fifty_df,
    x="length",
    hue="type",
    bins=50,
)

In [None]:
onefifty_outpath = samples_path / "subsampled_150.csv"

one_fifty_df = one_fifty_df.sort_values(by=["length", "type"])
one_fifty_df.to_csv(onefifty_outpath, index=False)