In [1]:
import sys

sys.path.append("..")

import os
from copy import deepcopy
from functools import partial
from itertools import islice
from os import environ
from pickle import dump


from freegroup.sampling import freegroup
from freegroup.sampling.helper import get_rng
from freegroup.tools import Comm, flatten, to_string
from iteration_utilities import repeatfunc, unique_everseen
from sklearn.model_selection import train_test_split
from tokenizer import build_tokenizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 0
rng = get_rng(seed)

fdim = 2
num_pairs = 2

L = 50
# train dataset size
N = int(1e2)

tokenizer = build_tokenizer(
    "word-level",
    fdim=fdim,
    add_commutator_tokens=False,
    add_prompt_tokens=True,
    add_post_processor=True,
)



In [3]:
from freegroup.sampling import CFGNormalClosureSampler


def generate_random_closure(fdim, max_length=5):
    # TODO: avoid trivial closures, maybe freegroup has smth for this?
    length = rng.integers(1, max_length)
    closure = []
    for _ in range(length):
        letter = 0
        while letter == 0:
            letter = rng.integers(-fdim, fdim)
        closure.append(letter)

    return closure


def generate_closure_pairs(fdim, num_pairs=5):
    return [
        (generate_random_closure(fdim), generate_random_closure(fdim)) for _ in range(num_pairs)
    ]


def create_samplers(closure_pairs, fdim):
    samplers = []
    for r, s in closure_pairs:
        R_sampler = CFGNormalClosureSampler.build(closure=r, fdim=fdim)
        S_sampler = CFGNormalClosureSampler.build(closure=s, fdim=fdim)
        samplers.append((R_sampler, S_sampler))

    return samplers

In [4]:
closure_pairs = closure_pairs = [
    [[-1, 2, 2, 1, -2, -2, -2], [-2, 1, 1, 2, -1, -1, -1]],
    [[-1, 2, 1, -2, -2], [-1, -1, -2, 1, 1, 2]]
]
samplers = create_samplers(closure_pairs, fdim)

In [5]:
def get_whitehead_multilabel(label, num_pairs):
    # num_pairs = len(closure_pairs)
    if label.startswith("r"):
        return [int(label[1:])]
    elif label.startswith("s"):
        return [num_pairs + int(label[1:])]
    elif label == "f":
        return []
    elif label == "c":
        return list(range(2 * num_pairs))
    else:
        raise ValueError(f"Unknown label: {label}")

In [41]:
def sample(n_samples, rng, sampler, label):
    def fn():
        length = rng.integers(1, L + 1)
        try:
            word = sampler(length=length, rng=rng)
            return {
                "label": label,
                "multilabel": get_whitehead_multilabel(label, 2),
                "word_str": to_string(word),
            }
        except:
            return None

    iterator = repeatfunc(fn)
    iterator = filter(lambda x: x is not None, iterator)
    iterator = unique_everseen(iterator)
    iterator = islice(iterator, n_samples)

    return list(tqdm(iterator, total=int(n_samples)))

In [42]:
# samplers[0][0](length=rng.integers(1, L + 1), rng=rng)

In [44]:
def sample_freegroup(n_samples=1e3, rng=rng, label="f"):
    def fn():
        length = rng.integers(1, L + 1)
        word = freegroup(2, length, rng=rng)
        return {
            "label": label,
            "multilabel": get_whitehead_multilabel(label, 2),  # if coin else 's',
            "word_str": to_string(word),
        }

    iterator = repeatfunc(fn)
    iterator = unique_everseen(iterator)
    iterator = islice(iterator, n_samples)

    return list(tqdm(iterator, total=int(n_samples)))

In [45]:
def sample_comm(n_samples=1e3, rng=rng, samplers=None, label="c"):
    def fn():
        words = []
        for R_sampler, S_sampler in samplers:
            for sampler in [R_sampler, S_sampler]:
                flag = False
                while not flag:
                    length = rng.integers(1, L // (5 * len(samplers)) + 1)
                    try:
                        word = sampler(length=length, rng=rng)
                        words.append(word)
                        flag = True
                    except:
                        pass

        i, j = rng.choice(len(words), size=2, replace=False)
        word1, word2 = words[i], words[j]

        coin = rng.integers(low=0, high=2)
        if coin:
            result = flatten(Comm([word1, word2]))
        else:
            result = flatten(Comm([word2, word1]))

        return {
            "label": label,
            "multilabel": get_whitehead_multilabel(label, 2),
            "word_str": to_string(result),
        }

    iterator = repeatfunc(fn)
    iterator = unique_everseen(iterator)
    iterator = islice(iterator, n_samples)

    return list(tqdm(iterator, total=int(n_samples)))

In [104]:
dataset = []
for i, (R, S) in enumerate(samplers):
    dataset += sample(N // (2 * num_pairs), rng, R, f"r{i}")
    dataset += sample(N // (2 * num_pairs), rng, S, f"s{i}")

dataset += sample_freegroup(N // 2)
dataset += sample_comm(N // 2, samplers=samplers)

train, test = train_test_split(deepcopy(dataset), test_size=0.1)

100%|██████████| 25/25 [00:00<00:00, 500.42it/s]
100%|██████████| 25/25 [00:00<00:00, 914.89it/s]
100%|██████████| 25/25 [00:00<00:00, 941.36it/s]
100%|██████████| 25/25 [00:00<00:00, 900.65it/s]
100%|██████████| 50/50 [00:00<00:00, 2741.59it/s]
  0%|          | 0/50 [00:00<?, ?it/s]

In [2]:
import pickle
with open('/main/draft-v2/pavel-tikhomirov-runs/fdim-2-whitehead:v0/train.pkl', 'rb') as f:
    data = pickle.load(f)

In [3]:
data[0]

{'label': 'c',
 'multilabel': [0, 1, 2],
 'word_str': '2 2 1 1 1 1 1 1 -2 -2 -1 -2 -1 2 1 1 2 2 -1 -1 -1 -1 -1 -1 -2 -2 -1 -1 -2 1 2 1'}

In [106]:
from utils import to_tensor


def train_collate_fn(batch, tokenizer, fdim, num_pairs):
    words = list(map(lambda x: x["word_str"], batch))
    multilabels = list(map(lambda x: x["multilabel"], batch))

    batch = to_tensor(
        words, tokenizer, padding=True, prompt_multilabels=multilabels, prompt_strategy_fdim=fdim
    )

    print(batch)
    batch["labels"] = batch["input_ids"].clone()
    batch["input_ids"] = batch["input_ids"]
    batch["attention_mask"] = batch["attention_mask"]

    # Avoid predicting <pad>
    batch["labels"][batch["attention_mask"] == 0] = -100
    # Avoid predicting prompt
    prompt_size = 1 + fdim + 1 + 2 * num_pairs  # Start + fdim + delimiter + 2 * number of pairs
    batch["labels"][:, 1:prompt_size] = -100

    return batch

In [107]:
df = DataLoader(
    train,
    16,
    collate_fn=partial(train_collate_fn, tokenizer=tokenizer, fdim=fdim, num_pairs=num_pairs),
)