In [1]:
import sys

sys.path.append('..')

import os
import pickle
import random
from copy import deepcopy
from datetime import datetime
from functools import partial
from itertools import islice
from os import environ
from pickle import dump

import matplotlib.pyplot as plt
import numpy as np
import wandb
from iteration_utilities import repeatfunc, unique_everseen
from sklearn.model_selection import train_test_split
from tokenizer import build_tokenizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils import train_collate_fn

from freegroup.sampling import CFGNormalClosureSampler, freegroup
from freegroup.sampling.helper import get_rng
from freegroup.tools import Comm, flatten, normalize, to_string

os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Generate relations

In [2]:
seed = 0
rng = get_rng(seed)
fdim = 3

In [3]:
n_relations_per_length = 4

In [4]:
def sample_freegroup(fdim=2, n_samples=int(1e3), rng=rng, length=5, tqdm_on=False):
    if length <= 1:
        return [() for _ in range(n_samples)]
    def fn():
        return tuple(freegroup(fdim, length, rng=rng))
    iterator = repeatfunc(fn)
    iterator = unique_everseen(iterator)
    iterator = islice(iterator, n_samples)
    result = set(list(tqdm(iterator, total=int(n_samples), disable=not tqdm_on)))
    # if len(result) < n_samples:
    #     print(f'Warning, there are lack of samples for length={length}')
    return result

In [5]:
relations_all = [(sample_freegroup(fdim, n_relations_per_length, length=length, tqdm_on=False)) for length in tqdm(range(31))]

  0%|          | 0/31 [00:00<?, ?it/s]

In [6]:
relations = {
    # 'extra_tiny': {length: relations_all[length][:2] for length in range(4, len(relations_all), 5)},
    # 'tiny': {length: relations_all[length][:3] for length in range(4, len(relations_all), 5)},
    'small': {length: relations_all[length] for length in range(4, len(relations_all), 5)},
    'medium': {length: relations_all[length] for length in range(4, len(relations_all), 3)},
    'big': {length: relations_all[length] for length in range(4, len(relations_all))},   
    }

In [7]:
for key, val in relations.items():
    print(key, len(val) * len(list(val.values())[0]))

small 24
medium 36
big 108


### Generate data

In [8]:
L = 500

In [9]:
flat_relations_all = [relation for relation_set in relations_all for relation in relation_set]


In [None]:
total_relations = len(flat_relations_all)
with tqdm(total=total_relations) as pbar: 
    for i, relation in enumerate(flat_relations_all):
        if not relation:
            continue
        pbar.set_description(f"Processing relation: {relation} ({len(relation)})")
        sampler = CFGNormalClosureSampler.build(closure=list(relation), fdim=fdim, max_length=L)
        sampler_path = f'/main/whitehead/data/CFG_samplers/meta_model_samplers_L_{L}/sampler_{i:03d}_{",".join(map(str, relation))}_L={L}.pkl'
        with open(sampler_path, 'wb') as f:
            pickle.dump(sampler, f)
        pbar.update(1)
        with open('log_generate_samples.log', 'a') as f:
            print(f'{i:03d} / {len(flat_relations_all)} {datetime.now} {relation} ready! Path={sampler_path}', file=f)

  0%|          | 0/124 [00:00<?, ?it/s]