In [6]:
%autoreload 2

In [7]:
from argparse import Namespace
import copy
import enum
import gzip
import json
import logging
import marshal
import os
import pickle
import sys
import timeit
import typing

logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('numba').setLevel(logging.WARNING)

import dill
import msgpack
import numpy as np
import tatsu
import simplejson

sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('../src'))
from src import fitness_energy_utils as utils
from src.fitness_energy_utils import NON_FEATURE_COLUMNS
from src.ast_counter_sampler import *
from src.fitness_ngram_models import *
from src.ast_utils import cached_load_and_parse_games_from_file, load_games_from_file, _extract_game_id
from src import ast_printer, ast_parser, ast_mcmc_regrowth, latest_model_paths

2023-06-15 13:46:45 - ast_utils - DEBUG    - Using cache folder: /misc/vlgscratch4/LakeGroup/guy/game_generation_cache
2023-06-15 13:46:45 - src.ast_utils - DEBUG    - Using cache folder: /misc/vlgscratch4/LakeGroup/guy/game_generation_cache


In [8]:
grammar = open('../dsl/dsl.ebnf').read()
grammar_parser = tatsu.compile(grammar)
game_asts = list(cached_load_and_parse_games_from_file('../dsl/interactive-beta.pddl', grammar_parser, False, relative_path='..', force_rebuild=True))
# real_game_texts = [ast_printer.ast_to_string(ast, '\n') for ast in game_asts]
# regrown_game_texts = list(load_games_from_file('../dsl/ast-real-regrowth-samples.pddl'))

# regrown_game_asts = list(cached_load_and_parse_games_from_file('../dsl/ast-real-regrowth-samples-1024.pddl.gz', grammar_parser, True, relative_path='..'))

print(len(game_asts))


2023-06-15 13:46:57 - src.ast_utils - DEBUG    - Updating cache with 98 new games
2023-06-15 13:46:57 - src.ast_utils - DEBUG    - About to finally update the cache


98


In [5]:
def original_deepcopy(ast):
    return copy.deepcopy(ast)


def pickle_deepcopy(ast):
    return pickle.loads(pickle.dumps(ast))


def dill_deepcopy(ast):
    return dill.loads(dill.dumps(ast))


def json_deepcopy(ast):
    return json.loads(json.dumps(ast))


def simplejson_deepcopy(ast):
    return simplejson.loads(simplejson.dumps(ast))


def marshal_deepcopy(ast):
    return marshal.loads(marshal.dumps(ast))


class ExtTypeID(enum.IntEnum):
    AST = 0
    PARSE_INFO = 1
    TUPLE = 2
    CLOSURE = 3



def msgpack_default(obj):
    if isinstance(obj, tatsu.ast.AST):
        return msgpack.ExtType(ExtTypeID.AST, msgpack.packb(dict(obj), default=msgpack_default, strict_types=True))
    elif isinstance(obj, tatsu.infos.ParseInfo):
        return msgpack.ExtType(ExtTypeID.PARSE_INFO, msgpack.packb(list(obj), default=msgpack_default, strict_types=True))
    elif isinstance(obj, tuple):
        return msgpack.ExtType(ExtTypeID.TUPLE, msgpack.packb(list(obj), default=msgpack_default, strict_types=True))
    elif isinstance(obj, tatsu.contexts.closure):
        return msgpack.ExtType(ExtTypeID.CLOSURE, msgpack.packb(list(obj), default=msgpack_default, strict_types=True))
    else:
        raise ValueError(f'Unknown type in msgpack_default: {type(obj)}')

    return obj


def msgpack_ext_hook(code, data):
    if code == ExtTypeID.AST:
        return tatsu.ast.AST(msgpack.unpackb(data, ext_hook=msgpack_ext_hook))
    elif code == ExtTypeID.PARSE_INFO:
        return tatsu.infos.ParseInfo(*msgpack.unpackb(data, ext_hook=msgpack_ext_hook))
    elif code == ExtTypeID.TUPLE:
        return tuple(msgpack.unpackb(data, ext_hook=msgpack_ext_hook))
    elif code == ExtTypeID.CLOSURE:
        return tatsu.contexts.closure(msgpack.unpackb(data, ext_hook=msgpack_ext_hook))

    return data


def msgpack_deepcopy(ast):
    return msgpack.unpackb(msgpack.packb(ast, default=msgpack_default, strict_types=True), ext_hook=msgpack_ext_hook)


def msgpack_ast_restore(obj, depth: int = 0):
    if depth == 0:
        return tuple([msgpack_ast_restore(item, depth + 1) for item in obj])

    if isinstance(obj, list):
        out = [msgpack_ast_restore(item, depth + 1) for item in obj]
        if depth == 1:
            return tuple(out)
        
        return out

    if isinstance(obj, dict):
        out = {}
        for key, val in obj.items():
            if key == 'parseinfo':
                val = tatsu.infos.ParseInfo(*val)

            out[key] = msgpack_ast_restore(val, depth + 1)

        return tatsu.ast.AST(out)
        
    return obj


def msgpack_restore_deepcopy(ast):
    return msgpack_ast_restore(msgpack.unpackb(msgpack.packb(ast)))


def json_restore_deepcopy(ast):
    return msgpack_ast_restore(json_deepcopy(ast))


def simplejson_restore_deepcopy(ast):
    return msgpack_ast_restore(simplejson_deepcopy(ast))


def ast_printer_deepcopy(ast):
    # TODO: think about multiprocessing issues
    ast_printer.BUFFER = []
    ast_printer.pretty_print_ast(ast)
    return grammar_parser.parse(''.join(ast_printer.BUFFER))


DEEPCOPY_METHODS = {
    # 'original': original_deepcopy,  # also quite slow
    'pickle': pickle_deepcopy,
    # 'dill': dill_deepcopy,  # incredibly slow
    'json': json_deepcopy,
    'json_restore': json_restore_deepcopy,
    # 'simplejson': simplejson_deepcopy,  # strictly slower than regular json
    # 'simplejson_restore': simplejson_restore_deepcopy,  # strictly slower than built-in json
    # 'marshal': marshal_deepcopy,  # unmarshalable
    # 'msgpack': msgpack_deepcopy,  # too slow with ext types
    'msgpack_restore': msgpack_restore_deepcopy,
    # 'ast_printer': ast_printer_deepcopy,  # really slow
}

## TODO
* Pure-python messagepack with ext_types
* simplejson with default or with deepcopy restore?
* ...?

In [12]:
def time_methods(methods=DEEPCOPY_METHODS, asts=game_asts, number=1000):
    for name, method in methods.items():
        try:
            time = timeit.timeit(lambda: [method(ast) for ast in asts], number=number)
            print(f'{name}: {time}')
        except Exception as e:
            print(f'{name}: {e}')

time_methods(number=500)

pickle: 31.912119068205357
json: 19.5588649045676
json_restore: 36.88173092715442
msgpack_restore: 26.70975672453642


In [None]:
for ast in game_asts:
    frontier = [ast]
    while frontier:
        item = frontier.pop(0)
        if isinstance(item, (list, tuple)):
            frontier.extend(item)

        if isinstance(item, (dict, tatsu.ast.AST)):
            for key, value in item.items():
                if isinstance(value, tatsu.buffering.Buffer):
                    print(key, value)
                else:
                    frontier.append(value)

In [None]:
def msgpack_default(obj):
    print(f'msgpack_default: {type(obj)}')
    if isinstance(obj, tatsu.ast.AST):
        return msgpack.ExtType(0, msgpack.packb(obj))

    if isinstance(obj, tatsu.infos.ParseInfo):
        return msgpack.ExtType(1, msgpack.packb(obj))

    return obj


def msgpack_ext_hook(code, data):
    print(f'msgpack_ext_hook: {code}')
    if code == 0:
        return tatsu.ast.AST(msgpack.unpackb(data))
    if code == 1:
        return tatsu.infos.ParseInfo(msgpack.unpackb(data))

    return data


def msgpack_deepcopy(ast):
    return msgpack.unpackb(msgpack.packb(ast, default=msgpack_default), ext_hook=msgpack_ext_hook)





msgpack.packb(game_asts[0][3][1], default=msgpack_default, strict_types=True)

In [None]:
import enum


class ExtTypeID(enum.IntEnum):
    AST = 0
    PARSE_INFO = 1
    TUPLE = 2
    CLOSURE = 3


def msgpack_default(obj):
    if isinstance(obj, tatsu.ast.AST):
        return msgpack.ExtType(ExtTypeID.AST, msgpack.packb(dict(obj), default=msgpack_default, strict_types=True))
    elif isinstance(obj, tatsu.infos.ParseInfo):
        return msgpack.ExtType(ExtTypeID.PARSE_INFO, msgpack.packb(list(obj), default=msgpack_default, strict_types=True))
    elif isinstance(obj, tuple):
        return msgpack.ExtType(ExtTypeID.TUPLE, msgpack.packb(list(obj), default=msgpack_default, strict_types=True))
    elif isinstance(obj, tatsu.contexts.closure):
        return msgpack.ExtType(ExtTypeID.CLOSURE, msgpack.packb(list(obj), default=msgpack_default, strict_types=True))
    else:
        raise ValueError(f'Unknown type in msgpack_default: {type(obj)}')

    return obj


def msgpack_ext_hook(code, data):
    if code == ExtTypeID.AST:
        return tatsu.ast.AST(msgpack.unpackb(data, ext_hook=msgpack_ext_hook))
    elif code == ExtTypeID.PARSE_INFO:
        return tatsu.infos.ParseInfo(*msgpack.unpackb(data, ext_hook=msgpack_ext_hook))
    elif code == ExtTypeID.TUPLE:
        return tuple(msgpack.unpackb(data, ext_hook=msgpack_ext_hook))
    elif code == ExtTypeID.CLOSURE:
        return tatsu.contexts.closure(msgpack.unpackb(data, ext_hook=msgpack_ext_hook))

    return data


def msgpack_deepcopy(ast):
    return msgpack.unpackb(msgpack.packb(ast, default=msgpack_default, strict_types=True), ext_hook=msgpack_ext_hook)


for ast in game_asts:
    copied_ast = msgpack_deepcopy(ast)
    sample_str = ast_printer.ast_to_string(ast, '\n')
    copied_sample_str = ast_printer.ast_to_string(copied_ast, '\n')
    print(sample_str == copied_sample_str)
 


In [None]:
def msgpack_ast_restore(obj, depth: int = 0):
    if depth == 0:
        return tuple([msgpack_ast_restore(item, depth + 1) for item in obj])

    if isinstance(obj, list):
        out = [msgpack_ast_restore(item, depth + 1) for item in obj]
        if depth == 1:
            return tuple(out)
        
        return out

    if isinstance(obj, dict):
        out = {}
        for key, val in obj.items():
            if key == 'parseinfo':
                val = tatsu.infos.ParseInfo(*val)

            out[key] = msgpack_ast_restore(val, depth + 1)

        return tatsu.ast.AST(out)
        
    return obj


def msgpack_restore_deepcopy(ast):
    return msgpack_ast_restore(msgpack.unpackb(msgpack.packb(ast)))

for ast in game_asts:
    copied_ast = msgpack_restore_deepcopy(ast)
    sample_str = ast_printer.ast_to_string(ast, '\n')
    copied_sample_str = ast_printer.ast_to_string(copied_ast, '\n')
    if sample_str != copied_sample_str:
        print(sample_str)
        print(copied_sample_str)
        break

In [8]:
class TestParser(ast_parser.ASTParser):
    def __init__(self):
        self.game_name = ''

    def _handle_ast(self, ast, **kwargs):
        if ast.parseinfo.rule == 'game_def':
            self.game_name = ast.game_name

        elif ast.parseinfo.rule == 'then':
            if any(not isinstance(sf, tatsu.ast.AST) for sf in ast.then_funcs):
                print(self.game_name, ast.then_funcs)
            
        super()._handle_ast(ast, **kwargs)


tp = TestParser()
for ast in game_asts:
    tp(ast)

In [9]:
DEFAULT_ARGS = argparse.Namespace(
    grammar_file=os.path.join('..', DEFAULT_GRAMMAR_FILE),
    parse_counter=False,
    counter_output_path=os.path.join('..', DEFAULT_COUNTER_OUTPUT_PATH),
    random_seed=DEFAULT_RANDOM_SEED,
)

grammar = open(DEFAULT_ARGS.grammar_file).read()
grammar_parser = typing.cast(tatsu.grammars.Grammar, tatsu.compile(grammar))  # type: ignore
counter = parse_or_load_counter(DEFAULT_ARGS, grammar_parser)

sampler = ASTSampler(grammar_parser, counter, seed=DEFAULT_RANDOM_SEED) 
regrowth_sampler = RegrowthSampler(sampler, seed=DEFAULT_RANDOM_SEED, rng=sampler.rng)

initial_sampler = ast_mcmc_regrowth.create_initial_proposal_sampler(
    ast_mcmc_regrowth.InitialProposalSamplerType.SECTION_SAMPLER,
    sampler, 
    latest_model_paths.LATEST_AST_N_GRAM_MODEL_PATH,
)

In [20]:
N_SAMPLES = 100

def generate_and_duplicate_sample(idx):
    sample = None
    while sample is None:
        try:
            sample = initial_sampler.sample(global_context=dict(original_game_id=f'evo-{idx}'))

        except RecursionError:
            continue
        except SamplingException:
            continue

    copied_sample = msgpack_restore_deepcopy(sample)
    if copied_sample is sample:
        raise ValueError('Sample is not copied')
    
    sample_str = ast_printer.ast_to_string(sample, '\n')
    copied_sample_str = ast_printer.ast_to_string(copied_sample, '\n')
    return sample_str, copied_sample_str


for idx in range(N_SAMPLES):
    sample_str, copied_sample_str = generate_and_duplicate_sample(idx)
    if sample_str != copied_sample_str:
        print(sample_str)
        print(copied_sample_str)
        break

In [2]:
import multiprocess as multiprocessing
from multiprocess import pool as mpp

with mpp.Pool(12) as pool:
    for (sample_str, copied_sample_str) in pool.map(generate_and_duplicate_sample, range(2000)):
        if sample_str != copied_sample_str:
            print(sample_str)
            print(copied_sample_str)
            break

NameError: name 'generate_and_duplicate_sample' is not defined

In [11]:
from itertools import chain
from collections import deque

def total_size(o, handlers={}, verbose=False):
    """ Returns the approximate memory footprint an object and all of its contents.

    Automatically finds the contents of the following builtin containers and
    their subclasses:  tuple, list, deque, dict, set and frozenset.
    To search other containers, add handlers to iterate over their contents:

        handlers = {SomeContainerClass: iter,
                    OtherContainerClass: OtherContainerClass.get_elements}

    """
    dict_handler = lambda d: chain.from_iterable(d.items())
    all_handlers = {tuple: iter,
                    list: iter,
                    deque: iter,
                    dict: dict_handler,
                    set: iter,
                    frozenset: iter,
                   }
    all_handlers.update(handlers)     # user handlers take precedence
    seen = set()                      # track which object id's have already been seen
    default_size = sys.getsizeof(0)       # estimate sizeof object without __sizeof__

    def sizeof(o):
        if id(o) in seen:       # do not double count the same object
            return 0
        seen.add(id(o))
        s = sys.getsizeof(o, default_size)

        if verbose:
            print(s, type(o), repr(o), file=sys.stderr)

        for typ, handler in all_handlers.items():
            if isinstance(o, typ):
                s += sum(map(sizeof, handler(o)))
                break
        return s

    return sizeof(o)

In [13]:
max_size = 0
for ast in game_asts:
    max_size = max(max_size, total_size(ast))

print(max_size)

278818


In [15]:
time = timeit.timeit(lambda: [total_size(ast) for ast in game_asts], number=1000)
print(time)

94.38803822919726


In [18]:
import pebble

In [19]:
pebble.current_process()

AttributeError: module 'pebble' has no attribute 'current_process'