# [ReCogLab] Generate Dataset
This is colab interface that can be used to access the library and generate examples.

We've provided an interface for easy to configure options. For a description of what each one corresponds to, see


In [None]:
%%capture

# Clone recoglab from github.
!git clone https://github.com/google-deepmind/recoglab.git
!pip install -r recoglab/requirements.txt

If you encounter:

```ValueError: All ufuncs must have type `numpy.ufunc`.```

 when importing libraries, you will need to restart the colab kernel so the python environment reloads the proper libraries.

In [None]:
%load_ext autoreload

%autoreload 2
import collections
import datetime
import os
import ml_collections
from recoglab import common_types
from recoglab import eval_io_lib
from recoglab import recoglab_dataset
from recoglab import utils
import tensorflow as tf

## ReCogLab Instructions

Different configurations will create different relational reasoning problems. We defined some common defaults that we found useful and performant to language model reasoning.

## Configuration

* **Global Configurations**: This section controls framework and i/o parameters like splits, number of examples, seeding, and save location. Use domain to build configurations for a specific domain.

* **Common Options**: This section are common configuration options between the four datasets.

* **Social Network**: This section controls Social Network domain tasks.
* **Comparison**: This section controls Comparison domain tasks.

In [None]:
# @title Global Configurations
def get_global_config():
  domain = 'social_network'  # @param ['social_network', 'comparison', 'syllogism', 'family_json']
  dataset_name = 'recoglab_dataset_name'  # @param {type: 'string'}
  recoglab_dataset_dir = 'generated_dataset'  # @param {type: 'string'}
  split = 'test'  # @param ['train', 'test', 'val']
  num_examples = 500  # @param {type: 'number'}
  csv_seeds = '42'  # @param {type: 'string'}
  csv_seed_list = [int(i) for i in csv_seeds.split(',')]
  return locals()

In [None]:
# @title Common Options
def common_config_form():
  num_entities = 20  # @param{type: 'number'}
  graph_type = 'linear'  # @param ['linear', 'random_tree', 'erdos_renyi']
  ordering = 'random'  # @param ['random', 'inorder', 'reverse']

  use_heuristic_rebalance = True  # @param{type:'boolean'}
  heuristic_rebalance_field = 'answer/symbolic_distance'  # @param ['answer','answer/symbolic_distance']

  # filler options
  add_filler = False  # @param{type:"boolean"}
  num_filler_lines = 10  # @param{type:"integer"}
  filler_type = 'random_text'  # @param['random_text', 'entity_filler']
  filler_position = 'interspersed'  # @param['before', 'after', 'interspersed']
  return locals()

In [None]:
# @title Social Network Options
def social_network_form():
  task_type = 'FastestMessage_NumHops'  # @param ['FastestMessage', 'FastestMessage_NumHops', 'FastestMessage_ExactPath']
  entity_type = 'baby-names'  # @param ['baby-names']
  relation_type = 'friend_advanced'  # @param ['friend', 'friend_advanced', 'relative']
  return locals()


def initialize_social_network_config():
  common_options = common_config_form()
  options = social_network_form()
  config = ml_collections.ConfigDict()
  config.name = 'SocialNetworkModule'

  # common_options
  config.network_type = common_options['graph_type']
  config.randomize_relations = common_options['ordering']
  config.num_entities_max = common_options['num_entities']

  # task-specific options
  config.entity_type = options['entity_type']
  config.relation_type = options['relation_type']
  config.query_task = options['task_type']
  # if randomize_direction is True:
  #   (A, B) can generate (A is friends with B) and (B is friends with A)
  config.randomize_direction = False
  config.add_filler = common_options['add_filler']

  # Add filler options
  if common_options['add_filler']:
    # Overwrite defaults
    config.num_filler_lines = common_options['num_filler_lines']
    config.filler_type = common_options['filler_type']
    config.filler_position = common_options['filler_position']

  # misc
  config.entities_mode = common_types.EntityMode.PRESET
  config.preamble = ''
  config.query_preamble = (
      'Any two friends are able to pass along a message, which allows messages '
      'to move from one friend to another. Thus, messages can be passed between'
      ' two people through friends they have in common.\n'
  )
  config.heuristic_field = (
      common_options['heuristic_rebalance_field']
      if common_options['use_heuristic_rebalance']
      else ''
  )
  return config

In [None]:
# @title Comparison Options
def comparison_form():
  task_type = 'ConsistencyDetection'  # @param ['Comparison', 'ConsistencyDetection', 'FeasibilityDetection']
  entity_type = 'baby-names'  # @param ['baby-names', 'basic_objects', 'congruent_objects', 'random_name']
  relation_type = 'size'  # @param['size', 'age', 'weight']
  congruency_mode = 'all_congruent'  # @param['all_congruent', 'all_incongruent', 'random']
  return locals()


def initialize_comparison_config():
  common_options = common_config_form()
  options = comparison_form()
  config = ml_collections.ConfigDict()
  # common_options
  config.network_type = common_options['graph_type']
  config.ordering = common_options['ordering']
  config.num_entities_max = common_options['num_entities']

  # task-specific options
  config.entity_type = options['entity_type']
  config.relation_type = options['relation_type']
  config.congruency_mode = options['congruency_mode']

  # misc
  config.entities_mode = common_types.EntityMode.PRESET
  config.randomize_relations = False
  config.preamble = ''
  config.query_preamble = ''
  config.add_filler = common_options['add_filler']

  # Add filler options
  if common_options['add_filler']:
    # Overwrite defaults
    config.num_filler_lines = common_options['num_filler_lines']
    config.filler_type = common_options['filler_type']
    config.filler_position = common_options['filler_position']
  config.heuristic_field = (
      common_options['heuristic_rebalance_field']
      if common_options['use_heuristic_rebalance']
      else ''
  )

  # Specific customization based on the task
  if options['task_type'] == 'Comparison':
    config.name = 'ComparisonModule'
  elif options['task_type'] == 'ConsistencyDetection':
    config.name = 'ComparisonValidModule'
    config.heuristic_field = 'answer'
  elif options['task_type'] == 'FeasibilityDetection':
    config.name = 'ComparisonModule'
    # feasibility only makes sense in directed trees
    config.network_type = 'random_tree'
    config.heuristic_field = 'answer'

  return config

In [None]:
# @title Syllogism Options
def syllogism_form():
  entity_type = 'plural_nouns'  # @param ['plural_nouns']
  entities_mode = 'preset'  # @param ['congruent', 'incongruent', 'preset']
  return locals()


def initialize_syllogism_config():
  common_options = common_config_form()
  options = syllogism_form()
  config = ml_collections.ConfigDict()
  # common_options
  config.network_type = common_options['graph_type']
  config.ordering = common_options['ordering']
  config.num_entities_max = common_options['num_entities']

  # task-specific options
  config.entity_type = options['entity_type']

  # misc
  config.entities_mode = options['entities_mode']
  config.randomize_relations = False
  config.preamble = ''
  config.query_preamble = ''
  config.add_filler = common_options['add_filler']
  # Filler not implemented
  if common_options['add_filler']:
    raise RuntimeError('Filler not yet implemented for this module.')
  config.heuristic_field = (
      common_options['heuristic_rebalance_field']
      if common_options['use_heuristic_rebalance']
      else ''
  )

  if options['entities_mode'] == 'congruent':
    config.name = 'CongruentSyllogismModule'
  elif options['entities_mode'] == 'incongruent':
    config.name = 'IncongruentSyllogismModule'
  else:
    config.name = 'SyllogismModule'

  return config

In [None]:
# @title Family JSON Options

family = recoglab_dataset.family


def family_json_form():
  task = 'family_member_hobby_comparison'  # @param ['family_size', 'family_member_hobby', 'family_size_comparison','family_member_age_comparison','family_member_hobby_comparison']
  num_families = 10  # @param
  max_members = 5  # @param
  return locals()


def get_family_module_config(relation_type):
  options = family_json_form()

  family_module_config = ml_collections.ConfigDict()
  family_module_config.name = 'FamilyModule'
  family_module_config.num_families = options[
      'num_families'
  ]  # Used in Sweep ~ 1 Fam = 180 tokens
  family_module_config.max_members = options['max_members']  # Used in Sweep
  family_module_config.relation_type = relation_type
  family_module_config.preamble = ''
  family_module_config.query_preamble = ''
  family_module_config.termination_type = 'None'
  family_module_config.hop_length = -1  # -1 means sample between 1 to num_fam/2
  return family_module_config


def initialize_family_config():
  common_options = common_config_form()
  options = family_json_form()
  config = ml_collections.ConfigDict()

  task = options['task']

  if task == 'family_size':
    # config.all_module_names = ['family_size']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_SIZE
    )
  elif task == 'family_member_hobby':
    # config.all_module_names = ['family_member_hobby']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_MEMBER_HOBBY
    )
  elif task == 'family_size_comparison':
    # config.all_module_names = ['family_size_comparison']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_SIZE_COMPARISON
    )
  elif task == 'family_member_age_comparison':
    # config.all_module_names = ['family_member_age_comparison']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_MEMBER_AGE_COMPARISON
    )
  elif task == 'family_member_hobby_comparison':
    # config.all_module_names = ['family_member_hobby_comparison']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_MEMBER_HOBBY_COMPARISON
    )
  else:
    raise ValueError('unrecognized task: %s' % options['task'])

  # misc
  config.preamble = ''
  config.query_preamble = ''
  config.add_filler = common_options['add_filler']
  config.heuristic_field = common_options.get('heuristic_field', None)
  config.entities_mode = common_options.get('entities_mode', None)
  # Filler not implemented
  if common_options.get('add_filler', None):
    raise RuntimeError('Filler not yet implemented for this module.')
  if common_options.get('heuristic_field', None):
    raise RuntimeError('N/A for this task.')
  if common_options.get('entities_mode', None):
    raise RuntimeError('N/A for this task.')

  return config

In [None]:
# @title Construct Config
def construct_config_given_global(global_config):
  stamp_creation_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
  if global_config['domain'] == 'social_network':
    config = initialize_social_network_config()
  elif global_config['domain'] == 'comparison':
    config = initialize_comparison_config()
  elif global_config['domain'] == 'syllogism':
    config = initialize_syllogism_config()
  elif global_config['domain'] == 'family_json':
    config = initialize_family_config()
  else:
    raise ValueError('unrecognized domain: %s' % global_config['domain'])
  outer_config = ml_collections.ConfigDict()
  outer_config._config = config
  outer_config.all_module_names = ['_config']
  outer_config.maintain_entity_uniqueness = False
  outer_config.config_only_hash = utils.order_invariant_hash(outer_config)
  outer_config.timestamp = stamp_creation_time
  outer_config.heuristic_field = config.heuristic_field
  del config.heuristic_field
  return outer_config


global_config = get_global_config()
dataset_config = construct_config_given_global(global_config)
print(dataset_config)

## Writing ReCogLab generated datasets to disk.

Each ReCogLab example is encoded entirely in Python
and Tensorflow primitives when written to disk. So you won't need ReCogLab to load your dataset!

In [None]:
# @title Write TFRecord code


def write_examples_and_given_config(examples, global_config, seed, split):
  dataset_name = global_config['dataset_name'] + '_' + str(seed) + '_' + split
  recoglab_dir = global_config['recoglab_dataset_dir']
  if not os.path.exists(recoglab_dir):
    os.makedirs(recoglab_dir)
  tfrecord = f'{dataset_name}.tfrecord'
  config_path = f'{dataset_name}.config'

  tfrecord = os.path.join(recoglab_dir, tfrecord)
  config_path = os.path.join(recoglab_dir, config_path)
  # need to do something with save_name
  index = 0
  parent_path = os.path.dirname(tfrecord)
  if not os.path.exists(parent_path):
    os.makedirs(parent_path)
  with tf.io.TFRecordWriter(tfrecord) as writer:
    assert isinstance(writer, tf.io.TFRecordWriter)
    for example in examples:
      example_proto = eval_io_lib.recoglab_dataset_example_to_tf_example(
          example, index
      )
      writer.write(example_proto.SerializeToString())
      index += 1
  with open(config_path, 'w') as f:
    f.write(dataset_config.to_json_best_effort(indent=2))

In [None]:
split = global_config['split']
seeds = global_config['csv_seed_list']
for seed in seeds:
  examples = recoglab_dataset.generate_dataset(
      dataset_config,
      split=split,
      seed=seed,
      num_examples=global_config['num_examples'],
      metadata_rebalance_field=dataset_config.heuristic_field,
  )
  # Write the dataset to directory specified by global_config.
  write_examples_and_given_config(examples, global_config, seed, split)

In [None]:
# @title ReCogLab Visualization code
def render_example(example):
  if isinstance(example, dict):
    prompt_and_question = example["question"]
    lines = prompt_and_question.split("\n")
    p = "\n".join(lines[:-1])
    q = lines[-1]
    a = [example["answer"]] + example["alternative_answers"]
  else:
    p = example.get_prompts()
    q = example.get_question()
    a = example.get_answers()
  print()
  print("Prompt")
  print("======")
  print(p)
  print()
  print("Question")
  print("======")
  print(q)
  print()
  print("Answer")
  print("======")
  print(a)


def count_answers(examples):
  answer_hist = collections.defaultdict(int)
  for e in examples:
    answer_hist[tuple(e.get_answers())] += 1
  return answer_hist


def conut_metadata(examples, metadata_field):
  metadata_hist = collections.defaultdict(int)
  for e in examples:
    metadata_value = e.get_metadata()[metadata_field]
    metadata_hist[metadata_value] += 1
  return metadata_hist

In [None]:
# Examine the output produced by ReCogLab
render_example(examples[0])
print()
print("Dataset Metadata")
print(conut_metadata(examples, "answer/symbolic_distance"))
print(count_answers(examples))

In [None]:
# Loads the dataset from disk and examine the outputs.
# This only requires immutabledict and Tensorflow Dataset.
import ast
import immutabledict
import tensorflow as tf

DEFAULT_MULTI_ANSWER_SEPARATOR = '\t'  # default answer splitter in strings
_FEATURE_SPEC = immutabledict.immutabledict({
    'index': tf.io.FixedLenFeature([], tf.int64, default_value=-1),
    'question': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'question_only': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'answer': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'metadata': tf.io.FixedLenFeature([], tf.string, default_value=''),
    'alternative_answers': tf.io.FixedLenFeature(
        [], tf.string, default_value=''
    ),
})


def tf_to_python_primitives(
    dict_example,
):
  """Converts a dict of tf Tensors to python primitives.

  Args:
    dict_example: a dictionary outputted by a tf dataset for an example.

  Returns:
    A dictionary with values that are python primitives like strings, int, and
    dict.
  """
  output = {
      'index': int(dict_example['index'].numpy()),
      'question': dict_example['question'].numpy().decode(),
      'answer': dict_example['answer'].numpy().decode(),
  }
  alternative_answers_str = dict_example['alternative_answers'].numpy().decode()
  if alternative_answers_str:
    output['alternative_answers'] = alternative_answers_str.split(
        DEFAULT_MULTI_ANSWER_SEPARATOR
    )
  else:
    output['alternative_answers'] = []
  metadata_repr = dict_example['metadata'].numpy().decode()
  if metadata_repr:
    output['metadata'] = ast.literal_eval(metadata_repr)
  else:
    output['metadata'] = {}
  return output


def parse_example(serialized_item: bytes) -> dict[str, Any]:
  """Parse a tf example for its content.

  Args:
    serialized_item: serialized example.

  Returns:
    unserialized tf.Example in a dictionary and tf.Tensor form.
  """
  return tf.io.parse_example(serialized_item, _FEATURE_SPEC)


def load_tfrecord(filepath):
  dataset = tf.data.TFRecordDataset(filepath)
  return dataset.map(parse_example)


def print_examples_from_tfrecord(tfrecord_path):
  tf_ds = load_tfrecord(tfrecord_path)
  for example in tf_ds.take(1):
    python_example = tf_to_python_primitives(example)
    print(python_example)


print_examples_from_tfrecord(
    '/content/generated_dataset/recoglab_dataset_name_42_test.tfrecord'
)

In [None]:
!zip -r '/content/recoglab_generated_dataset.zip' '/content/generated_dataset'

In [None]:
# Downloads the zipped folder of ReCogLab to your local computer.
from google.colab import files

files.download('/content/recoglab_generated_dataset.zip')