# [ReCogLab] Generate Dataset
This is colab interface that can be used to access the library and generate examples.

We've provided an interface for easy to configure options. For a description of what each one corresponds to, see


In [None]:
%%capture

# Clone recoglab from github.
!git clone https://github.com/google-deepmind/recoglab.git
!pip install -r recoglab/requirements.txt

If you encounter:

```ValueError: All ufuncs must have type `numpy.ufunc`.```

 when importing libraries, you will need to restart the colab kernel so the python environment reloads the proper libraries.

In [None]:
%load_ext autoreload

%autoreload 2
import collections
import datetime
import ml_collections
from recoglab import common_types
from recoglab import recoglab_dataset
from recoglab import utils

## Instructions

Different configurations will create different relational reasoning problems. We defined some common defaults that we found useful and performant to language model reasoning.

## Configuration

* **Global Configurations**: This section controls framework and i/o parameters like splits, number of examples, seeding, and save location. Use domain to build configurations for a specific domain.

* **Common Options**: This section are common configuration options between the four datasets.

* **Social Network**: This section controls Social Network domain tasks.
* **Comparison**: This section controls Comparison domain tasks.

In [None]:
# @title Global Configurations
def get_global_config():
  domain = 'social_network'  # @param ['social_network', 'comparison', 'syllogism', 'family_json']
  dataset_name = ''  # @param {type: 'string'}
  recoglab_dataset_dir = ''  # @param {type: 'string'}
  split = 'test'  # @param ['train', 'test', 'val']
  num_examples = 500  # @param {type: 'number'}
  csv_seeds = '42'  # @param {type: 'string'}
  csv_seed_list = [int(i) for i in csv_seeds.split(',')]
  return locals()

In [None]:
# @title Common Options
def common_config_form():
  num_entities = 20  # @param{type: 'number'}
  graph_type = 'linear'  # @param ['linear', 'random_tree', 'erdos_renyi']
  ordering = 'random'  # @param ['random', 'inorder', 'reverse']

  use_heuristic_rebalance = True  # @param{type:'boolean'}
  heuristic_rebalance_field = 'answer/symbolic_distance'  # @param ['answer','answer/symbolic_distance']

  # filler options
  add_filler = False # @param{type:"boolean"}
  num_filler_lines = 10 # @param{type:"integer"}
  filler_type = 'random_text' # @param['random_text', 'entity_filler']
  filler_position = 'interspersed' # @param['before', 'after', 'interspersed']
  return locals()

In [None]:
# @title Social Network Options
def social_network_form():
  task_type = 'FastestMessage_NumHops'  # @param ['FastestMessage', 'FastestMessage_NumHops', 'FastestMessage_ExactPath']
  entity_type = 'baby-names'  # @param ['baby-names']
  relation_type = 'friend_advanced'  # @param ['friend', 'friend_advanced', 'relative']
  return locals()


def initialize_social_network_config():
  common_options = common_config_form()
  options = social_network_form()
  config = ml_collections.ConfigDict()
  config.name = 'SocialNetworkModule'

  # common_options
  config.network_type = common_options['graph_type']
  config.randomize_relations = common_options['ordering']
  config.num_entities_max = common_options['num_entities']

  # task-specific options
  config.entity_type = options['entity_type']
  config.relation_type = options['relation_type']
  config.query_task = options['task_type']
  # if randomize_direction is True:
  #   (A, B) can generate (A is friends with B) and (B is friends with A)
  config.randomize_direction = False
  config.add_filler = common_options['add_filler']

  # Add filler options
  if common_options['add_filler']:
    # Overwrite defaults
    config.num_filler_lines = common_options['num_filler_lines']
    config.filler_type = common_options['filler_type']
    config.filler_position = common_options['filler_position']

  # misc
  config.entities_mode = common_types.EntityMode.PRESET
  config.preamble = ''
  config.query_preamble = (
      'Any two friends are able to pass along a message, which allows messages '
      'to move from one friend to another. Thus, messages can be passed between'
      ' two people through friends they have in common.\n'
  )
  config.heuristic_field = (
      common_options['heuristic_rebalance_field']
      if common_options['use_heuristic_rebalance']
      else ''
  )
  return config

In [None]:
# @title Comparison Options
def comparison_form():
  task_type = 'ConsistencyDetection'  # @param ['Comparison', 'ConsistencyDetection', 'FeasibilityDetection']
  entity_type = 'baby-names'  # @param ['baby-names', 'basic_objects', 'congruent_objects', 'random_name']
  relation_type = 'size'  # @param['size', 'age', 'weight']
  congruency_mode = 'all_congruent'  # @param['all_congruent', 'all_incongruent', 'random']
  return locals()


def initialize_comparison_config():
  common_options = common_config_form()
  options = comparison_form()
  config = ml_collections.ConfigDict()
  # common_options
  config.network_type = common_options['graph_type']
  config.ordering = common_options['ordering']
  config.num_entities_max = common_options['num_entities']

  # task-specific options
  config.entity_type = options['entity_type']
  config.relation_type = options['relation_type']
  config.congruency_mode = options['congruency_mode']

  # misc
  config.entities_mode = common_types.EntityMode.PRESET
  config.randomize_relations = False
  config.preamble = ''
  config.query_preamble = ''
  config.add_filler = common_options['add_filler']

  # Add filler options
  if common_options['add_filler']:
    # Overwrite defaults
    config.num_filler_lines = common_options['num_filler_lines']
    config.filler_type = common_options['filler_type']
    config.filler_position = common_options['filler_position']
  config.heuristic_field = (
      common_options['heuristic_rebalance_field']
      if common_options['use_heuristic_rebalance']
      else ''
  )

  # Specific customization based on the task
  if options['task_type'] == 'Comparison':
    config.name = 'ComparisonModule'
  elif options['task_type'] == 'ConsistencyDetection':
    config.name = 'ComparisonValidModule'
    config.heuristic_field = 'answer'
  elif options['task_type'] == 'FeasibilityDetection':
    config.name = 'ComparisonModule'
    # feasibility only makes sense in directed trees
    config.network_type = 'random_tree'
    config.heuristic_field = 'answer'

  return config

In [None]:
# @title Syllogism Options
def syllogism_form():
  entity_type = 'plural_nouns'  # @param ['plural_nouns']
  entities_mode = 'preset'  # @param ['congruent', 'incongruent', 'preset']
  return locals()


def initialize_syllogism_config():
  common_options = common_config_form()
  options = syllogism_form()
  config = ml_collections.ConfigDict()
  # common_options
  config.network_type = common_options['graph_type']
  config.ordering = common_options['ordering']
  config.num_entities_max = common_options['num_entities']

  # task-specific options
  config.entity_type = options['entity_type']

  # misc
  config.entities_mode = options['entities_mode']
  config.randomize_relations = False
  config.preamble = ''
  config.query_preamble = ''
  config.add_filler = common_options['add_filler']
  # Filler not implemented
  if common_options['add_filler']:
    raise RuntimeError('Filler not yet implemented for this module.')
  config.heuristic_field = (
      common_options['heuristic_rebalance_field']
      if common_options['use_heuristic_rebalance']
      else ''
  )

  if options['entities_mode'] == 'congruent':
    config.name = 'CongruentSyllogismModule'
  elif options['entities_mode'] == 'incongruent':
    config.name = 'IncongruentSyllogismModule'
  else:
    config.name = 'SyllogismModule'

  return config

In [None]:
#@title Family JSON Options

family = recoglab_dataset.family
def family_json_form():
  task = 'family_member_hobby_comparison' # @param ['family_size', 'family_member_hobby', 'family_size_comparison','family_member_age_comparison','family_member_hobby_comparison']
  num_families = 10   #@param
  max_members = 5   #@param
  return locals()

def get_family_module_config(relation_type):
  options = family_json_form()

  family_module_config = ml_collections.ConfigDict()
  family_module_config.name = 'FamilyModule'
  family_module_config.num_families = options['num_families']  # Used in Sweep ~ 1 Fam = 180 tokens
  family_module_config.max_members = options['max_members']  # Used in Sweep
  family_module_config.relation_type = relation_type
  family_module_config.preamble = ''
  family_module_config.query_preamble = ''
  family_module_config.termination_type = 'None'
  family_module_config.hop_length = -1  # -1 means sample between 1 to num_fam/2
  return family_module_config

def initialize_family_config():
  common_options = common_config_form()
  options = family_json_form()
  config = ml_collections.ConfigDict()

  task = options['task']

  if task == 'family_size':
    # config.all_module_names = ['family_size']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_SIZE
    )
  elif task == 'family_member_hobby':
    # config.all_module_names = ['family_member_hobby']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_MEMBER_HOBBY
    )
  elif task == 'family_size_comparison':
    # config.all_module_names = ['family_size_comparison']
    config =  get_family_module_config(
        relation_type=family.RelationType.FAMILY_SIZE_COMPARISON
    )
  elif task == 'family_member_age_comparison':
    # config.all_module_names = ['family_member_age_comparison']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_MEMBER_AGE_COMPARISON
    )
  elif task == 'family_member_hobby_comparison':
    # config.all_module_names = ['family_member_hobby_comparison']
    config = get_family_module_config(
        relation_type=family.RelationType.FAMILY_MEMBER_HOBBY_COMPARISON
    )
  else:
    raise ValueError('unrecognized task: %s' % options['task'])

  # misc
  config.preamble = ''
  config.query_preamble = ''
  config.add_filler = common_options['add_filler']
  config.heuristic_field = common_options.get('heuristic_field', None)
  config.entities_mode = common_options.get('entities_mode', None)
  # Filler not implemented
  if common_options.get('add_filler', None):
    raise RuntimeError('Filler not yet implemented for this module.')
  if common_options.get('heuristic_field', None):
    raise RuntimeError('N/A for this task.')
  if common_options.get('entities_mode', None):
    raise RuntimeError('N/A for this task.')

  return config



In [None]:

# @title Construct Config
def construct_config_given_global(global_config):
  stamp_creation_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
  if global_config['domain'] == 'social_network':
    config = initialize_social_network_config()
  elif global_config['domain'] == 'comparison':
    config = initialize_comparison_config()
  elif global_config['domain'] == 'syllogism':
    config = initialize_syllogism_config()
  elif global_config['domain'] == 'family_json':
    config = initialize_family_config()
  else:
    raise ValueError('unrecognized domain: %s' % global_config['domain'])
  outer_config = ml_collections.ConfigDict()
  outer_config._config = config
  outer_config.all_module_names = ['_config']
  outer_config.maintain_entity_uniqueness = False
  outer_config.config_only_hash = utils.order_invariant_hash(outer_config)
  outer_config.timestamp = stamp_creation_time
  outer_config.heuristic_field = config.heuristic_field
  del config.heuristic_field
  return outer_config


global_config = get_global_config()
dataset_config = construct_config_given_global(global_config)
print(dataset_config)

In [None]:
split = global_config['split']
seeds = global_config['csv_seed_list']
for seed in seeds:
  examples = recoglab_dataset.generate_dataset(
      dataset_config,
      split=split,
      seed=seed,
      num_examples=global_config['num_examples'],
      metadata_rebalance_field=dataset_config.heuristic_field,
  )

In [None]:
def render_example(example):
  print()
  print("Prompt")
  print("======")
  print(example.get_prompts())
  print()
  print("Question")
  print("======")
  print(example.get_question())
  print()
  print("Answer")
  print("======")
  print(example.get_answers())


def count_answers(examples):
  answer_hist = collections.defaultdict(int)
  for e in examples:
    answer_hist[tuple(e.get_answers())] += 1
  return answer_hist


def conut_metadata(examples, metadata_field):
  metadata_hist = collections.defaultdict(int)
  for e in examples:
    metadata_value = e.get_metadata()[metadata_field]
    metadata_hist[metadata_value] += 1
  return metadata_hist

In [None]:
render_example(examples[0])

In [None]:
print(conut_metadata(examples, 'answer/symbolic_distance'))
print(count_answers(examples))