### Modelling relational inference

As in Whittington et al. (2020), we model the spatial task of predicting the next location in a trajectory as the prediction of the next node in a graph. We create a large set of graphs, each one an n-by-n grid of nodes representing a simple spatial environment. Nodes are labelled with random letters to represent arbitrary associations at a particular location. Each directed edge, i.e. each possible transition in the graph, is of the type north, south, east, or west. Random walks in the set of graphs are used to train the model; these could represent sequences stored in an initial bank of memories. The generative model is trained from scratch on the replayed sequences (converted to strings of the form ‘node1 E node2 W node3 …’) with the mechanism of causal language modelling.

Tested with conda_pytorch_latest_p36 kernel in AWS SageMaker.

#### Installation:

In [None]:
!pip install simpletransformers csrgraph networkx==2.8

#### Imports:

In [None]:
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import csrgraph as cg
import numpy as np
import random
import string
from graph_utils import *
from gpt import GPT

#### Train generative model

Train GPT-2 from scratch on dataset created above.

In [None]:
text_file = open("train.txt", "w")
walks = get_walks_as_strings(n_graphs=20000, n_walks=1)
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("test.txt", "w")
walks = get_walks_as_strings(n_graphs=1000, n_walks=1)
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

gpt = GPT(vocab_size=100)
gpt.train(segmented_sequence_list=[], best_model_dir='outputs_graph', train_file="train.txt", test_file="test.txt", eps=10)

#### Load trained model for sequence generation

In [None]:
model = GPT(base_model='outputs_graph', base_model_name='gpt2')

In [None]:
model.continue_input("a E b S e W d N", do_sample=False)

In [None]:
loops = ["{} E {} S {} W {} N", "{} S {} W {} N {} E", "{} W {} N {} E {} S", "{} N {} E {} S {} W",
        "{} E {} N {} W {} S", "{} N {} W {} S {} E", "{} W {} S {} E {} N", "{} S {} E {} N {} W"]

In [None]:
def test_loop():
    random_nodes = random.sample(string.ascii_letters[0:26], 4)
    loop = random.choice(loops)
    test_string = loop.format(random_nodes[0], random_nodes[1], random_nodes[2], random_nodes[3])
    output = model.continue_input(test_string, do_sample=False)
    output = output[0:len(test_string)+2]
    print(output)
    if output[-1] == output[0]:
        return 1
    else:
        return 0

results = [test_loop() for i in range(100)]

In what percentage of trials was the next node correct?

In [None]:
results.count(1)

#### A more challenging test

For an arbitrary loop in the graph, can the model predict the final item?

In [None]:
def get_cycles_for_graph(G):
    cycles = nx.simple_cycles(G)
    loops = []
    for c in cycles:
        path_string = ""
        for ind, node in enumerate(c):
            if ind+1 < len(c):
                direction = G.get_edge_data(c[ind], c[ind+1])['direction']
                path_string += '{} {} '.format(node, direction)
            else:
                direction = G.get_edge_data(c[ind], c[0])['direction']
                path_string += '{} {} '.format(node, direction)
        loops.append(path_string)
    return loops

def test_loop(num_graphs = 5):
    results = []
    lens = []
    
    for i in range(num_graphs):
        entities_for_graphs =[random.sample(string.ascii_letters[0:26], 9) for i in range(100)]
        nodes = entities_for_graphs[0]
        G = get_graph(nodes=nodes)
        test_strings = get_cycles_for_graph(G)

        for test_string in test_strings:
            lens.append((len(test_string))/4)
            output = model.continue_input(test_string)
            output = output[0:len(test_string)+1]
            if output[-1] == output[0]:
                results.append(1)
            else:
                results.append(0)
    
    return results, lens

results, lens = test_loop()

In [None]:
results, lens

#### Plot structural inference accuracy against graph cycle length

In [None]:
def acc_for_len(length):
    accs = [r for ind, r in enumerate(results) if lens[ind] == length]
    return accs.count(1) / len(accs)

lengths = [2, 4, 6, 8]
accuracies = [acc_for_len(i) for i in lengths]

plt.bar(lengths, accuracies)
plt.title('Next node inference accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Number of transitions')
plt.tight_layout()
plt.savefig('graph_cycle_length.png')

In [None]:
plt.figure()
plt.rcParams.update({'font.size' : 15})

df = pd.read_csv('outputs_graph/training_progress_scores.csv')
df = df.iloc[0:7]
df.plot(x='global_step', y='eval_loss', title='Loss over time', 
                   ylabel='Loss on test set', xlabel = 'Training step', legend=False)

plt.tight_layout()
plt.savefig('graph-gpt.png')