In [1]:
import sys

sys.path.append('..')

import os
import pickle
import random
from copy import deepcopy
from functools import partial
from itertools import islice
from os import environ
from pickle import dump
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import wandb
from iteration_utilities import repeatfunc, unique_everseen
from sklearn.model_selection import train_test_split
from tokenizer import build_tokenizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from utils import train_collate_fn

from freegroup.sampling import CFGNormalClosureSampler, freegroup
from freegroup.sampling.helper import get_rng
from freegroup.tools import Comm, flatten, normalize, to_string

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from freegroup.tools import magnus_is_from_normal_closure
from os import listdir

In [None]:
N_words = 50_000 * 5 #(will reduce to 50_000) #number of words per relation 
L_RELATION = 100 # max length of generated word
number_per_length = N_words // L_RELATION # number of words per length
fdim = 3

In [None]:
samplers_path = Path(f'data/CFG_samplers/meta_model_samplers_L_{L_RELATION}')

In [None]:
dataset_path = Path(f'data/datasets/L_{L_RELATION}_data')

In [13]:
seed = 0
rng = get_rng(seed)

In [14]:
def sample(sampler, length, n_samples=int(1e3), rng=rng, pbar=None):
    """
    return a list of samples for certain sampler
    """
    def fn():
        if pbar is not None:
            pbar.update(1)             
        try:   
            return tuple(map(lambda x: x.item(), sampler(length=length, rng=rng)))
        except:
            return None
    iterator = repeatfunc(fn) 
    iterator = islice(iterator, n_samples)

    return set(filter(lambda x: x is not None, iterator))

In [15]:
sorted(listdir(samplers_path))[0]

'sampler_008_-3,2_L=100.pkl'

In [None]:
k = 0
all_sampled_words = {}
with tqdm(total=len(listdir(samplers_path)) * N_words) as pbar: 
    for path in (sorted(listdir(samplers_path))):
        k +=1
        if k == 5: break
        data = []
        relation = path.split('_')[2]
        sampler_code = path.split('_')[1]
        with open(str(samplers_path / path), 'rb') as f:
            sampler = pickle.load(f)
        for length in range(L_RELATION + 1):
            pbar.set_description(f"Processing sampler: {relation} with l={length}")
            samples = sample(sampler, length, number_per_length, rng=rng, pbar=pbar)
            data.append(samples)
        # with open(str(dataset_path / (sampler_code + '_' + relation)), 'wb') as f:
        #     pickle.dump(data, f)
        all_sampled_words[relation] = data
        

  0%|          | 0/26000 [00:00<?, ?it/s]

### generate commutator

In [77]:
max_pairs = 4
relation_list = list(all_sampled_words.keys())
NOT_EMPTY_LENGTHS = {relation: [i for i in range(len(samples)) if samples[i]]
                     for relation, samples in all_sampled_words.items()}

In [None]:
from freegroup.tools import Comm, normalize, flatten

In [None]:
def choose_random_word_from_rel(rel, n_pairs, words_samples=all_sampled_words):
    max_com_length = int(L_RELATION / 4 / n_pairs)
    length_list = list(filter(lambda x: x <= max_com_length, NOT_EMPTY_LENGTHS[rel]))
    if len(length_list) == 0:
        return None, None
    # length_list = NOT_EMPTY_LENGTHS[rel]
    length = random.choice(length_list)
    word = list(random.choice(list(words_samples[rel][length])))
    return word, length

In [141]:
n_com_words = N_words * len(relation_list)

In [144]:
comm_data = {}
for _ in tqdm(range(n_com_words)):
    n_pairs = np.random.randint(1, max_pairs+1)
    rel1 = random.choice(relation_list)
    rel2 = random.choice(list(set(relation_list).difference({rel1})))
    n_pairs_real = 0
    words = []
    debug_info = []
    general_length = 0
    for _ in range(n_pairs):
        word1, length1 = choose_random_word_from_rel(rel1, n_pairs)
        word2, length2 = choose_random_word_from_rel(rel2, n_pairs)
        if word1 is None or word2 is None:
            continue

        words += flatten(Comm([word1, word2]))
        debug_info.append((word1, word2, length1, length2))
        general_length += (length1 + length2) * 2
        n_pairs_real += 1
    words = tuple(normalize(words))
    final_length = len(words)
    comm_data.setdefault(final_length, set())
    comm_data[final_length].add(words)


  0%|          | 0/2000 [00:00<?, ?it/s]

In [94]:
len(normalize(flatten(Comm([(word1), (word2)]))))

TypeError: unsupported operand type(s) for +: 'NoneType' and 'NoneType'

In [41]:
word2

(1, -2)

In [39]:
Comm([word1, word2])

Comm(children=[(-1, -1, -1, -1, -1, 2, -1, 3, -1, -3, -3, 1, 3, 1, -3, 1, 3, -1, -2, -3, -3, -3, -1, 3, 3, 3, 1, 3, 2, -3, 1, -2, -1, 3, 1, -3, 1, -3, -2, -3, -1, 3, 3, -2, -1, 3, 2, 3, 1, -3, -1, -1, 3, 2, 3, 2, -1, 3, -2, -1, -3, 1, -2, -3, -3, 1, 1, 2, 3, 1, 1, 1, 1, -3, 1, 3, 2, -3, 1, -2, -3, 1, -3, -1, -3, 1), (1, -2)])

In [34]:
def sample_freegroup(fdim=2, n_samples=int(1e3), rng=rng, length=5, tqdm_on=False):
    if length <= 1:
        return set()
    def fn():
        return tuple(freegroup(fdim, length, rng=rng))
    iterator = repeatfunc(fn)
    # iterator = unique_everseen(iterator)
    iterator = islice(iterator, n_samples)
    result = set(list(tqdm(iterator, total=int(n_samples), disable=not tqdm_on)))
    # set(filter(lambda x: x is not None, iterator))
    return result

In [None]:
data = []
for length in tqdm(range(L_RELATION + 1)):
    data.append(sample_freegroup(3, number_per_length, rng=rng, length=length))
with open(str(dataset_path / (sampler_code + '_' + '0')), 'wb') as f:
    pickle.dump(data, f)

  0%|          | 0/101 [00:00<?, ?it/s]