In [None]:


from itertools import product
from datasets import load_dataset,load_from_disk
import config
import os
import pandas as pd
import util
import torch
import plotting
from scipy.stats import power_divergence

import config
from itertools import product



In [None]:
import os
import pandas as pd
import torch
from itertools import product
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
import threading
from multiprocessing import Manager
from scipy.spatial.distance import jensenshannon
from scipy.stats import entropy
stats_lock = threading.Lock()

# Use Manager to create shared cache and stats list
manager = Manager()
dataset_cache = manager.dict()
df_cache = manager.dict()
stats = manager.list()  # Shared list to store stats from all processes

def get_cached_dataset(dataset_name):
    if dataset_name not in dataset_cache:
        print(f"[Dataset] Loading: {dataset_name}")
        dataset_cache[dataset_name] = load_dataset(dataset_name)["train"].select_columns("stage")
    return dataset_cache[dataset_name]

def get_prepared_df(dataset_name, model_type):
    cache_key = (dataset_name, model_type)

    if cache_key not in df_cache:
        print(f"[DataFrame] Preparing for: {dataset_name} / {model_type}")
        
        dataset = get_cached_dataset(dataset_name)
        model_name = dataset_name + "_" + model_type + "_random"
        
        influence_output_dir = os.path.join(
            "./influence_mean_normalized",
            os.path.basename(model_name),
            "_".join([(os.path.basename(dataset_name) + "_" + "train[0%:100%]")] * 2),
        )
        
        df = pd.DataFrame({
            int(result_checkpoint.replace("checkpoint-", "")): torch.load(
                os.path.join(influence_output_dir, result_checkpoint),
                weights_only=True,
                map_location="cpu"
            ).numpy().flatten()
            for result_checkpoint in os.listdir(influence_output_dir)
        })

        df = df.reindex(sorted(df.columns), axis=1)
        df["stage"] = dataset.to_pandas()

        df_cache[cache_key] = df
        print(f"[DataFrame] Cached for: {dataset_name} / {model_type}")
    return df_cache[cache_key]

# Generate combinations
all_combinations = []

for model_type in config.model_types: 
    curricula = [model_type + c for c in config.influence_curricula] + config.baseline_curricula
    for a, b in product(curricula, curricula):
        for dataset in config.datasets:
            all_combinations.append((dataset, model_type, a, b))
print(all_combinations)
def process_combination(combination):
    dataset_name, model_type, curriculum_a_name, curriculum_b_name = combination
    local_stats = []

    try:
        df = get_prepared_df(dataset_name, model_type)

        epsilon = 1e-100
        bins = 1000

        curriculum_a = util.get_curriculum(dataset_name, curriculum_a_name)
        curriculum_b = util.get_curriculum(dataset_name, curriculum_b_name)

        examples_a = df.iloc[torch.cat(curriculum_a).flatten()]
        examples_b = df.iloc[torch.cat(curriculum_b).flatten()]

        all_stages = examples_a["stage"].unique()

        max_len = min(len(examples_a), len(examples_b))
        bin_size = max_len // bins

        chunks_to_compare = [
            (examples_a["stage"][i:i+bin_size], examples_b["stage"][i:i+bin_size])
            for i in range(0, max_len, bin_size) if i + bin_size < max_len
        ]
        
        local_stats.extend([
            (
                idx, curriculum_a_name, curriculum_b_name, dataset_name,
                model_type, len(a), len(b),
                jensenshannon(
                    p=a.value_counts().reindex(all_stages, fill_value=epsilon),
                    q=b.value_counts().reindex(all_stages, fill_value=epsilon),
                    
                )**2,
                entropy(pk=a.value_counts().reindex(all_stages, fill_value=epsilon)),
                
                
                
            )
            for idx, (a, b) in enumerate(chunks_to_compare)
        ])

    except Exception as e:
        print(f"[Error] Skipping: {dataset_name}, {model_type}, {curriculum_a_name}, {curriculum_b_name}")
        print("Reason:", str(e))
  
    # Append local stats to the shared stats list safely
    with stats_lock:
        stats.extend(local_stats)

# Prepare interleaved batches
import random
from itertools import chain, zip_longest
from collections import defaultdict


first_batch = all_combinations[:3]
remaining_batch = all_combinations[3:]

# Use ThreadPoolExecutor for the first batch
with ThreadPoolExecutor(max_workers=3) as executor:
    futures = [executor.submit(process_combination, c) for c in first_batch]
    for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing (3 workers)") :
        pass
print(stats)

print(len(stats))
with ProcessPoolExecutor(max_workers=32) as executor:
    futures = [executor.submit(process_combination, c) for c in remaining_batch]
    for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing (64 workers)") :
        pass
print(len(stats))





In [None]:
stats_df = pd.DataFrame(list(stats),columns=["chunk","curriculum_a", "curriculum_b","dataset", "model_type", "chunk_size_a", "chunk_size_b", "stat","entropy"])


KeyboardInterrupt: 

In [None]:
config.influence_curricula

In [None]:
exclude = ["source_difficulty", "_influence_top_50_cp_shuffled", "_influence_epoch_repetition"]

In [None]:
stats_df

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Assuming stats_df is your data frame, you can calculate the std for the 'power_divergence_stat' column
std_dev = stats_df['stat'].std()

sns.set_context("talk")
sns.set(font_scale=1.5)

g = sns.FacetGrid(stats_df, row="model_type", height=8, sharex=False)

def plot_heatmap(data, **kwargs):
    pivot_df = data.pivot_table(index="curriculum_a", 
                                columns="curriculum_b", 
                                values="stat", 
                                aggfunc='mean')
    pivot_df = pivot_df.fillna(float('nan'))
    pivot_df = pivot_df.rename(index=util.rename, columns=util.rename)
    pivot_df = pivot_df.sort_index(axis=0).sort_index(axis=1)


    pivot_df = pivot_df[pivot_df.index.isin([util.rename(c) for c in config.baseline_curricula])]
    pivot_df = pivot_df[[col for col in pivot_df.columns if col not in [util.rename(c) for c in config.baseline_curricula]]]


    sns.heatmap(pivot_df, annot=True, cmap="YlGnBu", fmt=".2f", cbar_kws=dict(use_gridspec=True, location="top", fraction=0.089, pad=0.04,label=r"mean JSD"),cbar=True, robust=True, **kwargs)

g.map_dataframe(plot_heatmap)
g.set_titles(col_template="{col_name}", fontsize=16, fontweight="bold")
g.set_axis_labels("expected", "observed", fontsize=16)

g.tick_params(axis='y', rotation=0, which='major', labelsize=24)
g.tick_params(axis='x', rotation=90, which='major', labelsize=24)
for ax in g.axes.flat:
    ax.tick_params(axis='both', which='major', labelsize=19)
    ax.set_facecolor("white")
    ax.set_aspect('equal', 'box')

plt.tight_layout()
plt.savefig(os.path.join("./autogenerated_figures", "source_distribution_heatmap_both.pdf"), dpi=600, bbox_inches='tight')

plt.show()


In [None]:
stats_df

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm, Normalize


    

fig, ax = plt.subplots(figsize=(12, 14))
data = stats_df[stats_df["model_type"] == "llama"]
pivot_df = data.pivot_table(index="curriculum_b", 
                                columns="curriculum_a", 
                                values="stat", 
                                aggfunc='mean',fill_value=np.nan)
display(pivot_df)

pivot_df = pivot_df.rename(index=util.rename, columns=util.rename)
# pivot_df = pivot_df.sort_index(axis=0).sort_index(axis=1)

# Set vmax to 3 standard deviations

pivot_df = pivot_df[pivot_df.index.isin([util.rename(c) for c in config.baseline_curricula])]

# pivot_df = pivot_df[[ (col not in [util.rename(c) for c in config.baseline_curricula])]]
pivot_df = pivot_df.T
mask = np.triu(np.ones_like(pivot_df, dtype=bool), k=1)
sns.heatmap(pivot_df, ax=ax, annot=True, cmap="YlGnBu", fmt=".2f", cbar=True, 
                    cbar_kws=dict(use_gridspec=True, location="top", fraction=0.02, pad=0.02,label=r"mean JSD"))

ax.set(xlabel="", ylabel="")

ax.tick_params(axis='y', rotation=0, which='major', labelsize=24)
ax.tick_params(axis='x', rotation=0, which='major', labelsize=24)
# Increase label sizes for both x and y axes
ax.set_xlabel(ax.get_xlabel(), fontsize=25)
ax.set_ylabel(ax.get_ylabel(), fontsize=25)
ax.set_aspect('equal', 'box')
ax.set_facecolor("white")
plt.tight_layout(pad=3)
plt.savefig(os.path.join("./autogenerated_figures", "source_distribution_llama.pdf"), dpi=600, bbox_inches='tight')

plt.show()


In [None]:
d = stats_df[["chunk","dataset", "model_type", "curriculum_a","entropy"]].groupby(["dataset", "model_type", "curriculum_a"]).mean().reset_index()
d

In [None]:
benchmark_results = pd.read_pickle("./benchmark_results.pkl")
benchmark_results

In [None]:
benchmark_results.sort_values(by="average_improvement")

In [None]:
d["curriculum"] = d["curriculum_a"].apply(util.rename)
stats_df_merged = benchmark_results.merge(d, left_on="curriculum", right_on="curriculum")
stats_df_merged

In [None]:
stats_df_merged["entropy"].corr(stats_df_merged["average_improvement"])

In [None]:
import scipy.stats as stats

stats_df_merged
entropy = stats_df_merged["entropy"]
average_improvement = stats_df_merged["model_acc"]
label = stats_df_merged["model"]

corr_coefficient, p_value = stats.pearsonr(entropy, average_improvement)

print(f"Correlation Coefficient: {corr_coefficient}")
print(f"P-value: {p_value}")


In [None]:
stats_df_merged

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_palette("colorblind")

unique_curriculum = stats_df_merged['curriculum'].unique()

available_markers = ['o', 's', '^', 'D', 'P', '*', 'X', 'H', 'v', '<', '>']

markers = {unique_curriculum[i]: available_markers[i % len(available_markers)] for i in range(len(unique_curriculum))}

sns.scatterplot(data=stats_df_merged, x="entropy", y="model_acc", hue="curriculum", style="curriculum", markers=markers)

sns.regplot(data=stats_df_merged, x="entropy", y="model_acc", scatter=False, color='black', line_kws={'lw': 2, 'ls': '--'})

plt.xlabel("Entropy")
plt.ylabel("Acc")

plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3)

plt.show()


In [None]:
sns.barplot(d.sort_values(by="entropy", ascending=False), x="curriculum_a", y="entropy")
plt.xticks(rotation=90)


In [None]:
stats_df

In [None]:
pivot_df

In [None]:
pivot_df["$C_{rand}$"].sort_values()

In [None]:
pivot_df[pivot_df.index.isin(['$C_{MATTR}$',
       '$C_{PPL}$', '$C_{\searrow}$', '$C_{rand}$', '$C_{source}$'])]

In [None]:
next(iter(stats))[-1]