In [2]:
from itertools import permutations, islice
import numpy as np
import math
import pickle
import torch
from sklearn.preprocessing import OneHotEncoder
import torch 

In [6]:
def num_permutations(N, p):

    '''
    Returns how many length p lists are when the vocabulary size is N.
    @param N: vocabulary size
    @param p: list lenght
    '''

    return int(math.factorial(N) / math.factorial(N-p))

In [8]:
def generate_test_set(test_name, num_lists_per_length, max_frac, protrusions, 
                        ll_range, letters_testing=26, protrusion_ll=6, num_protrusion_lists=300):

    '''
    Generates a test set
    @param test_name: name to save file 
    @param num_lists_per_length: 
    @param max_frac: the max fraction of all lists at a given length that can be reserved
    for testing.
    @param protrusions: if true, split vocabulary into two distinct sets. Form consecutive lists 
    from distinct sets. 
    @param ll_range: list length range
    @param letters_testing: vocabulary size 
    protrusion_ll: list length used to test protrusions
    num_protrusions_list: how many trials to include for testin protrusions
    '''

    test_lists_dict = {}

    letters = np.arange(letters_testing)

    for ll in range(1, ll_range+1, 1):

        # number of possible lists of a given length 
        num_possible_lists = num_permutations(letters_testing, ll)

        # place a max limit on the number of lists reserved for testing
        # max limit is set by the max_frac threshold 
        max_lists_for_testing = int(max_frac*num_possible_lists)

        if max_lists_for_testing < num_lists_per_length:
            num_lists = max_lists_for_testing
        else:
            num_lists = num_lists_per_length

        # if number  of possible lists is a reasonable size,
        # evenly step through all possible permutations 
        if num_possible_lists < 1e8:
            # round up to prevent step size from being 0 
            step_size = math.ceil(num_possible_lists / num_lists)

            test_lists_dict[str(ll)] = list(islice(permutations(letters, ll), 
                                None, None, step_size))

         # if the number of possible lists is too large, use random sampling 
        else:
            test_lists_dict[str(ll)] = []
            for i in np.arange(num_lists):
                rng = np.random.default_rng(seed=i)
                test_lists_dict[str(ll)].append(tuple(rng.choice(letters, ll, replace=False)))

    if protrusions: 

        test_lists_dict_paired = {}
        list_even, list_odd = np.split(letters, 2)
        num_possible_protrusion_lists = num_permutations(list_even.shape[0], protrusion_ll)
        max_lists_for_testing = int(max_frac*num_possible_lists)

        step_size = math.ceil(num_possible_protrusion_lists / num_protrusion_lists)
        
        test_lists_dict_paired = {}
        test_lists_dict_paired['even'] = list(islice(permutations(list_even, protrusion_ll), 
                                                                    None, None, step_size))
        test_lists_dict_paired['odd'] = list(islice(permutations(list_odd, protrusion_ll), 
                                                                    None, None, step_size))

        # add protrusion trials to the normal test set so they are not 
        # included in training 
        p_trials = []
        for x, y in zip(test_lists_dict_paired['even'],test_lists_dict_paired['odd']):
            test_lists_dict[str(protrusion_ll)].append(x)
            test_lists_dict[str(protrusion_ll)].append(y)
            p_trials.append(x)
            p_trials.append(y)

        p_trials_dict = {}
        p_trials_dict['even_odd'] = p_trials

    
    # create a set version of the dict
    # this speeds up access for removing
    # testing trials from training 
    test_lists_set= {}
    for ll in test_lists_dict.keys():
        test_lists_set[str(ll)] = set(test_lists_dict[str(ll)])

    with open(f'{test_name}.pkl', 'wb') as f:
        pickle.dump(test_lists_dict, f)

    with open(f'{test_name}_set.pkl', 'wb') as f:
        pickle.dump(test_lists_set, f)

    if protrusions:
        with open(f'{test_name}_protrusions.pkl', 'wb') as f:
            pickle.dump(p_trials_dict, f)
            
        with open(f'{test_name}_protrusions_set.pkl', 'wb') as f:
            pickle.dump(set(p_trials_dict), f)

    return test_lists_dict, test_lists_dict_paired

In [9]:
test_lists_dict, tl_paired = generate_test_set(test_name = 'test_lists_cleaned_26', 
                                    num_lists_per_length=1500, max_frac=0.3, protrusions=True,
                                    ll_range=9, letters_testing=26) 

In [6]:
list_even, list_odd = np.split(np.arange(26), 2)


array([13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25])