# Notebook to create the filters

## Imports

In [2]:
from transformers import AutoTokenizer
import json
import os
import glob
import json
import numpy as np
from nltk import ngrams
from collections import Counter
import pickle

In [3]:
path_to_data = "data/maryland_ngram2_seed3.jsonl" # data from which to compute the filter
path_to_model =  "" # path to model to use for tokenization
path_to_filter = "." # where to save the filter labels_values.pkl
window_size = 2 # n-gram size for the filter computation

def count_ngrams(toks, n):
    # Generate n-grams
    n_grams = list(ngrams(toks, n))
    # Count the frequency of each n-gram
    n_gram_counts = Counter(n_grams)
    return n_gram_counts

def parallel_count_ngrams(toks, n, num_processes):
    # Split the toks into chunks for each process
    chunks = [toks[ii::num_processes] for ii in range(num_processes)]
    # Create a pool of processes
    with mp.Pool(processes=num_processes) as pool:
        # Use starmap to apply the count_ngrams function to each chunk
        results = pool.starmap(count_ngrams, [(chunk, n) for chunk in chunks])
    # Combine the results from each process
    combined_counts = sum(results, Counter())
    return combined_counts

def process_file(args, mp_mode=False, prop=1):
    """
    Args:
        args: (jsonl_path, nn, save_dir)
        mp_mode: whether to use multiprocessing
    Returns:
        labels: list of ngrams
        values: list of counts for each ngram
    """
    jsonl_path, nn, save_dir = args
    print("loading data and tokenizing...")
    tokenizer = AutoTokenizer.from_pretrained(path_to_model)
    result_key = "text"
    with open(jsonl_path, 'r') as f:
        data = [json.loads(line) for line in f]
        taille = len(data)
        new_length = prop*taille
        data = [item["input"] + item["output"] for i,item in enumerate(data) if i<=new_length]

    toks = []
    for item in data:
        toks.extend(tokenizer.encode(item))

    print("counting ngrams...")
    if mp_mode:
        n_gram_counts = parallel_count_ngrams(toks, nn, 4)
    else:
        # Generate n-grams
        n_grams = list(ngrams(toks, nn))
        # Count the frequency of each n-gram
        n_gram_counts = Counter(n_grams)
    labels, values = zip(*n_gram_counts.items())

    # order by frequency
    labels, values = zip(*sorted(zip(labels, values), key=lambda item: item[1], reverse=True))
    print(labels[:100])
    print(values[:100])
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    # save labels and values
    save_path = os.path.join(save_dir, "filter.pkl")
    with open(save_path, "wb") as f:
        pickle.dump((labels, values), f)
    
    return labels, values


# Saves a pkl that contains the ngrams and their counts
labels, values = process_file((path_to_data, window_size, path_to_filter), mp_mode=False, prop=1)

loading data and tokenizing...


counting ngrams...
((29901, 13), (13, 2277), (29889, 13), (25580, 29962), (13, 29961), (14350, 263), (29962, 13), (7128, 2486), (263, 2933), (1614, 2167), (2167, 278), (2933, 393), (278, 2009), (13291, 29901), (2799, 4080), (4080, 29901), (518, 25580), (393, 7128), (1, 518), (29962, 14350), (2486, 1614), (2277, 2799), (29961, 29914), (29914, 25580), (2277, 13291), (3030, 29889), (278, 3030), (2183, 278), (15228, 29901), (2009, 2183), (2277, 15228), (29889, 1), (29892, 322), (29915, 29879), (2009, 29889), (13, 1576), (13, 29908), (243, 162), (13, 13), (310, 278), (13, 5631), (5631, 403), (297, 278), (363, 263), (29871, 29896), (29889, 450), (29871, 243), (403, 385), (13, 29909), (373, 278), (13, 1184), (1184, 29894), (29894, 680), (310, 263), (29892, 13), (29973, 13), (297, 263), (304, 278), (680, 263), (29871, 29906), (338, 263), (13, 29954), (6113, 263), (13, 6113), (29892, 278), (411, 263), (304, 263), (13, 29896), (1213, 1), (363, 278), (29892, 306), (13, 29902), (29908, 1), (13, 40