### Modelling sequence learning

As in Whittington et al. (2020), we model the spatial task of predicting the next location in a trajectory as the prediction of the next node in a graph. We create a large set of graphs, each one an n-by-n grid of nodes representing a simple spatial environment. Nodes are labelled with random letters to represent arbitrary associations at a particular location. Each directed edge, i.e. each possible transition in the graph, is of the type north, south, east, or west. Random walks in the set of graphs are used to train the model; these could represent sequences stored in an initial bank of memories. The generative model is trained from scratch on the replayed sequences (converted to strings of the form ‘node1 E node2 W node3 …’) with the mechanism of causal language modelling.

This is a Colab version of the original notebook; to run this outside of Colab, please use the version in the shapes_vae directory.

This code will only work with access to a GPU. To switch to GPU on Colab, go to 'Runtime' > 'Change runtime type', and select 'GPU' from the 'Hardware acceleration' dropdown menu. (Note that the free version of Colab will only allow this for one notebook at a time.)

#### Colab installation:

Make sure you click 'Restart runtime' after running this cell.

In [None]:
!pip install simpletransformers csrgraph networkx

#### Imports:

In [None]:
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import csrgraph as cg
import numpy as np
import random
import string
from simpletransformers.language_modeling import (
    LanguageModelingModel,
    LanguageModelingArgs,
)
from simpletransformers.language_generation import (
    LanguageGenerationModel, 
    LanguageGenerationArgs,
)

#### Prepare training data

The function below takes a list of node names (which could represent arbitrary characteristics of points in space) and constructs a directed graph in the shape of a 3X3 grid. Each transition / edge is either north, south, east or west (i.e. N, S, E, or W).

In [None]:
def get_graph(nodes = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]):

    G = nx.DiGraph()
    east_pairs = [(nodes[0], nodes[1]), (nodes[1], nodes[2]), (nodes[3], nodes[4]), 
                  (nodes[4], nodes[5]), (nodes[6], nodes[7]), (nodes[7], nodes[8])]
    south_pairs = [(nodes[0], nodes[3]), (nodes[3], nodes[6]), (nodes[1], nodes[4]), 
                   (nodes[4], nodes[7]), (nodes[2], nodes[5]), (nodes[5], nodes[8])]
    north_pairs = [(i[1], i[0]) for i in south_pairs]
    west_pairs = [(i[1], i[0]) for i in east_pairs]

    for n in nodes:
        G.add_node(n)

    for tple in east_pairs:
        G.add_edge(tple[0], tple[1], direction='E')
    for tple in north_pairs:
        G.add_edge(tple[0], tple[1], direction='N')
    for tple in west_pairs:
        G.add_edge(tple[0], tple[1], direction='W')
    for tple in south_pairs:
        G.add_edge(tple[0], tple[1], direction='S')

    return G

Simple function to plot the graph:

In [None]:
def plot_G(G):
    pos = nx.spring_layout(G, iterations=100, seed=39775)

    fig, ax = plt.subplots(1)

    nx.draw(G, pos, ax=ax, font_size=8, with_labels=True)

    fig.tight_layout()
    plt.show()

#plot_G(G)

Function to get random walks of length 50 from a given graph:

In [None]:
def get_random_walks(G):
    csr_G = cg.csrgraph(G, threads=12) 
    node_names = csr_G.names
    walks = csr_G.random_walks(walklen=50, # length of the walks
                    epochs=10, 
                    start_nodes=None, 
                    return_weight=1.,
                    neighbor_weight=1.)

    walks = np.vectorize(lambda x: node_names[x])(walks)
    return walks

The cell below:
* Defines a function to convert a random walk into a string (e.g. 'node1 E node2 W node3 ...')
* Defines a function to pull all this together, by random selecting 9 letters as arbitrary node names, creating a 3X3 grid graph with these nodes, getting random walks in this graph, and converting them to strings
* Runs this final function to gather the training data

In [None]:
def walk_to_string(walk, G):
    walk_string = ""
    for i in range(len(walk)-1):
        node1 = walk[i]
        node2 = walk[i+1]
        direc = G.edges[(node1, node2)]['direction']
        walk_string += str(node1) + " "+ str(direc) + " "
    walk_string += walk[-1]
    return walk_string

def get_walks_as_strings():
    entities_for_graphs =[[random.choice(string.ascii_letters[0:26]) for i in range(9)] for i in range(1000)]
    entities_for_graphs = [entities for entities in entities_for_graphs if len(list(set(entities)))== 9]

    walks_as_strings = []
    for nodes in entities_for_graphs:
        G = get_graph(nodes=nodes)
        walks = get_random_walks(G)
        walks_as_strings.extend([walk_to_string(walk, G) for walk in walks])
    return walks_as_strings

walks_as_strings = get_walks_as_strings()

#### Train generative model

Train GPT-2 from scratch on dataset created above.

In [None]:
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

model_args = LanguageModelingArgs()
model_args.reprocess_input_data = True
model_args.overwrite_output_dir = True
model_args.num_train_epochs = 10
model_args.dataset_type = "simple"
model_args.save_model_every_epoch = False
model_args.evaluate_during_training = True
model_args.mlm = False  # mlm must be False for CLM
model_args.learning_rate = 1e-5
model_args.vocab_size=100
model_args.use_early_stopping = True
model_args.manual_seed = 123

text_file = open("train.txt", "w")
walks = get_walks_as_strings()[0:10000]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("test.txt", "w")
walks = get_walks_as_strings()[0:1000]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

train_file = "train.txt"
test_file = "test.txt"

model = LanguageModelingModel(
    "gpt2", None, train_files='train.txt', args=model_args
)

# Train the model
model.train_model(train_file, eval_file=test_file)

# Evaluate the model
result = model.eval_model(test_file)

#### Load trained model for sequence generation

In [None]:
model = LanguageGenerationModel("gpt2", "outputs", args={'do_sample': False, 'evaluate_generated_text': True})

In [None]:
model.generate("a E b S e W d N")

In any spatial environment, going N, E, S, W by one unit each takes you back to your starting point. Can the model perform structural inference to predict the next node in this way?

Let's start by specifying the set of 4-transition cycles / loops in the graph:

In [None]:
loops = ["{} E {} S {} W {} N", "{} S {} W {} N {} E", "{} W {} N {} E {} S", "{} N {} E {} S {} W",
        "{} E {} N {} W {} S", "{} N {} W {} S {} E", "{} W {} S {} E {} N", "{} S {} E {} N {} W"]

In [None]:
def test_loop():
    random_nodes = [random.choice(string.ascii_letters[0:26]) for i in range(4)]
    loop = random.choice(loops)
    test_string = loop.format(random_nodes[0], random_nodes[1], random_nodes[2], random_nodes[3])
    output = model.generate(test_string)
    output = output[0][0:len(test_string)+2]
    if output[-1] == output[0]:
        return 1
    else:
        return 0

results = [test_loop() for i in range(100)]

In what percentage of trials was the next node correct?

In [None]:
results.count(1)

#### A more challenging test

For an arbitrary loop in the graph, can the model predict the final item?

Here we define a function to get all cycles (loops in the graph that only visit each node once). We then test in each case whether the GPT-2 model can infer the final node.

In [None]:
def get_cycles_for_graph(G):
    cycles = nx.simple_cycles(G)
    loops = []
    for c in cycles:
        path_string = ""
        for ind, node in enumerate(c):
            if ind+1 < len(c):
                direction = G.get_edge_data(c[ind], c[ind+1])['direction']
                path_string += '{} {} '.format(node, direction)
            else:
                direction = G.get_edge_data(c[ind], c[0])['direction']
                path_string += '{} {} '.format(node, direction)
        loops.append(path_string)
    return loops

def test_loop(num_graphs = 5):
    results = []
    lens = []
    
    for i in range(num_graphs):
        entities_for_graphs =[[random.choice(string.ascii_letters[0:26]) for i in range(9)] for i in range(100)]
        entities_for_graphs = [entities for entities in entities_for_graphs if len(list(set(entities)))== 9]
        nodes = entities_for_graphs[0]
        G = get_graph(nodes=nodes)
        test_strings = get_cycles_for_graph(G)

        for test_string in test_strings:
            lens.append((len(test_string))/4)
            output = model.generate(test_string)
            output = output[0][0:len(test_string)+1]
            if output[-1] == output[0]:
                results.append(1)
            else:
                results.append(0)
    
    return results, lens

results, lens = test_loop()

#### Plot structural inference accuracy against graph cycle length

In [None]:
def acc_for_len(length):
    accs = [r for ind, r in enumerate(results) if lens[ind] == length]
    return accs.count(1) / len(accs)

lengths = [2, 4, 6, 8]
accuracies = [acc_for_len(i) for i in lengths]

plt.bar(lengths, accuracies)
plt.title('Next node inference accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Number of transitions')
plt.tight_layout()
plt.savefig('graph_cycle_length.png')

In [None]:
plt.figure()
plt.rcParams.update({'font.size' : 15})

df = pd.read_csv('outputs/training_progress_scores.csv')
df = df.iloc[0:7]
df.plot(x='global_step', y='eval_loss', title='Loss over time', 
                   ylabel='Loss on test set', xlabel = 'Training step', legend=False)

plt.tight_layout()
plt.savefig('graph-gpt.png')