# Load Seqs from the Pile

In [44]:
import sys, os, functools, torch
sys.path.append('/home/jupyter/')
#from paraMem.utils import metrics, helpers
ROOT = "/home/jupyter"

import transformer_lens
import numpy as np
from pathlib import Path
from toolz import compose

In [15]:
DEVICE = "cpu"
model = transformer_lens.HookedTransformer.from_pretrained("gpt-neo-125M").to(DEVICE)

Loaded pretrained model gpt-neo-125M into HookedTransformer
Moving model to device:  cpu


## Prefixes

https://github.com/ethz-spylab/lm_memorization_data


In [43]:
def load_pile_seqs(seq_length=100, select_idcs:torch.LongTensor=None):
    
    data_folder = ROOT + "/paraMem/data/lm_mem"
    data_files = os.listdir(data_folder)
    prompts, counts = None, None    
    for f in data_files:
        f_seq_length = f.split("_")[1].split(".")[0]
        if f_seq_length == str(seq_length):
            f_type = f.split("_")[0]
            if "counts" in f_type: ## number of prefixes (and their counts)
                counts = np.load(Path(data_folder) / f)
            elif "prompts" in f_type: ## number of prefixes, sequence length
                #prompts = np.memmap(Path(data_folder) / f)
                prompts = np.load(Path(data_folder) / f)
    if select_idcs is not None:
        prompts, counts = prompts[select_idcs.numpy()], counts[select_idcs.numpy()]
    return prompts, counts


def preprocess_pile_seqs(seq_length=100, n_seqs=None, uni_tok_frac=0.0, filter_toks:list=None, count_ranges=None, k_uniform=None):

    prompts, counts = load_pile_seqs(seq_length=seq_length)
    data_indices = np.array(range(0,len(prompts)))
    metadata={"seq_length":seq_length, "uni_tok_frac":uni_tok_frac, "filter_toks": filter_toks.to("cpu"), "count_ranges": count_ranges, "k_uniform":k_uniform}

    ## (1) pre-filtering________________________________________
    
    non_zero_idcs = np.argwhere(counts>0).squeeze() ## remove sequences with 0 count
    min_unique_toks = int(uni_tok_frac * prompts.shape[-1])
    n_unique_toks = np.apply_along_axis(compose(len, np.unique), 1, prompts)
    non_unique_idcs = np.argwhere(n_unique_toks>min_unique_toks).squeeze() ## remove sequences where all tokens are the same

    keep_idcs = np.array(list(functools.reduce(np.intersect1d, (non_unique_idcs, non_zero_idcs))))
    prompts, counts, data_indices = prompts[keep_idcs], counts[keep_idcs], data_indices[keep_idcs]
    
    ## (2) token filters___________________________________________


    if filter_toks is not None:
        eos_token_id = 50256
        print(f"filtering tokens and eos: {eos_token_id}")
        filter_toks_mask = torch.where(filter_toks==eos_token_id, 0, 1) ## remove EOS tokens
        filter_sets = [set(torch.masked_select(filter_toks[i,:], filter_toks_mask.bool()[i,:]).tolist()) for i in range(filter_toks.shape[0])]
        keep_idcs = [i for i, prompt in enumerate(prompts) if not any(filter_seq.issubset(set(prompt.tolist())) for filter_seq in filter_sets)]
        prompts, counts, data_indices = prompts[keep_idcs], counts[keep_idcs], data_indices[keep_idcs]
        
    ## (3) sampling___________________________________________

    if k_uniform is not None:
        print(f"k_uniform sampling {k_uniform}")
        #unique_values, counts = np.unique(a, return_counts=True) # Count occurrences of each unique value in the array
        #valid_indices = np.where(np.isin(a, unique_values[counts >= k]))[0] # Filter indices of values that occur at least k times
        unique_values, value_counts = np.unique(counts, return_counts=True)
        valid_values = unique_values[value_counts >= k_uniform]
        mask = np.zeros_like(counts, dtype=bool)
        for val in valid_values:
            indices = np.where(counts == val)[0]
            mask[indices[:k_uniform]] = True
            keep_idcs = np.nonzero(mask)
        prompts, counts, data_indices = prompts[keep_idcs], counts[keep_idcs], data_indices[keep_idcs]
        
    if count_ranges is not None:
        print(f"count_ranges {count_ranges}")
        keep_idcs = list()
        for count_range in count_ranges:
            keep_idcs += np.where((counts >= count_range[0]) & (counts <= count_range[1]))[0].tolist()
        keep_idcs = np.array(list(set(keep_idcs)))
        prompts, counts, data_indices = prompts[keep_idcs], counts[keep_idcs], data_indices[keep_idcs]
        
    ## (4) shuffling and select_________________________________________
    
    #np.random.seed(0)
    shuffle_idcs = np.arange(len(counts))
    np.random.shuffle(shuffle_idcs)
    prompts, counts, data_indices = prompts[shuffle_idcs], counts[shuffle_idcs], data_indices[shuffle_idcs]  
    
    data_indices = torch.LongTensor(data_indices[:n_seqs])
    data_indices.meta = metadata
    return torch.tensor(prompts[:n_seqs]), torch.tensor(counts[:n_seqs]), data_indices

filter_toks = model.to_tokens([" license", " License", " LICENSE", " copyright", " Copyright", " COPYRIGHT"])
prompts, counts, idcs = preprocess_pile_seqs(seq_length=100, n_seqs=500, uni_tok_frac=0.5, filter_toks=filter_toks, k_uniform=None)

most_frequent = (counts.argsort())[-1]
print(f"count:{counts[most_frequent]}, prompt{model.to_string(prompts[most_frequent])}")

filtering tokens and eos: 50256
count:58251, prompt SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SU


In [54]:
def load_pile_splits(folder:str, as_torch:bool=False):
    folder_path = Path(ROOT + "/paraMem/data/pile_splits") / str(folder)
    mem_idcs, non_mem_idcs = torch.load(folder_path / "mem.pt"), torch.load(folder_path / "non_mem.pt")
    seq_length = int(mem_idcs.meta["seq_length"]) ## read off seq length
    prompts, counts = load_pile_seqs(seq_length=seq_length)
    if as_torch:
        prompts, counts = torch.LongTensor(prompts), torch.LongTensor(counts)
    mem_prompts_counts = (prompts[mem_idcs], counts[mem_idcs])
    non_mem_prompts_counts = (prompts[non_mem_idcs], counts[non_mem_idcs])
    return mem_prompts_counts, non_mem_prompts_counts

(mem_prompts,mem_counts),(non_mem_prompts,non_mem_counts) = load_pile_splits("mock_mem_em", as_torch=True)