### Modelling structural inference

As in Whittington et al. (2020), we model the spatial task of predicting the next location in a trajectory as the prediction of the next node in a graph. We create a large set of graphs, each one an n-by-n grid of nodes representing a simple spatial environment. Nodes are labelled with random letters to represent arbitrary associations at a particular location. Each directed edge, i.e. each possible transition in the graph, is of the type north, south, east, or west. Random walks in the set of graphs are used to train the model; these could represent sequences stored in an initial bank of memories. The generative model is trained from scratch on the replayed sequences (converted to strings of the form ‘node1 E node2 W node3 …’) with the mechanism of causal language modelling.

#### Imports:

In [None]:
import sys
sys.path.append('../scripts')

import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import csrgraph as cg
import numpy as np
import random
import string
from graph_utils import *
from tree_utils import *
from itertools import combinations
import pickle
import gc
import os
from transformers import GPT2LMHeadModel, GPT2Tokenizer

os.environ['WANDB_MODE'] = 'disabled'

In [None]:
class GPT:

    def __init__(self, base_model=None, base_model_name='gpt2', vocab_size=100):
        self.base_model = base_model
        self.base_model_name = base_model_name
        self.vocab_size = vocab_size

        if self.base_model is not None:
            self.tokenizer = GPT2Tokenizer.from_pretrained(base_model)
            self.model = GPT2LMHeadModel.from_pretrained(base_model)
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def continue_input(self, input_sequence, max_new_tokens=5, num_return_sequences=1, no_repeat_ngram_size=0,
                       do_sample=False, temperature=0.7, num_beams=1):
        input_ids = self.tokenizer.encode(input_sequence, return_tensors='pt')

        # Generate text
        output = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            num_return_sequences=num_return_sequences,
            num_beams=num_beams,
            no_repeat_ngram_size=no_repeat_ngram_size,
            do_sample=do_sample,
            temperature=temperature,
        )

        # Decode the output
        sequence = output[0].tolist()
        text = self.tokenizer.decode(sequence)
        return text

In [None]:
def load_pkl(pth):
    with open(pth, 'rb') as f:
        d = pickle.load(f)
    return d

def is_valid_path(sequence, graphs):
    # Split the sequence into parts
    parts = sequence.split()

    # Extract nodes and edges; nodes are at even indices, edges at odd indices
    nodes = parts[::2]
    edges = parts[1::2]

    # Convert edges to a lowercase version for comparison (assuming all edges in graphs are lowercase)
    edges = [edge.lower() for edge in edges]

    # Iterate over each graph to check if the path exists
    for graph in graphs:
        path_exists = True
        for i in range(len(nodes) - 1):
            # Check if the current graph has the edge between the current node and the next node
            if not graph.has_edge(nodes[i], nodes[i+1]):
                path_exists = False
                break

        # If path exists in the current graph, return True
        if path_exists:
            return True

    # If none of the graphs contain the path, return False
    return False

In [None]:
def train_model_script(num_epochs=3,
                       output_dir='outputs',
                       lr=5e-05):
    gc.collect()
    train_path = f'./{output_dir}/train.txt'
    test_path = f'./{output_dir}/test.txt'
    ! python3 ../scripts/run_clm_from_scratch-v1.py \
        --model_type 'gpt2-medium' \
        --tokenizer_name 'gpt2' \
        --train_file {train_path} \
        --validation_file {test_path} \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --do_train \
        --do_eval \
        --output_dir {output_dir} \
        --overwrite_output_dir \
        --num_train_epochs {num_epochs} \
        --save_strategy 'epoch' \
        --learning_rate {lr}

In [None]:
# Equivalent code with updated script:
# (Original code above used to reproduce paper results)

# def train_model_script(num_epochs=3,
#                        output_dir='outputs',
#                        lr=5e-05):
#     gc.collect()
#     train_path = f'./{output_dir}/train.txt'
#     test_path = f'./{output_dir}/test.txt'
#     ! python3 ../scripts/run_clm_from_scratch.py \
#         --model_type 'gpt2' \
#         --tokenizer_name 'openai-community/gpt2-medium' \
#         --config_name 'openai-community/gpt2-medium' \
#         --train_file {train_path} \
#         --validation_file {test_path} \
#         --per_device_train_batch_size 1 \
#         --per_device_eval_batch_size 1 \
#         --do_train \
#         --do_eval \
#         --output_dir {output_dir} \
#         --overwrite_output_dir \
#         --num_train_epochs {num_epochs} \
#         --save_strategy 'epoch' \
#         --learning_rate {lr}

### Spatial graph

In [None]:
!rm -rf outputs_graph
!mkdir outputs_graph

text_file = open("outputs_graph/train.txt", "w")
walks, train_gs = get_walks_as_strings(n_graphs=50000, n_walks=5, walk_length=50)
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("outputs_graph/test.txt", "w")
walks, test_gs = get_walks_as_strings(n_graphs=500, n_walks=1, walk_length=50)
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

train_model_script(num_epochs=5,
                   output_dir='outputs_graph',
                   lr=1e-05)


In [None]:
with open(f'outputs_graph/train_graphs.pkl', 'wb') as handle:
      pickle.dump(train_gs, handle)
with open(f'outputs_graph/test_graphs.pkl', 'wb') as handle:
      pickle.dump(test_gs, handle)

### Family tree graph

In [None]:
!rm -rf outputs_tree
!mkdir outputs_tree

text_file = open("outputs_tree/train.txt", "w")
walks, train_gs = get_walks_for_n_trees(n_graphs=50000, n_walks=5, walk_length=50)
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("outputs_tree/test.txt", "w")
walks, test_gs = get_walks_for_n_trees(n_graphs=500, n_walks=1, walk_length=50)
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

train_model_script(num_epochs=5,
                   output_dir='outputs_tree',
                   lr=5e-05)


In [None]:
with open(f'outputs_tree/train_trees.pkl', 'wb') as handle:
      pickle.dump(train_gs, handle)
with open(f'outputs_tree/test_trees.pkl', 'wb') as handle:
      pickle.dump(test_gs, handle)

### Test trained models

Provide paths to models to test:

In [None]:
FAMILY_MODEL_PATH = '/Users/eleanorspens/Documents/PhD Code/clean-sequence-paper/models/familygraph/'
SPATIAL_MODEL_PATH = '/Users/eleanorspens/Documents/PhD Code/clean-sequence-paper/models/spatialgraph/'

#### Test loop inferences

In [None]:
def generate_name() -> str:
    """Generate a random 2-letter name."""
    return ''.join(random.choices(string.ascii_lowercase, k=2))

def test_loop(model, loop_templates):
    accuracy_scores = []  # Store accuracy scores for each template
    results_dict = {}

    for template in loop_templates:
        template_accuracy = []  # Store accuracy for each iteration of the current template

        for _ in range(100):  # Repeat for 10 versions of each template
            # Fill the template with random names
            names = [generate_name() for _ in range(template.count("{}") - 1)]
            names += [names[0]]
            filled_template = template.format(*names)
            print(filled_template)

            # The true final item is the last name generated
            true_final_item = names[-1]
            input_len = len(filled_template.split())

            # Use the model to predict/continue the input based on the filled template
            # Adjust the prompt as needed for your specific model and task
            prediction = model.continue_input(filled_template[0:-3],
                                              max_new_tokens=5,
                                              do_sample=False)
            print(prediction)
            # Assuming the prediction is a string, extract the last word/item
            predicted_items = prediction.strip().split()[0:input_len]
            predicted_final_item = predicted_items[-1] if predicted_items else None
            print(f"True final:{true_final_item}, predicted final: {predicted_final_item}")

            # Calculate accuracy for this iteration
            is_correct = int(predicted_final_item == true_final_item)
            print(is_correct)
            template_accuracy.append(is_correct)

        # Calculate average accuracy for this template
        accuracy_scores.extend(template_accuracy)
        results_dict[template] = sum(template_accuracy) / len(template_accuracy)

    # Calculate and return the overall average accuracy
    overall_avg_accuracy = sum(accuracy_scores) / len(accuracy_scores)
    return overall_avg_accuracy, results_dict


In [None]:
loop_templates = ["{} EAST {} WEST {}",
                  "{} WEST {} EAST {}",
                  "{} NORTH {} SOUTH {}",
                  "{} SOUTH {} NORTH {}",
                  "{} EAST {} SOUTH {} WEST {} NORTH {}",
                  "{} SOUTH {} WEST {} NORTH {} EAST {}",
                  "{} WEST {} NORTH {} EAST {} SOUTH {}",
                  "{} NORTH {} EAST {} SOUTH {} WEST {}",
                  "{} EAST {} EAST {} NORTH {} WEST {} WEST {} SOUTH {}",
                  "{} NORTH {} NORTH {} WEST {} SOUTH {} SOUTH {} EAST {}"]

# Run the test
model = GPT(base_model=SPATIAL_MODEL_PATH, base_model_name='gpt2')
average_accuracy, spatial_results_dict = test_loop(model, loop_templates)
print(f"Average Accuracy: {average_accuracy}")

Saved results for data in paper:

In [None]:
spatial_results_dict = {'{} EAST {} WEST {}': 0.8,
 '{} WEST {} EAST {}': 0.83,
 '{} NORTH {} SOUTH {}': 0.85,
 '{} SOUTH {} NORTH {}': 0.93,
 '{} EAST {} SOUTH {} WEST {} NORTH {}': 0.83,
 '{} SOUTH {} WEST {} NORTH {} EAST {}': 0.84,
 '{} WEST {} NORTH {} EAST {} SOUTH {}': 0.84,
 '{} NORTH {} EAST {} SOUTH {} WEST {}': 0.79,
 '{} EAST {} EAST {} NORTH {} WEST {} WEST {} SOUTH {}': 0.84,
 '{} NORTH {} NORTH {} WEST {} SOUTH {} SOUTH {} EAST {}': 0.82}

In [None]:
# Example loop templates
loop_templates = ["{} CHILD_OF {} PARENT_OF {}",
                  "{} PARENT_OF {} CHILD_OF {}",
                  "{} GRANDCHILD_OF {} GRANDPARENT_OF {}",
                  "{} GRANDPARENT_OF {} GRANDCHILD_OF {}",
                  "{} CHILD_OF {} CHILD_OF {} GRANDPARENT_OF {} SIBLING_OF {}",
                  "{} CHILD_OF {} SPOUSE_OF {} PARENT_OF {} SIBLING_OF {}",
                  "{} PARENT_OF {} SIBLING_OF {} CHILD_OF {} SPOUSE_OF {}",
                  "{} PARENT_OF {} PARENT_OF {} GRANDCHILD_OF {} SPOUSE_OF {}",
                  "{} CHILD_OF {} SPOUSE_OF {} CHILD_OF {} SPOUSE_OF {} GRANDPARENT_OF {} SIBLING_OF {}",
                  "{} GRANDPARENT_OF {} SIBLING_OF {} CHILD_OF {} SPOUSE_OF {} CHILD_OF {} SPOUSE_OF {}"
                 ]

# Run the test
model = GPT(base_model=FAMILY_MODEL_PATH, base_model_name='gpt2')
average_accuracy, family_results_dict = test_loop(model, loop_templates)
print(f"Average Accuracy: {average_accuracy}")

Saved results for data in paper:

In [None]:
family_results_dict = {'{} CHILD_OF {} PARENT_OF {}': 0.71,
 '{} PARENT_OF {} CHILD_OF {}': 0.74,
 '{} GRANDCHILD_OF {} GRANDPARENT_OF {}': 0.74,
 '{} GRANDPARENT_OF {} GRANDCHILD_OF {}': 0.7,
 '{} CHILD_OF {} CHILD_OF {} GRANDPARENT_OF {} SIBLING_OF {}': 0.71,
 '{} CHILD_OF {} SPOUSE_OF {} PARENT_OF {} SIBLING_OF {}': 0.7,
 '{} PARENT_OF {} SIBLING_OF {} CHILD_OF {} SPOUSE_OF {}': 0.75,
 '{} PARENT_OF {} PARENT_OF {} GRANDCHILD_OF {} SPOUSE_OF {}': 0.66,
 '{} CHILD_OF {} SPOUSE_OF {} CHILD_OF {} SPOUSE_OF {} GRANDPARENT_OF {} SIBLING_OF {}': 0.71,
 '{} GRANDPARENT_OF {} SIBLING_OF {} CHILD_OF {} SPOUSE_OF {} CHILD_OF {} SPOUSE_OF {}': 0.72}

In [None]:
# Function to extract hop counts
def get_hops_count(key):
    return len(key.split()) // 2

# Combine data and compute averages and standard deviations
combined_data = {'Family tree': family_results_dict, 'Spatial': spatial_results_dict}
averages = {}
std_devs = {}

# Organizing data by hops instead of task
for task, data in combined_data.items():
    for pattern, accuracy in data.items():
        hops = get_hops_count(pattern)
        if hops not in averages:
            averages[hops] = {}
            std_devs[hops] = {}
        if task not in averages[hops]:
            averages[hops][task] = []
        averages[hops][task].append(accuracy)

# Calculate average accuracies and standard deviations by task
for hops, tasks in averages.items():
    for task, accuracies in tasks.items():
        averages[hops][task] = np.mean(accuracies)
        std_devs[hops][task] = np.std(accuracies)

# Plotting
fig, ax = plt.subplots(figsize=(3.4, 2.8))  # Increased figure size for clarity
tasks = list(combined_data.keys())
colors = ['blue', 'red']  # Colors for different tasks
hops_labels = sorted(averages.keys())

x = np.arange(len(hops_labels))  # Hop counts as positions on x-axis
bar_width = 0.35  # Width of each bar
offset = bar_width / 2

# Create bars for each hop count
for i, hops in enumerate(hops_labels):
    positions = x[i] - offset * len(tasks) / 2
    for j, task in enumerate(tasks):
        avg = averages[hops].get(task, 0)
        std_dev = std_devs[hops].get(task, 0)
        bar_pos = positions + j * bar_width
        ax.bar(bar_pos, avg, bar_width, label=task if i == 0 else "", color=colors[j], alpha=0.4,
               yerr=std_dev, capsize=3)

ax.set_xticks(x)
ax.set_xticklabels([f'{h} hops' for h in hops_labels])
ax.set_xlabel('Number of transitions')
ax.set_ylabel('Average accuracy')
ax.legend(loc='upper right', fontsize=9)
ax.set_ylim(0,1.12)

plt.tight_layout()
plt.savefig('aggregated_inf.png', dpi=300)
plt.show()


In [None]:
with open('trainer_state_spatial.json', 'r') as file:
    trainer_state = json.load(file)

# Extract loss values for plotting
train_steps = []
train_loss = []
eval_steps = []
eval_loss = []

for entry in trainer_state["log_history"]:
    if "loss" in entry:
        train_steps.append(entry["epoch"])
        train_loss.append(entry["loss"])
    if "eval_loss" in entry:
        eval_steps.append(entry["epoch"])
        eval_loss.append(entry["eval_loss"])

# Plotting the training and evaluation loss
plt.figure(figsize=(3, 3))
plt.plot(train_steps, train_loss, label='Train loss', marker='.', color='red', alpha=0.4)
plt.plot(eval_steps, eval_loss, label='Val loss', marker='.', color='blue', alpha=0.4)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('spatial_loss.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
with open('trainer_state_family.json', 'r') as file:
    trainer_state = json.load(file)

# Extract loss values for plotting
train_steps = []
train_loss = []
eval_steps = []
eval_loss = []

for entry in trainer_state["log_history"]:
    if "loss" in entry:
        train_steps.append(entry["epoch"])
        train_loss.append(entry["loss"])
    if "eval_loss" in entry:
        eval_steps.append(entry["epoch"])
        eval_loss.append(entry["eval_loss"])

# Plotting the training and evaluation loss
plt.figure(figsize=(3, 3))
plt.plot(train_steps, train_loss, label='Train loss', marker='.', color='red', alpha=0.4)
plt.plot(eval_steps, eval_loss, label='Val loss', marker='.', color='blue', alpha=0.4)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig('family_loss.png', dpi=300, bbox_inches='tight')
plt.show()


#### Test imagination

In [None]:
def track_coordinates(walks):
    direction_offsets = {
        'NORTH': (0, 1),
        'SOUTH': (0, -1),
        'EAST': (1, 0),
        'WEST': (-1, 0)
    }
    
    coordinates = {}
    current_position = (0, 0)
    
    for walk in walks:
        steps = walk.split()
        for i in range(0, len(steps) - 2, 2):
            node = steps[i]
            direction = steps[i + 1]
            next_node = steps[i + 2]
            
            # Check if the current node is correctly logged at the current position
            if current_position in coordinates:
                if coordinates[current_position] != node:
                    print(f"Invalid path: {node} found at {current_position}, but {coordinates[current_position]} was expected.")
                    return False
            else:
                coordinates[current_position] = node
            
            # Move to the next position
            if direction not in direction_offsets.keys():
                return False
            else:
                offset = direction_offsets[direction]
                current_position = (current_position[0] + offset[0], current_position[1] + offset[1])
                
                # Check if the next node is correctly logged at the new position
                if current_position in coordinates:
                    if coordinates[current_position] != next_node:
                        #print(f"Invalid path: {next_node} found at {current_position}, but {coordinates[current_position]} was expected.")
                        return False
                else:
                    coordinates[current_position] = next_node
    
    #print("Valid path")
    return True

# Test cases
walks1 = ['sz WEST zr EAST zr']  # This should be invalid
walks2 = ['ab EAST xy NORTH yz']  # This should be valid

print(track_coordinates(walks1))  # Expected output: False
print(track_coordinates(walks2))  # Expected output: True


In [None]:
def random_letter_pair():
    return ''.join(random.choices('abcdefghijklmnopqrstuvwxyz', k=2))

model = GPT(base_model=SPATIAL_MODEL_PATH, base_model_name='gpt2')

imagined_for_temps = {}

for temp in [0, 0.5, 1.0, 1.5, 2.0]:
    imagined = []
    for i in range(50):
        if temp == 0:
            prediction = model.continue_input(random_letter_pair(), do_sample=False,
                                             max_new_tokens=50)
        else:
            prediction = model.continue_input(random_letter_pair(), do_sample=True, 
                                              max_new_tokens=50, temperature=temp)
        imagined.append(prediction)
    imagined_for_temps[temp] = imagined

In [None]:
# Define path lengths to check
lengths = [1, 2, 3, 4, 5, 6]

# Create a figure with a specific size
plt.figure(figsize=(3.3, 3.1))  # You can adjust the width and height as needed

# Define a colormap
cmap = plt.get_cmap('magma')  # 'cool' is a colormap with red/blue/purple colors
colors = cmap(np.linspace(0.15, 0.75, len(imagined_for_temps)))

# Plot one line per temperature
temps_to_plot = [0.5, 1.0, 1.5, 2.0]
for idx, temp in enumerate(temps_to_plot):
    paths = imagined_for_temps[temp]
    fractions = []
    for length in lengths:
        valid_count = 0
        for path in paths:
            shortened_path = ' '.join(path.split()[:2 * length + 1])
            if track_coordinates([shortened_path]):
                valid_count += 1
        fraction_valid = valid_count / len(paths)
        fractions.append(fraction_valid)
    plt.plot(lengths, fractions, marker='o', label=f'{temp}', color=colors[idx])

# Add labels and legend
plt.xlabel('Number of transitions')
plt.ylabel('Fraction valid')
plt.legend(title='Temperature')
plt.savefig('Imagined_paths_by_temp.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
# Function to convert path to coordinates
def path_to_coordinates(path):
    x, y = 0, 0
    coordinates = [(x, y)]
    directions = {
        'NORTH': (0, 1),
        'SOUTH': (0, -1),
        'EAST': (1, 0),
        'WEST': (-1, 0),
    }
    steps = path.split()
    for step in steps:
        if step in directions:
            dx, dy = directions[step]
            x += dx
            y += dy
            coordinates.append((x, y))
    return coordinates

fig, axs = plt.subplots(1, 3, figsize=(10, 3))  # 1 row, 3 columns

grid_size = 9 # Define the size of the grid (should be odd for symmetry)
center = grid_size // 2

# Create a heatmap for each temperature
temps_to_plot = [0, 0.5, 1.0]
for idx, temp in enumerate(temps_to_plot):
    paths = imagined_for_temps[temp]
    grid = np.zeros((grid_size, grid_size))

    for path in paths:
        coordinates = path_to_coordinates(path)
        for x, y in coordinates:
            grid[center + x, center + y] += 1

    sns.heatmap(grid, cmap='coolwarm', cbar=True, ax=axs[idx], vmin=0, vmax=250, alpha=0.7)
    axs[idx].set_title(f'Temperature of {temp}')
    axs[idx].set_xticks([])  # Remove x-ticks
    axs[idx].set_yticks([])  # Remove y-ticks

plt.tight_layout()
plt.savefig('imagined_heatmaps.png', dpi=300)
plt.show()

In [None]:
# Function to convert path to coordinates
def path_to_coordinates(path):
    x, y = 0, 0
    coordinates = [(x, y)]
    directions = {
        'NORTH': (0, 1),
        'SOUTH': (0, -1),
        'EAST': (1, 0),
        'WEST': (-1, 0),
    }
    steps = path.split()
    for step in steps:
        if step in directions:
            dx, dy = directions[step]
            x += dx
            y += dy
            coordinates.append((x, y))
    return coordinates

# Function to calculate the maximum distance from origin at any point in the path
def calculate_max_distance_from_origin(coordinates):
    max_distance = 0
    for (x, y) in coordinates:
        distance = abs(x) + abs(y)  # Manhattan distance
        if distance > max_distance:
            max_distance = distance
    return max_distance


# Calculate distances and mean distances
temps_to_plot = [0, 0.5, 1.0, 1.5, 2.0]
mean_distances = []
sem_distances = []
all_distances = []

for temp in temps_to_plot:
    distances = []
    paths = imagined_for_temps[temp]
    
    for path in paths:
        coordinates = path_to_coordinates(path)
        max_distance = calculate_max_distance_from_origin(coordinates)
        distances.append(max_distance)
    
    mean_distance = np.mean(distances)
    sem_distance = np.std(distances) #/ np.sqrt(len(distances))
    
    mean_distances.append(mean_distance)
    sem_distances.append(sem_distance)
    all_distances.append(distances)

# Plot bar chart with individual data points
bar_width = 0.4  # Adjust the bar width to make them thinner
plt.figure(figsize=(3, 3))
plt.bar(temps_to_plot, mean_distances, yerr=sem_distances, width=0.4, capsize=2, color='blue', alpha=0.4, label='Mean Distance')

# Overlay individual data points
for i, temp in enumerate(temps_to_plot):
    x_values = np.full(len(all_distances[i]), temp)  # Same x value for all points in this category
    plt.scatter(x_values, all_distances[i], color='blue', alpha=0.2, label='Individual Distances' if i == 0 else "")

plt.xlabel('Temperature')
plt.ylabel('Mean max distance')
plt.savefig('dists.png', dpi=300, bbox_inches='tight')
plt.show()


#### Can the model generalise to a larger grid?

This excludes the hypothesis that the model *just* memorises sequences of directions.

In [None]:
def generate_name() -> str:
    """Generate a random 2-letter name."""
    return ''.join(random.choices(string.ascii_lowercase, k=2))

def test_loop(model, loop_templates):
    accuracy_scores = []  # Store accuracy scores for each template
    results_dict = {}

    for template in loop_templates:
        template_accuracy = []  # Store accuracy for each iteration of the current template

        for _ in range(50):  # Repeat for 10 versions of each template
            # Fill the template with random names
            names = [generate_name() for _ in range(template.count("{}") - 1)]
            names += [names[0]]
            filled_template = template.format(*names)
            print(filled_template)

            # The true final item is the last name generated
            true_final_item = names[-1]
            input_len = len(filled_template.split())

            # Use the model to predict/continue the input based on the filled template
            # Adjust the prompt as needed for your specific model and task
            prediction = model.continue_input(filled_template[0:-3],
                                              max_new_tokens=5,
                                              do_sample=True,
                                             temperature=1.0, num_beams=5)
            print(prediction)
            # Assuming the prediction is a string, extract the last word/item
            predicted_items = prediction.strip().split()[0:input_len]
            predicted_final_item = predicted_items[-1] if predicted_items else None
            print(f"True final:{true_final_item}, predicted final: {predicted_final_item}")

            # Calculate accuracy for this iteration
            is_correct = int(predicted_final_item == true_final_item)
            print(is_correct)
            template_accuracy.append(is_correct)

        # Calculate average accuracy for this template
        accuracy_scores.extend(template_accuracy)
        results_dict[template] = sum(template_accuracy) / len(template_accuracy)

    # Calculate and return the overall average accuracy
    overall_avg_accuracy = sum(accuracy_scores) / len(accuracy_scores)
    return overall_avg_accuracy, results_dict


The model can generalise to larger grids than it was trained on to some extent (Figure X), suggesting performance cannot solely be attributed to memorisation of sequences of transitions.

In [None]:
def create_valid_loop_templates(n):
    """
    Create all templates with n steps in one direction followed by n steps in another direction,
    ensuring that the path forms a valid loop without revisiting any location.
    """
    directions_tuples = [
        ('EAST', 'SOUTH', 'WEST', 'NORTH'),
        ('NORTH', 'EAST', 'SOUTH', 'WEST'),
        ('WEST', 'NORTH', 'EAST', 'SOUTH'),
        ('SOUTH', 'WEST', 'NORTH', 'EAST'),
    ]
    
    templates = []
    for direction_tuple in directions_tuples:
        direction_tuple = [[i]*n for i in list(direction_tuple)]
        direction_tuple = [item for sublist in direction_tuple for item in sublist]
        template = " {} ".join(direction_tuple)
        template = "{} " + template + " {}"
        templates.append(template)
    
    return templates

def create_repetition_templates(m):

    templates = []

    rep_template_1 = " {} ".join(['NORTH'] * m + ['EAST'] + ['SOUTH'] * m + ['WEST'])
    rep_template_2 = " {} ".join(['NORTH'] + ['EAST'] * m + ['SOUTH'] + ['WEST'] * m)
    rep_template_3 = " {} ".join(['NORTH'] + ['WEST'] * m + ['SOUTH'] + ['EAST'] * m)
    rep_template_4 = " {} ".join(['NORTH'] * m + ['WEST'] + ['SOUTH'] * m + ['EAST'])
    templates.append("{} " + rep_template_1 + " {}")
    templates.append("{} " + rep_template_2 + " {}")
    templates.append("{} " + rep_template_3 + " {}")
    templates.append("{} " + rep_template_4 + " {}")
    
    return templates

def generate_loop_templates(min_n=1, max_n=4):
    """
    Generate all valid loop templates with n varied between min_n and max_n.
    Also generate repetition templates for the same range of m.
    """
    templates_dict = {}
    for n in range(min_n, max_n + 1):
        templates = create_valid_loop_templates(n)
        repetition_templates = create_repetition_templates(n)
        templates_dict[n] = templates + repetition_templates
    return templates_dict

# Generate the templates
loop_templates_dict = generate_loop_templates()

# Initialize the model
model = GPT(base_model=SPATIAL_MODEL_PATH, base_model_name='gpt2')

# Test each set of templates and store the results
results = {}
for n, templates in loop_templates_dict.items():
    accuracies = []
    for template in templates:
        accuracy, _ = test_loop(model, [template])
        accuracies.append(accuracy)
    average_accuracy = np.mean(accuracies)
    sem = np.std(accuracies, ddof=1) / np.sqrt(len(accuracies))
    results[n] = (average_accuracy, sem)
    print(f"n = {n}, Average Accuracy: {average_accuracy}, SEM: {sem}")

# Extract the data for plotting
ns = list(results.keys())
mean_accuracies = [results[n][0] for n in ns]
sems = [results[n][1] for n in ns]

In [None]:
# Plot the results
plt.figure(figsize=(3, 3))
plt.errorbar([n+1 for n in ns], mean_accuracies, yerr=sems, fmt='o-', capsize=5, color='b')
plt.xlabel('Grid size')
plt.ylabel('Average accuracy')
plt.savefig('accuracy_by_grid_size.png', dpi=300, bbox_inches='tight')
plt.show()
