In [1]:
import math
import numpy as np
import os
import pickle
import random
import sympy
import uuid

from collections import namedtuple
from datasets.base import KnownEquation
from datasets.sampling import DefaultSampling
from eq.conversion import sympy2sequence, sequence2model
from eq.eval import convert_pred_sequence_to_eqs
from eq.vocabulary import SOS_TOKEN
from sympy.core import numbers
from sympy.core import power
from torchdistill.common.file_util import get_file_path_list, make_parent_dirs
from torchdistill.common.yaml_util import load_yaml_file
from tqdm import tqdm

## Variables

In [2]:
DATASET_CONFIG_FILE_PATH = os.path.expanduser('./configs/datasets/random/feynman_lm.yaml')
REF_EQ_DIR_PATH = os.path.expanduser('~/dataset/symbolic_regression/proposed/random_split/full_set/true_eq/')
NGRAM = 2
NULL_TOKEN = 'Null'
TOKEN_DELIMITER = ' '
MAX_SEQ_LENGTH = 30
MAX_NUM_EQS = 5000
NUM_VARIATIONS = 10
Node = namedtuple('Node', ['value', 'op_str'])


TRAIN_RATIO, VAL_RATIO, TEST_RATIO = 8, 1, 1
total = TRAIN_RATIO + VAL_RATIO + TEST_RATIO
TRAIN_RATIO /= total
VAL_RATIO /= total
TEST_RATIO /= total

In [3]:
def load_eq_tree_seqs(eq_dir_path):
    eq_tree_seq_list = list()
    for eq_file_path in get_file_path_list(eq_dir_path, is_sorted=True):
        with open(eq_file_path, 'rb') as fp:
            eq_sympy = pickle.load(fp)
        eq_tree_sequence = sympy2sequence(eq_sympy.evalf(), returns_binary_tree=True)
        eq_tree_seq_list.append(eq_tree_sequence)
    return eq_tree_seq_list

In [4]:
ref_eq_tree_seqs = load_eq_tree_seqs(REF_EQ_DIR_PATH)

## Build n-gram language model using naive Bayes

In [5]:
def update_freq_dict(eq_tree_seq, ngram, freq_dict):
    ngram_list = [NULL_TOKEN] * (ngram - 1) + [SOS_TOKEN]
    for token in eq_tree_seq:
        ngram_list.pop(0)
        given_str = TOKEN_DELIMITER.join(ngram_list)
        ngram_list.append(token)
        random_str = token
        if given_str not in freq_dict:
            freq_dict[given_str] = dict()
            
        given_dict = freq_dict[given_str]
        if random_str not in given_dict:
            given_dict[random_str] = 0
        given_dict[random_str] += 1


def build_ngram_lm(eq_tree_seqs, ngram):
    freq_dict = dict()
    for eq_tree_seq in eq_tree_seqs:
        update_freq_dict(eq_tree_seq, ngram, freq_dict)
        
    ngram_lm = dict()
    for given_str, sub_dict in freq_dict.items():
        denominator = sum(sub_dict.values())
        pairs = [(v / denominator, k) for k, v in sub_dict.items()]
        pairs = sorted(pairs, key=lambda x: x[0], reverse=True)
        base_value = 0
        node_list = list()
        for (v, k) in pairs:
            base_value += v
            node = Node(base_value, k)
            node_list.append(node)
        ngram_lm[given_str] = node_list
    return ngram_lm

In [6]:
ngram_lm = build_ngram_lm(ref_eq_tree_seqs, NGRAM)

## Randomly generate equations

In [7]:
def generate_random_equation(ngram_lm, ngram, max_seq_length, token_delimiter):
    ngram_list = [NULL_TOKEN] * (ngram - 1) + [SOS_TOKEN]
    op_list = list()
    for i in range(max_seq_length):
        ngram_list.pop(0)
        given_str = token_delimiter.join(ngram_list)
        if given_str not in ngram_lm:
            return None
        
        random_nodes = ngram_lm[given_str]
        random_value = random.random()
        op_found = False
        for random_node in random_nodes:
            if random_value < random_node.value:
                op_list.append(random_node.op_str)
                ngram_list.append(random_node.op_str)
                op_found = True
                break
        
        if not op_found:
            print(f'random value: {random_value}, random nodes: {random_nodes}')
            
        try:
            random_sr_model, parent_stack = sequence2model(op_list, returns_parent_stack=True)
            if len(parent_stack) == 0:
                sympy_eq_str = random_sr_model.sympy_str()
                random_eq = sympy.sympify(sympy_eq_str)
                return random_eq, op_list
        except:
            pass
    return None


def reindex_variables(op_list):
    numbers = [int(op[1:]) for op in op_list if op.startswith('x') and op[1:].isdigit()]
    numbers = sorted(numbers)
    max_num = len(numbers) - 1
    var_dict = {f'x{number}': f'x{i}' for i, number in enumerate(numbers)}
    op_list = [var_dict.get(op, op) for op in op_list]
    return op_list

In [8]:
random_eq_list = list()
random_eq_set = set()
for _ in tqdm(range(MAX_NUM_EQS)):
    output = generate_random_equation(ngram_lm, NGRAM, MAX_SEQ_LENGTH, TOKEN_DELIMITER)
    if output is not None:
        random_eq, op_list = output
        random_eq_list.append(random_eq)
        op_list = reindex_variables(op_list)
        eq_key = '\t'.join(op_list)
        random_eq_set.add(eq_key)

print(f'{len(random_eq_list)} random equations generated, and {len(random_eq_set)} of them are unique w.r.t. their equation tree')

100%|██████████████████████████████████████| 5000/5000 [00:32<00:00, 156.21it/s]

4501 random equations generated, and 1841 of them are unique w.r.t. their equation tree





## Create datasets using the randomly generated equations

In [9]:
def generate_random_sampling_objs(num_vars):
    sampling_obj_list = list()
    for _ in range(num_vars):
        random_int = random.randint(-32, 32)
        uses_negative = random.random() < 0.5
        sampling_obj = \
            DefaultSampling(np.power(10.0, random_int - 1), np.power(10.0, random_int + 1), uses_negative=uses_negative)
        sampling_obj_list.append(sampling_obj)
    return sampling_obj_list


def random_init_constants(random_eq, sub_eq=None, parent_op=None):
    if sub_eq is None:
        sub_eq = random_eq
        
    if isinstance(sub_eq, numbers.Float):
        const_value = random.random() * math.pow(10, random.uniform(-32, 32))
        if isinstance(parent_op, power.Pow):
            const_value = random.randint(2, 5)
            if random.random() < 0.5:
                const_value *= -1
            const_value = float(const_value)
        random_eq = random_eq.subs(sub_eq, const_value)
    
    for i in range(len(sub_eq.args)):
        random_eq = random_init_constants(random_eq, sub_eq.args[i], sub_eq)
    return random_eq


def split_dataset(dataset, train_ratio, val_ratio, test_ratio):
    total = train_ratio + val_ratio + test_ratio
    train_ratio /= total
    val_ratio /= total
    num_samples = len(dataset)
    num_train_samples = int(train_ratio * num_samples)
    num_val_samples = int(val_ratio * num_samples)
    num_test_samples = num_samples - (num_train_samples + num_val_samples)
    train_dataset = dataset[:num_train_samples] if num_train_samples > 0 else None
    val_dataset = dataset[num_train_samples:num_train_samples + num_val_samples] if num_val_samples > 0 else None
    test_dataset = dataset[-num_test_samples:] if num_test_samples > 0 else None
    return train_dataset, val_dataset, test_dataset


def generate_dataset(eq_instance, eq_name, dataset_config, default_train_ratio, default_val_ratio, default_test_ratio):
#     print('\n====================================')
#     print(f'Generating dataset `{eq_name}` ...')
#     print(dataset_config)

    # Generate tabular dataset
    try:
        dataset = eq_instance.create_dataset(dataset_config['sample_size'])
    except:
#         print(f'{eq_instance.sympy_eq} could not create a dataset')
        return False
    
    train_ratio = dataset_config.get('train_ratio', default_train_ratio)
    val_ratio = dataset_config.get('val_ratio', default_val_ratio)
    test_ratio = dataset_config.get('test_ratio', default_test_ratio)
    train_dataset, val_dataset, test_dataset = split_dataset(dataset, train_ratio, val_ratio, test_ratio)

    # Write out each split
    output_dir_path = os.path.expanduser(dataset_config['output_dir'])
    output_ext = dataset_config['output_ext']
    delimiter = dataset_config.get('output_delim', '\t' if output_ext == '.tsv' else ' ')
    for sub_dataset, split_name in zip((train_dataset, val_dataset, test_dataset), ('train', 'val', 'test')):
        if sub_dataset is None:
            continue

#         print(f'Writing out {len(sub_dataset)} samples for {split_name} split')
        output_file_path = os.path.join(output_dir_path, split_name, eq_name + output_ext)
        make_parent_dirs(output_file_path)
        # Save tabular dataset
        np.savetxt(output_file_path, sub_dataset, delimiter=delimiter)

    # Save ground-truth sympy expression
    pickle_file_path = os.path.join(output_dir_path, 'true_eq', eq_name + '.pkl')
    make_parent_dirs(pickle_file_path)
    with open(pickle_file_path, 'wb') as fp:
        pickle.dump(eq_instance.sympy_eq, fp)
    return True


def generate_datasets_from_eq(random_eq, num_trials, base_random_eq_name, dataset_config, train_ratio, val_ratio, test_ratio):
    success_count = 0
    num_vars = len(random_eq.free_symbols)
    for i in range(num_trials):
        random_eq_name = f'{base_random_eq_name}-{i}'
        random_eq = random_init_constants(random_eq)
#         print(random_eq)
        sampling_objs = generate_random_sampling_objs(num_vars)
        random_eq_instance = KnownEquation.from_sympy_eq(random_eq, sampling_objs, reindexes=True)
        success = generate_dataset(random_eq_instance, random_eq_name, dataset_config, 
                                   train_ratio, val_ratio, test_ratio)
        if success:
            success_count += 1
    return success_count

In [10]:
dataset_config = load_yaml_file(DATASET_CONFIG_FILE_PATH)

In [11]:
success_count = 0
for i, random_eq in enumerate(tqdm(random_eq_list)):
    success_count += generate_datasets_from_eq(random_eq, NUM_VARIATIONS, f'random-{i}', dataset_config, TRAIN_RATIO, VAL_RATIO, TEST_RATIO)

print(f'{success_count} datasets were created.')

100%|█████████████████████████████████████| 4501/4501 [1:51:55<00:00,  1.49s/it]

24232 datasets were created.



