Copyright 2020 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# The dataset used for the Paired associate inference task

This is the dataset used for the paired associated inference task in
["MEMO: A Deep Network for Flexible Combination of Episodic Memories
"](https://arxiv.org/abs/2001.10913).

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

import tensorflow as tf
import collections
import os

from google.colab import auth
auth.authenticate_user()

In [0]:
#@title Choices about the dataset you want to load.
# Make choices about the dataset here.
chain_length = 3  #@param {type:"slider", min:3, max:4, step:1}
mode = 'valid' #@param ['train', 'test', 'valid']

**If you choose chain_length 3 the data will look like this:**

*   trials shape: (48, 3, 1000); 48 trials x the target picture, left and right option x picture dimensions.
*   correct answer: (48); whether the left or right picture is correct.
*   difficulty (48); How far apart are the target picture and the two options.(e.g. AB are 0 steps apart, AC is 1)
*   trial type (48); See below.
*   memory shape (32, 2, 1000); Content of memory store, 32 pairs of images.

Trial types:
*   1: AB
*   2: BC
*   3: AC


**If you choose chain_length 4 the data will look like this:**
*   trials:  (96, 3, 1000)
*   correct answer: (96)
*   difficulty: (96)
*   trial type: (96)
*   memory shape: (48, 2, 1000)

Trial types:
*   1: AB
*   2: BC
*   3: AC
*   4: CD
*   5: BD
*   6: AD

In [0]:
# Train has 500 shards, valid 150, test 100.
if mode == 'train':
  num_shards = 500
elif mode == 'test':
  num_shards = 100
elif mode == 'valid':
  num_shards = 150

In [0]:
DatasetInfo = collections.namedtuple(
    'DatasetInfo',
    ['basepath', 'size', 'chain_length']
)

_DATASETS = dict(
    memo=DatasetInfo(
        basepath=mode,
        size=num_shards,
        chain_length=chain_length)
)

In [0]:
def _get_dataset_files(dataset_info, root):
  """Generates lists of files for a given dataset version."""
  basepath = dataset_info.basepath
  base = os.path.join(root, basepath)
  num_files = dataset_info.size
  length = len(str(num_files))
  template = 'trials-{:0%d}-of-{:0%d}' % (5, 5)
  return [os.path.join(base, template.format(i, num_files))
          for i in range(num_files)]

In [0]:
def parser_tf_examples(raw_data, chain_length=chain_length):
  if chain_length == 3:
    feature_map = {
        'trials' : tf.io.FixedLenFeature(
            shape=[48, 3, 1000],
            dtype=tf.float32),
        'correct_answer': tf.io.FixedLenFeature(
           shape=[48],
           dtype=tf.int64),
        'difficulty': tf.io.FixedLenFeature(
           shape=[48],
           dtype=tf.int64),
        'trial_type': tf.io.FixedLenFeature(
           shape=[48],
           dtype=tf.int64),
       'memory': tf.io.FixedLenFeature(
           shape=[32, 2, 1000],
           dtype=tf.float32),
   }
  elif chain_length == 4: 
   feature_map = {
       'trials' : tf.io.FixedLenFeature(
           shape=[96, 3, 1000],
           dtype=tf.float32),
       'correct_answer': tf.io.FixedLenFeature(
           shape=[96],
           dtype=tf.int64),
       'difficulty': tf.io.FixedLenFeature(
           shape=[96],
           dtype=tf.int64),
       'trial_type': tf.io.FixedLenFeature(
           shape=[96],
           dtype=tf.int64),
       'memory': tf.io.FixedLenFeature(
          shape=[48, 2, 1000],
            dtype=tf.float32),
   }
  example = tf.io.parse_example(raw_data, feature_map)
  batch = [example["trials"],
            example["correct_answer"],
            example["difficulty"],
            example["trial_type"],
            example["memory"]]
  return batch

## Load the data.

In [0]:
dataset_info = 'memo'
root = 'gs://deepmind-memo/length' + str(chain_length) + '/'
num_epochs = 100
shuffle_buffer_size = 150
num_readers = 4
dataset_info = _DATASETS['memo']
filenames = _get_dataset_files(dataset_info, root)
num_map_threads = 4
batch_size = 10

In [0]:
data = tf.data.Dataset.from_tensor_slices(filenames)
data = data.repeat(num_epochs)
data = data.shuffle(shuffle_buffer_size)
data = data.interleave(tf.data.TFRecordDataset,
                          cycle_length=num_readers, block_length=1)
data = data.shuffle(shuffle_buffer_size)
data = data.map(parser_tf_examples, num_parallel_calls=num_map_threads)
data = data.batch(batch_size)

# Looking at what we loaded.

In [0]:
iterator = data.__iter__()
element = iterator.get_next()

In [0]:
print(element[0].shape) # trials
print(element[1].shape) # correct answer
print(element[2].shape) # difficulty
print(element[3].shape) # trialtype
print(element[4].shape) # memory