In [None]:
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import product
from tqdm import tqdm
import os
import pandas as pd
import torch
import config
import util
import plotting
from datasets import load_dataset

def process_job(args):
    model_name, dataset_name, model_type, curriculum_name = args
    try:
        influence_output_dir = os.path.join(
            "./influence_mean_normalized",
            os.path.basename(model_name),
            "_".join([(os.path.basename(dataset_name) + "_" + f"train[0%:100%]")]*2)
        )
        print(influence_output_dir)
        dataset = load_dataset(dataset_name)["train"]
        curriculum = util.get_curriculum(dataset_name, curriculum_name)

        df = pd.DataFrame({
            int(result_checkpoint.replace("checkpoint-", "")):
                torch.load(os.path.join(influence_output_dir, result_checkpoint),
                           weights_only=True, map_location="cpu").numpy().flatten()
            for result_checkpoint in os.listdir(influence_output_dir)
        })

        df = df.reindex(sorted(df.columns, reverse=False), axis=1)
        influence_cols = df.columns
        df["total"] = df.sum(axis=1)
        df[["text", "source", "stage"]] = dataset.to_pandas()
        df["document_lenght"] = df["text"].str.split().str.len()

        plotting.plot_per_token_in_order(df[influence_cols.to_list() + ["stage", "document_lenght"]],
                                         curriculum_name, model_type, dataset_name, curriculum)
        plotting.plot_per_token_per_stage(df[influence_cols.to_list() + ["stage", "document_lenght"]],
                                          curriculum_name, model_type, dataset_name, curriculum)

        for influence_curriculum_name in config.influence_curricula:
            print(dataset_name, influence_curriculum_name)
            influence_curriculum_name = model_type + influence_curriculum_name
            influence_curriculum = util.get_curriculum(dataset_name, influence_curriculum_name)
            plotting.plot_per_token_in_order(df[influence_cols.to_list() + ["stage", "document_lenght"]],
                                             influence_curriculum_name, model_type, dataset_name, influence_curriculum)
           

    except Exception as e:
        print("skipping", model_name, dataset_name, model_type, curriculum_name, "due to", str(e))


if __name__ == "__main__":
  

    jobs = [
        (d + "_" + t + "_random", d, t, "random.pt")
        for d, t in product(config.datasets, config.model_types)  
    ]

    with ProcessPoolExecutor() as executor:
        futures = [executor.submit(process_job, job) for job in jobs]
        for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing jobs"):
            pass
