In [2]:
import sys

sys.path.append('..')

import os
import pickle
import random
from copy import deepcopy
from functools import partial
from itertools import islice
from os import environ
from pickle import dump
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import wandb
from iteration_utilities import repeatfunc, unique_everseen
from sklearn.model_selection import train_test_split
from tokenizer import build_tokenizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils import train_collate_fn

from freegroup.sampling import CFGNormalClosureSampler, freegroup
from freegroup.sampling.helper import get_rng
from freegroup.tools import Comm, flatten, normalize, to_string

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
from freegroup.tools import magnus_is_from_normal_closure
from os import listdir

In [None]:
N_words = 50_000 * 5 #(will reduce to 50_000)
L_relation = 100
number_per_length = N_words // L_relation
fdim = 3

In [43]:
samplers_path = Path(f'data/CFG_samplers/meta_model_samplers_L_{L_relation}')

In [44]:
dataset_path = Path(f'data/datasets/L_{L_relation}_data')

In [7]:
seed = 0
rng = get_rng(seed)

In [46]:
def sample(sampler, length, n_samples=int(1e3), rng=rng, pbar=None):
    """
    return a list of samples for certain sampler
    """
    def fn():
        if pbar is not None:
            pbar.update(1)             
        try:   
            return tuple(map(lambda x: x.item(), sampler(length=length, rng=rng)))
        except:
            return None
    iterator = repeatfunc(fn) 
    iterator = islice(iterator, n_samples)

    return set(filter(lambda x: x is not None, iterator))

In [49]:
sorted(listdir(samplers_path))[0]

'sampler_008_-3,2_L=100.pkl'

In [50]:
with tqdm(total=len(listdir(samplers_path)) * N_words) as pbar: 
    for path in (sorted(listdir(samplers_path))):
        data = []
        relation = path.split('_')[2]
        sampler_code = path.split('_')[1]
        with open(str(samplers_path / path), 'rb') as f:
            sampler = pickle.load(f)
        for length in range(L_relation + 1):
            pbar.set_description(f"Processing sampler: {relation} with l={length}")
            samples = sample(sampler, length, number_per_length, rng=rng, pbar=pbar)
            data.append(samples)
        with open(str(dataset_path / (sampler_code + '_' + relation)), 'wb') as f:
            pickle.dump(data, f)
        break
        

  0%|          | 0/13000000 [00:00<?, ?it/s]

In [34]:
def sample_freegroup(fdim=2, n_samples=int(1e3), rng=rng, length=5, tqdm_on=False):
    if length <= 1:
        return set()
    def fn():
        return tuple(freegroup(fdim, length, rng=rng))
    iterator = repeatfunc(fn)
    # iterator = unique_everseen(iterator)
    iterator = islice(iterator, n_samples)
    result = set(list(tqdm(iterator, total=int(n_samples), disable=not tqdm_on)))
    # set(filter(lambda x: x is not None, iterator))
    return result

In [54]:
data = []
for length in tqdm(range(L_relation + 1)):
    data.append(sample_freegroup(3, number_per_length, rng=rng, length=length))
with open(str(dataset_path / (sampler_code + '_' + '0')), 'wb') as f:
    pickle.dump(data, f)

  0%|          | 0/101 [00:00<?, ?it/s]

In [51]:
data

KeyboardInterrupt: 

In [81]:
path.split('_')[1]

'008'

In [34]:
list(map(lambda x: x.item(), sampler(6, rng=rng)))

[3, 2, -3, -3, -3, 2]

In [5]:

with open('/main/whitehead/data/CFG_samplers/meta_model_samplers_L_200_v2/sampler_018_-1,-2,3,3_L=200.pkl', 'rb') as f:
    sampler = pickle.load(f)

In [8]:
word = sampler(200, rng=rng)

In [10]:
magnus_is_from_normal_closure(word, (-1,-2,3,3))

True