# Chance level calculations for Rule Extrapolation

In [9]:
import torch
import numpy as np
from itertools import product
import math
from tqdm import tqdm

In [10]:
#@title Code for grammar rules

# these are always used
SOS_token = np.array([0])
EOS_token = np.array([1])
PAD_token = np.array([2])

# only for aNbNcN and variants
A_token = np.array([3])
B_token = np.array([4])
C_token = np.array([5])

# only for parentheses and brackets
OPENING_PARENTHESIS_token = np.array([3])
CLOSING_PARENTHESIS_token = np.array([4])
OPENING_BRACKET_token = np.array([5])
CLOSING_BRACKET_token = np.array([6])



import dataclasses
from typing import Dict


# to_dict: creates a dictionary: {'as_before_bs_accuracy': 0.0, 'as_before_bs_completion_accuracy':0.0, etc}
@dataclasses.dataclass
class GrammarMetrics:
    rule_2_accuracy: float = 0.0
    rule_2_completion_accuracy: float = 0.0

    rule_1_accuracy: float = 0.0
    finished_accuracy: float = 0.0
    grammatical_accuracy: float = 0.0

    def to_dict(self) -> Dict[str, float]:
        return dataclasses.asdict(self)


# generates aNbN grammar: all sequences, all even, all odd or sequences of random length and num_samples number
def generate_aNbN_grammar_data(
    num_samples: int,
    max_length: int = 32,
    all_sequences: bool = True,
    only_even: bool = False,
    only_odd: bool = False,
) -> list:
    """
    PCFG with two rules:
    - number of a's and b's must be the same
    - a's come first, followed by b's

    :param only_even: generates only sequences with even number of a's and b's
    :param only_odd: generates only sequences with odd number of a's and b's
    :param all_sequences: generates all sequences up to max_length (i.e., the longest will have max_length // 2 a's and b's)
    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length

    """

    if all_sequences + only_even + only_odd > 1:
        raise ValueError("Only one of all_sequences, only_even, only_odd can be True")

    if all_sequences is True:
        lengths = np.linspace(
            start=1, stop=max_length // 2, num=max_length // 2, dtype=int, endpoint=True
        )
    elif only_even is True:
        lengths = np.array(list(range(2, max_length // 2 + 1, 2)))
    elif only_odd is True:
        lengths = np.array(list(range(1, max_length // 2 + 1, 2)))
    else:
        lengths = np.random.randint(low=1, high=max_length // 2 + 1, size=num_samples)

    data = []

    for length in lengths:
        data.append(
            np.concatenate(
                (
                    SOS_token,
                    A_token * np.ones(length),
                    B_token * np.ones(length),
                    EOS_token,
                )
            )
        )

    return data  # list containing the sequences of max length max_length+2


def generate_aNbNaN_grammar_data(
    num_samples: int, max_length: int = 32, all_sequences=True
) -> list:
    """
    PCFG with two rules:
    - number of a's is twice the number of b's
    - N a's come first, followed by N b's, then N a's again

    :param all_sequences:
    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length

    """
    if all_sequences is True:
        lengths = np.linspace(
            start=1, stop=max_length // 3, num=max_length // 3, dtype=int, endpoint=True
        )
    else:
        lengths = np.random.randint(low=1, high=max_length // 3 + 1, size=num_samples)

    data = []

    for length in lengths:
        data.append(
            np.concatenate(
                (
                    SOS_token,
                    A_token * np.ones(length),
                    B_token * np.ones(length),
                    A_token * np.ones(length),
                    EOS_token,
                )
            )
        )

    return data


def generate_aNbNcN_grammar_data(
    num_samples: int, max_length: int = 32, all_sequences=True
) -> list:
    """
    PCFG with two rules:
    - number of a's is equal to the number of b's, equal to the number of c's
    - N a's come first, followed by N b's, then N c's

    :param all_sequences:
    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length

    """
    if all_sequences is True:
        lengths = np.linspace(
            start=1, stop=max_length // 3, num=max_length // 3, dtype=int, endpoint=True
        )
    else:
        lengths = np.random.randint(low=1, high=max_length // 3 + 1, size=num_samples)

    data = []

    for length in lengths:
        data.append(
            np.concatenate(
                (
                    SOS_token,
                    A_token * np.ones(length),
                    B_token * np.ones(length),
                    C_token * np.ones(length),
                    EOS_token,
                )
            )
        )

    return data


def generate_abN_grammar_data(num_samples: int, max_length: int = 32) -> list:
    """
    PCFG with one rule:
    - number of a's and b's must be the same

    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    lengths = np.random.randint(low=1, high=max_length // 2 + 1, size=num_samples)

    data = []

    for lengths in lengths:
        abN = np.concatenate((A_token * np.ones(lengths), B_token * np.ones(lengths)))
        # shuffle the symbols between start and end tokens
        np.random.shuffle(abN)
        data.append(np.concatenate((SOS_token, abN, EOS_token)))

    return data


def generate_aNbM_grammar_data(num_samples: int, max_length: int = 32) -> list:
    """
    PCFG with one rule:
    - a's are before b's

    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    lengths_a = np.random.randint(low=1, high=max_length - 2, size=num_samples)
    lengths_b = np.ones_like(lengths_a) * max_length - lengths_a - 2

    data = []

    for la, lb in zip(lengths_a, lengths_b):
        data.append(
            np.concatenate(
                (SOS_token, A_token * np.ones(la), B_token * np.ones(lb), EOS_token)
            )
        )

    return data


def generate_bNaM_grammar_data(num_samples: int, max_length: int = 32) -> list:
    """
    PCFG with one rule:
    - b's are before a's (begins with b, without SOS, EOS)

    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    lengths_b = np.random.randint(low=1, high=max_length, size=num_samples)
    lengths_a = np.ones_like(lengths_b) * max_length - lengths_b

    data = []

    for lb, la in zip(lengths_b, lengths_a):
        data.append(np.concatenate((B_token * np.ones(la), A_token * np.ones(lb))))

    return data


def generate_baN_grammar_data(num_samples: int, max_length: int = 32) -> list:
    """
    PCFG with two rules:
    - begins with b
    - even number of a's

    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    lengths = np.random.randint(low=1, high=max_length + 1, size=num_samples)

    data = []

    for l in lengths:
        num_a = np.random.randint(low=0, high=(l - 1) // 2 + 1)
        second_part = np.concatenate(
            (A_token * np.ones(num_a * 2), B_token * np.ones(l - 1 - num_a * 2))
        )
        # shuffle the symbols
        np.random.shuffle(second_part)

        data.append(np.concatenate((SOS_token, B_token, second_part, EOS_token)))

    return data


def generate_bbaN_grammar_data(num_samples: int, max_length: int = 32) -> list:
    """
    PCFG with two rules:
    - b's before a's ('bbbb' ok but 'aaaa' not)
    - even number of a's

    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    lengths = np.random.randint(low=1, high=max_length + 1, size=num_samples)

    data = []

    for l in lengths:
        num_a = np.random.randint(low=0, high=(l - 1) // 2 + 1)
        second_part = np.concatenate(
            (B_token * np.ones(l - 1 - num_a * 2), A_token * np.ones(num_a * 2))
        )

        data.append(
            np.concatenate(
                (
                    SOS_token,
                    B_token * np.ones(l - num_a * 2),
                    A_token * np.ones(num_a * 2),
                    EOS_token,
                )
            )
        )

    return data


def pad(data: list, max_seq_length: int = 0) -> np.ndarray:
    """
    Pad data with PAD token
    :param data:
    :param max_seq_length: maximum sequence length
    :return: numpy array with padded data of shape (batch_size, max_batch_length)
    """

    if max_seq_length == 0:
        # Get longest sequence in the dataset
        for seq in data:
            if len(seq) > max_seq_length:
                max_seq_length = len(seq)

    # Append padding tokens until it reaches the max length
    for i, seq in enumerate(data):
        remaining_length = max_seq_length - len(seq)

        if remaining_length > 0:
            data[i] = np.concatenate((data[i], [PAD_token.item()] * remaining_length))

    return np.array(data)


def check_as_before_bs(sequence: torch.Tensor):
    """
    Check if the first b comes after the last a
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(a_tokens := torch.where(sequence == A_token.item())[0]) > 0:
        # find the last a
        last_a = a_tokens[-1]

        if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
            # find the first b
            first_b = b_tokens[0]

            return first_b > last_a
        else:
            return True
    else:
        return True


def check_bs_before_as(sequence: torch.Tensor):
    """
    Check if the first a comes after the last b. 'bbbb' ok, 'aaaa' not
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
        # find the last b
        last_b = b_tokens[-1]

        if len(a_tokens := torch.where(sequence == A_token.item())[0]) > 0:
            # find the first a
            first_a = a_tokens[0]

            return first_a > last_b
        else:
            return True
    else:
        return False


def check_as_before_cs(sequence: torch.Tensor):
    """
    Check if the first c comes after the last a
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(a_tokens := torch.where(sequence == A_token.item())[0]) > 0:
        # find the last a
        last_a = a_tokens[-1]

        if len(c_tokens := torch.where(sequence == C_token.item())[0]) > 0:
            # find the first c
            first_c = c_tokens[0]

            return first_c > last_a
        else:
            return True
    else:
        return True


def check_bs_before_cs(sequence: torch.Tensor):
    """
    Check if the first c comes after the last b
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
        # find the last b
        last_b = b_tokens[-1]

        if len(c_tokens := torch.where(sequence == C_token.item())[0]) > 0:
            # find the first c
            first_c = c_tokens[0]

            return first_c > last_b
        else:
            return True
    else:
        return True


def check_bs_in_the_middle(sequence: torch.Tensor):
    """
    Check if the b's are in the middle
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
        # find the first b
        first_b = b_tokens[0]
        last_b = b_tokens[-1]

        if len(sequence[:first_b]) == len(sequence[last_b + 1 :]):
            return True
        else:
            return False
    else:
        return False


def check_bs_together(sequence: torch.Tensor):
    """
    Check if the b's are in the middle
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
        # find the first b
        first_b = b_tokens[0]
        last_b = b_tokens[-1]

        if (
            (b_subsequence := sequence[first_b : last_b + 1]) == B_token.item()
        ).sum() == len(b_subsequence):
            return True
        else:
            return False
    else:
        return False


def check_same_number_as_bs(sequence: torch.Tensor):
    """
    Check if the number of a's and b's is the same
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    num_as = torch.sum(sequence == A_token.item())
    num_bs = torch.sum(sequence == B_token.item())
    return num_as == num_bs


def check_twice_many_as_than_bs(sequence: torch.Tensor):
    """
    Check if the number of a's and b's is the same
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    num_as = torch.sum(sequence == A_token.item())
    num_bs = torch.sum(sequence == B_token.item())
    return num_as == 2 * num_bs


def check_more_as_than_bs(sequence: torch.Tensor):
    """
    Check if there are more a's than b's
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    num_as = torch.sum(sequence == A_token.item())
    num_bs = torch.sum(sequence == B_token.item())
    return num_as >= num_bs


def check_more_bs_than_cs(sequence: torch.Tensor):
    """
    Check if there are more b's than c's
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    num_bs = torch.sum(sequence == B_token.item())
    num_cs = torch.sum(sequence == C_token.item())
    return num_bs >= num_cs


def check_more_as_before_bs(sequence: torch.Tensor):
    """
    Check if there are more a's than b's
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
        first_b = b_tokens[0]

        num_as = torch.sum(sequence[:first_b] == A_token.item())
        num_bs = torch.sum(sequence == B_token.item())
        return num_as >= num_bs

    else:
        return True


def check_same_number_as_bs_cs(sequence: torch.Tensor):
    """
    Check if the number of a's, b's and c's is the same
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    num_as = torch.sum(sequence == A_token.item())
    num_bs = torch.sum(sequence == B_token.item())
    num_cs = torch.sum(sequence == C_token.item())
    return (num_as == num_bs) and (num_bs == num_cs)


def check_as_before_bs_before_cs(sequence: torch.Tensor):
    """
    Check if the first b comes after the last a and the first c comes after the last b
    :param sequence:
    :return:
    """

    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(c_tokens := torch.where(sequence == C_token.item())[0]) > 0:
        # find the first c
        first_c = c_tokens[0]

        if len(b_tokens := torch.where(sequence == B_token.item())[0]) > 0:
            # find the first and last b
            last_b = b_tokens[-1]
            first_b = b_tokens[0]

            if len(a_tokens := torch.where(sequence == A_token.item())[0]) > 0:
                # find the last a
                last_a = a_tokens[-1]
                if (last_a < first_b) and (last_b < first_c):
                    return True
                else:
                    return False
            else:
                return check_bs_before_cs(sequence)
        else:
            return check_as_before_cs(sequence)
    else:
        return check_as_before_bs(sequence)


def check_in_dist_anbncn(sequence: torch.Tensor):
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if len(c_tokens := torch.where(sequence == C_token.item())[0]) == 0:
        if len(b_tokens := torch.where(sequence == B_token.item())[0]) == 0:
            return True
        else:
            return check_as_before_bs(sequence) and check_more_as_than_bs(sequence)
    else:
        return (
            check_as_before_bs(sequence)
            and check_bs_before_cs(sequence)
            and check_same_number_as_bs(sequence)
            and check_more_bs_than_cs(sequence)
        )


def check_sequence_finished(sequence: torch.Tensor):
    """
    Check if the sequence is finished (EOS token is in the sequence)
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    # find the first EOS token
    return len(torch.where(sequence == EOS_token.item())[0]) > 0


def check_even_number_of_as(sequence: torch.Tensor):
    """
    Check if the sequence has even number of a's
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    num_as = torch.sum(sequence == A_token.item())

    return num_as % 2 == 0


def check_begins_with_b(sequence: torch.Tensor):
    """
    Check if the sequence begins with a B_token (after SOS)
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    if sequence[0] == SOS_token.item():
        return sequence[1] == B_token.item()
    else:
        return sequence[0] == B_token.item()


def generate_test_prompts(length: int = 6, grammar: str = "aNbN"):
    """
    Generates all prompts of a given length with symbols a and b or (and c)
    :param length:
    :return:
    """

    num_samples = 2**length
    if grammar in ["aNbN", "abN", "aNbM", "aNbNaN", "baN"]:
        symbols = [A_token.item(), B_token.item()]
        prompts = torch.tensor(list(product(symbols, repeat=length)), dtype=torch.long)

        # add SOS
        prompts = torch.cat(
            (torch.ones((prompts.shape[0], 1), dtype=torch.long) * SOS_token, prompts),
            dim=1,
        )
    elif grammar == "aNbNcN":
        symbols = [A_token.item(), B_token.item(), C_token.item()]
        prompts = torch.tensor(list(product(symbols, repeat=length)), dtype=torch.long)

        # add SOS
        prompts = torch.cat(
            (torch.ones((prompts.shape[0], 1), dtype=torch.long) * SOS_token, prompts),
            dim=1,
        )
    elif grammar == "bbaN":
        ID_data = torch.tensor(
            generate_bNaM_grammar_data(num_samples=num_samples // 2, max_length=length),
            dtype=torch.long,
        )
        OOD_data = torch.tensor(
            generate_bNaM_grammar_data(
                num_samples=num_samples // 2, max_length=length - 1
            ),
            dtype=torch.long,
        )
        id_prompts = torch.cat(
            (torch.ones((ID_data.shape[0], 1), dtype=torch.long) * SOS_token, ID_data),
            dim=1,
        )
        ood_prompts = torch.cat(
            (
                torch.ones((OOD_data.shape[0], 1), dtype=torch.long) * SOS_token,
                torch.ones((OOD_data.shape[0], 1), dtype=torch.long) * A_token,
                OOD_data,
            ),
            dim=1,
        )
        prompts = torch.cat((ood_prompts, id_prompts), dim=0)

    elif grammar == "parentheses":
        data = torch.tensor(
            generate_matched_parentheses_data(
                num_samples=num_samples / 2, max_length=length, fixed_length=True
            ),
            dtype=torch.long,
        )
        ood_prompts = torch.cat(
            (
                data[:, 0].view(-1, 1),
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * CLOSING_PARENTHESIS_token,
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * OPENING_PARENTHESIS_token,
                data[:, 1:-1],
            ),
            dim=1,
        )  # remove EOS

        id_prompts = torch.cat(
            (
                data[:, 0].view(-1, 1),
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * OPENING_PARENTHESIS_token,
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * CLOSING_PARENTHESIS_token,
                data[:, 1:-1],
            ),
            dim=1,
        )  # remove EOS

        prompts = torch.cat((ood_prompts, id_prompts), dim=0)
    elif grammar == "brackets":
        data = torch.tensor(
            generate_matched_brackets_data(
                num_samples=num_samples / 2, max_length=length, fixed_length=True
            ),
            dtype=torch.long,
        )
        ood_prompts = torch.cat(
            (
                data[:, 0].view(-1, 1),
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * CLOSING_BRACKET_token,
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * OPENING_BRACKET_token,
                data[:, 1:-1],
            ),
            dim=1,
        )  # remove EOS

        id_prompts = torch.cat(
            (
                data[:, 0].view(-1, 1),
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * OPENING_BRACKET_token,
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * CLOSING_BRACKET_token,
                data[:, 1:-1],
            ),
            dim=1,
        )  # remove EOS
        prompts = torch.cat((ood_prompts, id_prompts), dim=0)

    elif grammar == "parentheses_and_brackets":
        data = torch.tensor(
            generate_matched_parentheses_and_brackets_data(
                num_samples=num_samples / 2, max_length=length, fixed_length=True
            ),
            dtype=torch.long,
        )
        # generate torch 0-1 sequence in shape (data.shape[0], 1)
        ood_prompts = torch.cat(
            (
                data[:, 0].view(-1, 1),
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * CLOSING_PARENTHESIS_token,
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * OPENING_PARENTHESIS_token,
                data[:, 1:-1],
            ),
            dim=1,
        )  # remove EOS

        id_prompts = torch.cat(
            (
                data[:, 0].view(-1, 1),
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * OPENING_PARENTHESIS_token,
                torch.ones((data.shape[0], 1), dtype=torch.long)
                * CLOSING_PARENTHESIS_token,
                data[:, 1:-1],
            ),
            dim=1,
        )  # remove EOS

        prompts = torch.cat((ood_prompts, id_prompts), dim=0)
    return prompts


def grammar_rules(grammar):
    """
    Selects the rules the grammar needs to satisfy.
    :param grammar:
    """
    if grammar == "aNbN":
        return lambda x: check_same_number_as_bs(x) and check_as_before_bs(x)
    elif grammar == "aNbNcN":
        return lambda x: check_same_number_as_bs_cs(x) and check_as_before_bs_before_cs(
            x
        )
    elif grammar == "baN":
        return lambda x: check_even_number_of_as(x) and check_begins_with_b(x)
    elif grammar == "bbaN":
        return lambda x: check_even_number_of_as(x) and check_bs_before_as(x)
    elif grammar == "abN":
        return lambda x: check_same_number_as_bs(x)
    elif grammar == "aNbM":
        return lambda x: check_as_before_bs(x)
    elif grammar == "aNbNaN":
        return (
            lambda x: check_twice_many_as_than_bs(x)
            and check_bs_in_the_middle(x)
            and check_bs_together(x)
        )
    elif grammar == "brackets":
        return lambda x: check_matched_brackets(x)
    elif grammar == "parentheses":
        return lambda x: check_matched_parentheses(x)
    elif grammar == "parentheses_and_brackets":
        return lambda x: check_matched_parentheses_and_brackets(x)
    else:
        raise ValueError(f"Unknown grammar {grammar}")


def prompt_grammar_rules(grammar):
    """
    Selects the rules that check whether a prompt can be completed as such that it satisfies the rules of the grammar.
    It is used to split the test_prompts into in-distribution and out-of-distribution.

    NOTE: these rules are LESS strict than the grammar_rules, because even if the prompt does not satisfy the grammar rules,
    it might be completed as such that it does.
    :param grammar:

    """
    if grammar == "aNbN":
        return lambda x: check_as_before_bs(x) and check_more_as_than_bs(x)
    elif grammar == "aNbNcN":
        return lambda x: check_in_dist_anbncn(x)
    elif grammar == "abN":
        return lambda x: True
    elif grammar == "baN":
        return lambda x: check_begins_with_b(x)
    elif grammar == "bbaN":
        return lambda x: check_begins_with_b(x)
    elif grammar == "aNbM":
        return lambda x: check_as_before_bs(x)
    elif grammar == "aNbNaN":
        return lambda x: check_as_before_bs(x) and check_bs_together(x)
    elif grammar == "brackets":
        return lambda x: check_matched_brackets(x)
    elif grammar == "parentheses":
        return lambda x: check_matched_parentheses(x)
    elif grammar == "parentheses_and_brackets":
        return lambda x: check_matched_parentheses_and_brackets(x)
    else:
        raise ValueError(f"Unknown grammar {grammar}")


import random


def generate_matched_parentheses_and_brackets(n):
    """
    Generate a word of length n with paired () and [].
    """
    if n == 0:
        return np.concatenate((SOS_token, EOS_token))
    elif n % 2 == 1:
        raise ValueError("Length can only be even")
    else:
        word = []
        stack = []
        while len(word) < n:  # Each pair of parentheses or brackets adds 2 characters
            if len(stack) == 0:
                choice = random.choice(
                    [OPENING_PARENTHESIS_token, OPENING_BRACKET_token]
                )
            elif stack[-1] == OPENING_PARENTHESIS_token:
                choice = random.choice(
                    [
                        OPENING_PARENTHESIS_token,
                        OPENING_BRACKET_token,
                        CLOSING_PARENTHESIS_token,
                    ]
                )
                if len(word) + len(stack) >= n:
                    choice = CLOSING_PARENTHESIS_token

            elif stack[-1] == OPENING_BRACKET_token:
                choice = random.choice(
                    [
                        OPENING_PARENTHESIS_token,
                        OPENING_BRACKET_token,
                        CLOSING_BRACKET_token,
                    ]
                )
                if len(word) + len(stack) >= n:
                    choice = CLOSING_BRACKET_token

            if choice == OPENING_PARENTHESIS_token:
                word.append(OPENING_PARENTHESIS_token)
                stack.append(OPENING_PARENTHESIS_token)
            elif choice == OPENING_BRACKET_token:
                word.append(OPENING_BRACKET_token)
                stack.append(OPENING_BRACKET_token)
            elif choice == CLOSING_PARENTHESIS_token:
                word.append(CLOSING_PARENTHESIS_token)
                stack.pop()
            elif choice == CLOSING_BRACKET_token:
                word.append(CLOSING_BRACKET_token)
                stack.pop()

            if len(stack) == 0:
                break

        return np.concatenate((SOS_token, *word, EOS_token))


def check_matched_parentheses_and_brackets(sequence: torch.Tensor) -> bool:
    """
    Check if the parentheses and brackets are matched
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    stack = []
    for token in sequence:
        if token == OPENING_PARENTHESIS_token.item():
            stack.append(token)
        elif token == CLOSING_PARENTHESIS_token.item():
            if len(stack) == 0 or stack[-1] != OPENING_PARENTHESIS_token.item():
                return False
            stack.pop()
        elif token == OPENING_BRACKET_token.item():
            stack.append(token)
        elif token == CLOSING_BRACKET_token.item():
            if len(stack) == 0 or stack[-1] != OPENING_BRACKET_token.item():
                return False
            stack.pop()

    return len(stack) == 0


def generate_matched_parentheses(n):
    """
    Generate a word of length n with paired ().
    """
    if n == 0:
        return np.concatenate((SOS_token, EOS_token))
    elif n % 2 == 1:
        raise ValueError("Length can only be even")
    else:
        word = []
        stack = []
        while len(word) < n:  # Each pair of parentheses or brackets adds 2 characters
            if len(stack) == 0:
                choice = OPENING_PARENTHESIS_token
            elif stack[-1] == OPENING_PARENTHESIS_token:
                choice = random.choice(
                    [OPENING_PARENTHESIS_token, CLOSING_PARENTHESIS_token]
                )
                if len(word) + len(stack) >= n:
                    choice = CLOSING_PARENTHESIS_token

            if choice == OPENING_PARENTHESIS_token:
                word.append(OPENING_PARENTHESIS_token)
                stack.append(OPENING_PARENTHESIS_token)

            elif choice == CLOSING_PARENTHESIS_token:
                word.append(CLOSING_PARENTHESIS_token)
                stack.pop()

            if len(stack) == 0:
                break

        return np.concatenate((SOS_token, *word, EOS_token))


def check_matched_parentheses(sequence: torch.Tensor) -> bool:
    """
    Check if the parentheses are matched
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    stack = []
    for token in sequence:
        if token == OPENING_PARENTHESIS_token.item():
            stack.append(token)
        elif token == CLOSING_PARENTHESIS_token.item():
            if len(stack) == 0:
                return False
            stack.pop()

    return len(stack) == 0


def generate_matched_brackets(n):
    """
    Generate a word of length n with paired [].
    """
    if n == 0:
        return np.concatenate((SOS_token, EOS_token))
    elif n % 2 == 1:
        raise ValueError("Length can only be even")
    else:
        word = []
        stack = []
        while len(word) < n:  # Each pair of parentheses or brackets adds 2 characters
            if len(stack) == 0:
                choice = OPENING_BRACKET_token

            elif stack[-1] == OPENING_BRACKET_token:
                choice = random.choice([2, CLOSING_BRACKET_token])
                if len(word) + len(stack) >= n:
                    choice = CLOSING_BRACKET_token

            if choice == OPENING_BRACKET_token:
                word.append(OPENING_BRACKET_token)
                stack.append(OPENING_BRACKET_token)
            elif choice == CLOSING_BRACKET_token:
                word.append(CLOSING_BRACKET_token)
                stack.pop()

            if len(stack) == 0:
                break

        return np.concatenate((SOS_token, *word, EOS_token))


def check_matched_brackets(sequence: torch.Tensor) -> bool:
    """
    Check if the brackets are matched
    :param sequence:
    :return:
    """
    if type(sequence) == np.ndarray:
        sequence = torch.from_numpy(sequence)

    stack = []
    for token in sequence:
        if token == OPENING_BRACKET_token.item():
            stack.append(token)
        elif token == CLOSING_BRACKET_token.item():
            if len(stack) == 0:
                return False
            stack.pop()

    return len(stack) == 0


def generate_matched_parentheses_data(
    num_samples: int, max_length: int = 32, fixed_length=False
) -> list:
    """


    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    if fixed_length is False:
        lengths = np.random.randint(low=1, high=max_length // 2 + 1, size=num_samples)
        data = [generate_matched_parentheses(2 * l) for l in lengths]
    else:
        data = []
        while len(data) < num_samples:
            sample = generate_matched_parentheses(max_length)
            if len(sample) == (max_length + 2):  # +SOS, EOS
                data.append(sample)

    return data


def generate_matched_brackets_data(
    num_samples: int, max_length: int = 32, fixed_length=False
) -> list:
    """


    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    if fixed_length is False:
        lengths = np.random.randint(low=1, high=max_length // 2 + 1, size=num_samples)
        data = [generate_matched_brackets(2 * l) for l in lengths]
    else:
        data = []
        while len(data) < num_samples:
            sample = generate_matched_parentheses(max_length)
            if len(sample) == (max_length + 2):  # +SOS, EOS
                data.append(sample)

    return data


def generate_matched_parentheses_and_brackets_data(
    num_samples: int, max_length: int = 32, fixed_length=False
) -> list:
    """


    :param num_samples: number of samples
    :param max_length: maximum sequence length (inclusive SOS and EOS tokens)
    :return: list of length num_samples with maximal sequences of length max_length
    """

    if fixed_length is False:
        lengths = np.random.randint(low=1, high=max_length // 2 + 1, size=num_samples)
        data = [generate_matched_parentheses_and_brackets(2 * l) for l in lengths]
    else:
        data = []
        while len(data) < num_samples:
            sample = generate_matched_parentheses_and_brackets(max_length)
            if len(sample) == (max_length + 2):  # +SOS, EOS
                data.append(sample)

    return data


## Dyck language

In [11]:
def brackets(num_to_close=1):
  tot=0
  for l in tqdm(range(299, 300)):
    for n in range(num_to_close, l+1):
      for i in range(0, (n-num_to_close)//2+1):
        for k in range(0, n-2*i-num_to_close+1):
          tot+=math.comb(n, 2*i + num_to_close) * math.comb(n-2*i-num_to_close, k) * math.comb(2*i, i)*(1/(i+1))*(1/(5**(n+1)))
  return tot
print(f"Num to close: 1 {brackets(1)}")
print(f"Num to close: 0 {brackets(0)}")


100%|██████████| 1/1 [01:11<00:00, 71.07s/it]


Num to close: 1 0.12732200375000433


100%|██████████| 1/1 [00:53<00:00, 53.72s/it]

Num to close: 0 0.38196601125004565





##aNbNcN

In [8]:
def calc_distances(prompt):
  num_3 = sum((prompt==3)).item()
  num_4 = sum((prompt==4)).item()
  num_5 = sum((prompt==5)).item()
  max_num = max(num_3, num_4, num_5)
  distances = list(sorted([max_num-num_3, max_num-num_4, max_num-num_5]))
  distances = distances[1:]
  return tuple(distances)


def r2_formula(prompt):
  n = 307-len(prompt[1:])
  acc = 0
  if prompt[-1]==5:
    for i in range(0, n+1):
      acc += 1/(4**(i+1))
  elif prompt[-1]==4:
    for i in range(0, n+1):
      for b in range(0, i+1):
        acc += 1/(4**(i+1))
  elif prompt[-1]==3:
    for i in range(0, n+1):
      for a in range(0, i+1):
        for b in range(0, i-a+1):
          acc += 1/(4**(i+1))
  else:
    raise ValueError
  return acc

def r2_ood(len_n):
  tot = 0
  all = 0
  for lists in len_n:
    n, number = lists[0], lists[1]
    all += number
    for i in range(0, n+1):
      for a in range(0, i+1):
        for b in range(0, i-a+1):
          tot += 1/(4**(i+1)) * number
  return tot/all


def r1_formula(N1, N2, n):
  tot = 0
  for m in range(0, (n-(N1+N2))//3 + 1):
    tot += math.comb(N1 + N2 + 3 * m, N1 + m)* math.comb(N2 + 2 * m, m)*(1/(4**(N1 + N2 + 3 * m + 1)))
  return tot

def aNbNcN():
  my_dict_id={}
  my_dict_ood={}
  id_prompts = []
  ood_prompts = []
  id_r1_acc = 0
  ood_r1_acc = 0
  id_r2_accuracy = []
  len_n = []
  for i in tqdm(range(5, 6)):
    test_prompts = generate_test_prompts(i, "aNbNcN")
    for prompt in test_prompts:
      if check_in_dist_anbncn(prompt):
        id_prompts.append(prompt)
      else:
        ood_prompts.append(prompt)

    for prompt in id_prompts:
      print(prompt)
      r2 = r2_formula(prompt)
      id_r2_accuracy.append(r2)
      n = 307-len(prompt[1:])
      distances = calc_distances(prompt)
      #print(f"id {prompt}, dist {distances}")
      distances += (n,)


      if distances not in my_dict_id.keys():
        my_dict_id[distances]=1
      else:
        my_dict_id[distances]+=1

    for prompt in ood_prompts:
        n = 307-len(prompt[1:])
        if [n, len(ood_prompts)] not in len_n:
          len_n.append([n, len(ood_prompts)])
        distances = calc_distances(prompt)
        #print(f" ood {prompt}, dist {distances}")
        distances += (n,)
        if distances not in my_dict_ood.keys():
          my_dict_ood[distances]=1
        else:
          my_dict_ood[distances]+=1

  for distances in my_dict_id.keys():
        id_r1_acc += r1_formula(distances[0], distances[1], distances[2]) * my_dict_id[distances]
        #print(f"Adding id len={(distances[2]-distances[1]-distances[0])//3}, number in set {my_dict_id[distances]} of distances {distances}, and prob {r1_formula(distances[0], distances[1], distances[2])}")

  id_r1_acc /= sum([my_dict_id[distances] for distances in my_dict_id.keys()])
  print(sum([my_dict_id[distances] for distances in my_dict_id.keys()]))
  print(id_r1_acc)

  for distances in my_dict_ood.keys():
        ood_r1_acc += r1_formula(distances[0], distances[1], distances[2]) * my_dict_ood[distances]
        #print(f"Adding len={(distances[2]-distances[1]-distances[0])//3}, number in set {my_dict_ood[distances]} of distances {distances}, and prob {r1_formula(distances[0], distances[1], distances[2])}")
  ood_r1_acc /= sum([my_dict_ood[distances] for distances in my_dict_ood.keys()])
  print(ood_r1_acc)

  #print(id_r2_accuracy)
  id_r2_accuracy = sum(id_r2_accuracy)/len(id_r2_accuracy)
  print(id_r2_accuracy)

  ood_r2_acc = r2_ood(len_n)
  print(ood_r2_acc)

  return id_r1_acc, ood_r1_acc, id_r2_accuracy, ood_r2_acc

id_r1_acc, ood_r1_acc, id_r2_acc, ood_r2_acc = aNbNcN()
print(id_r1_acc, ood_r1_acc, id_r2_acc, ood_r2_acc)




  0%|          | 0/1 [00:00<?, ?it/s]

tensor([0, 3, 3, 3, 3, 3])


100%|██████████| 1/1 [00:05<00:00,  5.97s/it]

tensor([0, 3, 3, 3, 3, 4])
tensor([0, 3, 3, 3, 4, 4])
tensor([0, 3, 3, 4, 4, 5])
4
0.02165618851160017
0.03342730665586524
0.4537037037036963





0.5925925925925722
0.02165618851160017 0.03342730665586524 0.4537037037036963 0.5925925925925722


##aNbN

In [None]:
def calc_distances(prompt):
  num_3 = sum((prompt==3)).item()
  num_4 = sum((prompt==4)).item()
  N = max(num_3, num_4)-min(num_3, num_4)
  return N


def r2_formula(prompt):
  n = 307-len(prompt[1:])
  acc = 0
  if prompt[-1]==4:
    for i in range(0, n+1):
      acc += 1/(4**(i+1))
  elif prompt[-1]==3:
    for i in range(0, n+1):
      for a in range(0, i+1):
        acc += 1/(4**(i+1))
  else:
    raise ValueError
  return acc

def r2_ood(len_n):
  tot = 0
  all = 0
  for lists in len_n:
    n, number = lists[0], lists[1]
    all += number
    for i in range(0, n+1):
      for a in range(0, i+1):
          tot += 1/(4**(i+1)) * number
  return tot/all


def r1_formula(N, n):
  tot = 0
  for m in range(0, (n-(N))//2 + 1):
    tot += math.comb(N + 2 * m, m) *(1/(3**(N + 2 * m + 1)))
  return tot

def aNbN():
  my_dict_id={}
  my_dict_ood={}
  id_prompts = []
  ood_prompts = []
  id_r1_acc = 0
  ood_r1_acc = 0
  id_r2_accuracy = []
  len_n = []
  for i in tqdm(range(8, 9)):
    test_prompts = generate_test_prompts(i, "aNbN")
    for prompt in test_prompts:
      if prompt_grammar_rules("aNbN")(prompt):

        id_prompts.append(prompt)
      else:
        ood_prompts.append(prompt)

    for prompt in id_prompts:
      r2 = r2_formula(prompt)
      id_r2_accuracy.append(r2)
      n = 307-len(prompt[1:])
      N = calc_distances(prompt)
      distances = tuple((N, n))

      if distances not in my_dict_id.keys():
        my_dict_id[distances]=1
      else:
        my_dict_id[distances]+=1

    for prompt in ood_prompts:
        n = 307-len(prompt[1:])
        if [n, len(ood_prompts)] not in len_n:
          len_n.append([n, len(ood_prompts)])
        N = calc_distances(prompt)
        distances = tuple((N, n))
        if distances not in my_dict_ood.keys():
          my_dict_ood[distances]=1
        else:
          my_dict_ood[distances]+=1

  for distances in my_dict_id.keys():
        id_r1_acc += r1_formula(distances[0], distances[1]) * my_dict_id[distances]
  id_r1_acc /= sum([my_dict_id[distances] for distances in my_dict_id.keys()])
  print(id_r1_acc)

  for distances in my_dict_ood.keys():
        ood_r1_acc += r1_formula(distances[0], distances[1]) * my_dict_ood[distances]
        #print(f"Adding N={distances[0]}, len={(distances[1]-distances[0])//2}, acc={r1_formula(distances[0], distances[1])} this many times: {my_dict_ood[distances]}")
  ood_r1_acc /= sum([my_dict_ood[distances] for distances in my_dict_ood.keys()])
  print(ood_r1_acc)
  print(len(id_r2_accuracy))
  id_r2_accuracy = sum(id_r2_accuracy)/len(id_r2_accuracy)

  print(id_r2_accuracy)

  ood_r2_acc = r2_ood(len_n)
  print(ood_r2_acc)



  return id_r1_acc, ood_r1_acc, id_r2_accuracy, ood_r2_acc, my_dict_ood, my_dict_id

id_r1_acc, ood_r1_acc, id_r2_acc, ood_r2_acc, my_dict_ood, my_dict_id = aNbN()
print(id_r1_acc, ood_r1_acc, id_r2_acc, ood_r2_acc)

100%|██████████| 1/1 [00:00<00:00,  5.12it/s]

0.10471443673912728
0.1539634577351729
5
0.3555555555555554
0.44444444444444486
0.10471443673912728 0.1539634577351729 0.3555555555555554 0.44444444444444486





##bbaN

In [None]:
def generate_bbaN(length: int = 6):

        ID_data = torch.tensor(
            generate_bNaM_grammar_data(100 // 2, max_length=length),
            dtype=torch.long,
        )

        id_prompts = torch.cat(
            (torch.ones((ID_data.shape[0], 1), dtype=torch.long) * SOS_token, ID_data),
            dim=1,
        )
        return id_prompts
def generate_bbaN_ood(length: int=8):
        OOD_data = torch.tensor(
            generate_bNaM_grammar_data(
                num_samples=100 // 2, max_length=length - 1
            ),
            dtype=torch.long,
        )

        ood_prompts = torch.cat(
            (
                torch.ones((OOD_data.shape[0], 1), dtype=torch.long) * SOS_token,
                torch.ones((OOD_data.shape[0], 1), dtype=torch.long) * A_token,
                OOD_data,
            ),
            dim=1,
        )
        return ood_prompts



def bbaN():
  id_r1_acc = 0
  ood_r1_acc

  all_prompt = 0
  all_ood_prompt=0

  for i in tqdm(range(8, 9)):
    id_prompts = generate_bbaN(i)
    ood_prompts = generate_bbaN_ood(i)

    for prompt in id_prompts:
      all_prompt +=1
      n = 307-len(prompt[1:])
      if 3 in prompt:
        for a in range(0, n+1):
          id_r1_acc += 1/(4*(n+1))
      else:
        for i in range(0, n+1):
          for b in range(0, n+1):
            id_r1_acc += 1/(4**(n+1))

    for prompt in ood_prompts:
      all_ood_prompt +=1
      n = 307-len(prompt[1:])
      if 3 in prompt:
        for a in range(0, n+1):
          ood_r1_acc += 1/(4*(n+1))
      else:
        for i in range(0, n+1):
          for b in range(0, n+1):
            ood_r1_acc += 1/(4**(n+1))

  return id_r1_acc / all_prompt, ood_r1_acc/all_prompt

def bbaN_even_a():
  id_prompts = []
  ood_prompts = []
  id_r1_acc = 0
  id_r2_acc = 0
  ood_r1_acc =  0
  ood_r2_acc = 0

  all_prompt = 0
  all_ood_prompt = 0

  for i in tqdm(range(8, 9)):
    id_prompts = generate_bbaN(i)
    ood_prompts = generate_bbaN_ood(i)

    for prompt in id_prompts:
      all_prompt +=1
      num_a = sum((prompt==3))

      n = 307-len(prompt[1:])
      if num_a % 2 ==0:
        for l in range(0, n+1):
          for a in range(0, l//2+1):
            id_r2_acc += math.comb(l, 2*a)*(1/3**(l+1))
      else:
        for l in range(1, n+1):
          for a in range(1, (l+1)//2 + 1):
            id_r2_acc += math.comb(l, 2*a-1)*(1/3**(l+1))

      if 3 in prompt:
        for a in range(0, n+1):
          id_r1_acc += 1/(4*(n+1))
      else:
        for l in range(0, n+1):
          for b in range(0, l+1):
            id_r1_acc += 1/(4**(l+1))

    for prompt in ood_prompts:
      all_ood_prompt +=1
      num_a = sum((prompt==3))

      n = 307-len(prompt[1:])
      if num_a % 2 ==0:
        for l in range(0, n+1):
          for a in range(0, l//2+1):
            ood_r2_acc += math.comb(l, 2*a)*(1/3**(l+1))
      else:
        for l in range(1, n+1):
          for a in range(1, (l+1)//2 + 1):
            ood_r2_acc += math.comb(l, 2*a-1)*(1/3**(l+1))

      for l in range(1, n+1):
        for b in range(1, l+1):
          ood_r1_acc += 1/(3**(l+1))


  return id_r1_acc / all_prompt, id_r2_acc / all_prompt, ood_r1_acc / all_ood_prompt, ood_r2_acc / all_ood_prompt

r1, r2, ood_r1, ood_r2 = bbaN_even_a()
print(r1)
print(r2)
print(ood_r1)
print(ood_r2)

100%|██████████| 1/1 [00:32<00:00, 32.29s/it]

0.24999999999993502
0.4733333333333011
0.24999999999996164
0.4999999999999653





##baN

In [None]:
def baN():
  id = []
  ood = []
  all_ood_prompt = 0
  ood_r2_acc = 0
  all_id_prompt = 0
  id_r2_acc = 0

  test_prompts = generate_test_prompts(length=8, grammar="baN")
  for prompt in test_prompts:
    if prompt_grammar_rules("baN")(prompt):
      id.append(prompt)
    else:
      ood.append(prompt)
  for prompt in ood:
      all_ood_prompt +=1
      num_a = sum((prompt==3))

      n = 307-len(prompt[1:])
      if num_a % 2 ==0:
        for l in range(0, n+1):
          for a in range(0, l//2+1):
            ood_r2_acc += math.comb(l, 2*a)*(1/3**(l+1))
      else:
        for l in range(1, n+1):
          for a in range(1, (l+1)//2 + 1):
            ood_r2_acc += math.comb(l, 2*a-1)*(1/3**(l+1))
  for prompt in id:
      all_id_prompt +=1
      num_a = sum((prompt==3))

      n = 307-len(prompt[1:])
      if num_a % 2 ==0:
        for l in range(0, n+1):
          for a in range(0, l//2+1):
            id_r2_acc += math.comb(l, 2*a)*(1/3**(l+1))
      else:
        for l in range(1, n+1):
          for a in range(1, (l+1)//2 + 1):
            id_r2_acc += math.comb(l, 2*a-1)*(1/3**(l+1))
  return id_r2_acc / all_id_prompt, ood_r2_acc / all_ood_prompt

id, ood = baN()
print(id)
print(ood)


0.4999999999998808
0.4999999999998808
