In [None]:
from run import main
from absl import flags
from absl import app
import sys

flags.DEFINE_string('algorithm', 'lcs_length', 'Which algorithm to run.')
flags.DEFINE_integer('seed', 42, 'Random seed to set')

flags.DEFINE_integer('batch_size', 16, 'Batch size used for training.')
flags.DEFINE_boolean('chunked_training', False,
                     'Whether to use chunking for training.')
flags.DEFINE_integer('chunk_length', 100,
                     'Time chunk length used for training (if '
                     '`chunked_training` is True.')
flags.DEFINE_integer('train_items', 320000,
                     'Number of items (i.e., individual examples, possibly '
                     'repeated) processed during training. With non-chunked'
                     'training, this is the number of training batches times '
                     'the number of training steps. For chunked training, '
                     'as many chunks will be processed as needed to get these '
                     'many full examples.')
flags.DEFINE_integer('eval_every', 320,
                     'Logging frequency (in training examples).')
flags.DEFINE_boolean('verbose_logging', False, 'Whether to log aux losses.')

flags.DEFINE_integer('hidden_size', 128,
                     'Number of hidden size units of the model.')
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.')
flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.')
flags.DEFINE_float('hint_teacher_forcing_noise', 0.5,
                   'Probability that rematerialized hints are encoded during '
                   'training instead of ground-truth teacher hints. Only '
                   'pertinent in encoded_decoded modes.')
flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors')

flags.DEFINE_enum('hint_mode', 'encoded_decoded_nodiff',
                  ['encoded_decoded', 'decoded_only',
                   'encoded_decoded_nodiff', 'decoded_only_nodiff',
                   'none'],
                  'How should hints be used? Note, each mode defines a '
                  'separate task, with various difficulties. `encoded_decoded` '
                  'requires the model to explicitly materialise hint sequences '
                  'and therefore is hardest, but also most aligned to the '
                  'underlying algorithmic rule. Hence, `encoded_decoded` '
                  'should be treated as the default mode for our benchmark. '
                  'In `decoded_only`, hints are only used for defining '
                  'reconstruction losses. Often, this will perform well, but '
                  'note that we currently do not make any efforts to '
                  'counterbalance the various hint losses. Hence, for certain '
                  'tasks, the best performance will now be achievable with no '
                  'hint usage at all (`none`). The `no_diff` variants '
                  'try to predict all hint values instead of just the values '
                  'that change from one timestep to the next.')

flags.DEFINE_boolean('use_ln', True,
                     'Whether to use layer normalisation in the processor.')
flags.DEFINE_string('use_memory', "NTM",
                    'Whether to insert memory after message passing.')
flags.DEFINE_enum(
    'processor_type', 'gatv2',
    ['deepsets', 'mpnn', 'pgn', 'pgn_mask',
     'gat', 'gatv2', 'gat_full', 'gatv2_full',
     'memnet_full', 'memnet_masked'],
    'The processor type to use.')

flags.DEFINE_string('checkpoint_path', '/tmp/CLRS30',
                    'Path in which checkpoints are saved.')
flags.DEFINE_string('dataset_path', '/tmp/CLRS30',
                    'Path in which dataset is stored.')
flags.DEFINE_boolean('freeze_processor', False,
                     'Whether to freeze the processor of the model.')
flags.DEFINE_integer('memory_size', 20,
                     'Size of differentiable data structure memory.')
FLAGS = flags.FLAGS

FLAGS(sys.argv)
GAT_BEST = [
    'dfs',
    'jarvis_march',
    'kmp_matcher',
    'lcs_length',
    'quickselect',
    'task_scheduling'
]
# MPNN algos
MPNN_BEST = [
    'articulation_points',
    'activity_selector',
    'bfs',
    'bridges',
    'dijkstra',
    'graham_scan',
    'mst_kruskal',
    'mst_prim',
    'naive_string_matcher',
    'segments_intersect',
    'strongly_connected_components',
]
PGN_best = [
    'activity_selector',
    'bellman_ford',
    'binary_search',
    'dag_shortest_paths',
    'find_maximum_subarray_kadane',
    'floyd_warshall',
    'matrix_chain_order',
    'minimum',
    'mst_prim',
    'optimal_bst',
    'quickselect',
    'strongly_connected_components',
    'task_scheduling',
    'topological_sort',
]

# memory_type = "NTM"
# model="gatv2"
# memory_size=20

FLAGS.memory_size=100
for model in ["gatv2","mpnn"]:
    FLAGS.processor_type = model
    if FLAGS.processor_type=="gatv2" or FLAGS.processor_type=="gat":
        algo_list=GAT_BEST
    elif FLAGS.processor_type=="mpnn":
        algo_list = MPNN_BEST
    else:
        algo_list = PGN_best

    for algo in algo_list:
        FLAGS.algorithm = algo

        with open("results.txt") as myfile:
            txt = myfile.read()
            if not (f"{algo}_{FLAGS.processor_type}_{FLAGS.use_memory}_{FLAGS.memory_size}" in txt) and not (f"{algo}_best_{FLAGS.processor_type}_{FLAGS.use_memory}_{FLAGS.memory_size}" in txt) :
                print(f"running with specs: {algo}, {FLAGS.use_memory}, {FLAGS.processor_type}, {FLAGS.memory_size}")
                app.run(main)