### Retrieval augmented generation and inference

Before memories are consolidated into the generative network, let's see if the system can use the hippocampus and neocortex jointly to solve inference tasks...

#### Imports:

In [None]:
import sys
sys.path.append('../scripts/')

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import csrgraph as cg
import numpy as np
import random
import string
from graph_utils import *
from tree_utils import *
import networkx as nx
import matplotlib.pyplot as plt
import random
import string
from itertools import combinations
import networkx as nx
import random
import pickle
import gc
import os
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
walks, test_gs = get_walks_as_strings(n_graphs=500, n_walks=1, walk_length=50)
with open(f'outputs_graph/test_graphs.pkl', 'wb') as handle:
      pickle.dump(test_gs, handle)

walks, test_gs = get_walks_for_n_trees(n_graphs=500, n_walks=1, walk_length=50)
with open(f'outputs_tree/test_trees.pkl', 'wb') as handle:
      pickle.dump(test_gs, handle)

#### Paths to models:

In [None]:
family_model_dir = 'outputs_tree'
spatial_model_dir = 'outputs_graph'

#### Functions:

In [None]:
class GPT:

    def __init__(self, base_model=None, base_model_name='gpt2', vocab_size=100):
        self.base_model = base_model
        self.base_model_name = base_model_name
        self.vocab_size = vocab_size

        if self.base_model is not None:
            self.tokenizer = GPT2Tokenizer.from_pretrained(base_model)
            self.model = GPT2LMHeadModel.from_pretrained(base_model)
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def continue_input(self, input_sequence, max_new_tokens=5, num_return_sequences=1, no_repeat_ngram_size=0,
                       do_sample=False, temperature=0.7, num_beams=1):
        input_ids = self.tokenizer.encode(input_sequence, return_tensors='pt')

        # Generate text
        output = self.model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            num_return_sequences=num_return_sequences,
            num_beams=num_beams,
            no_repeat_ngram_size=no_repeat_ngram_size,
            do_sample=do_sample,
            temperature=temperature,
        )

        # Decode the output
        sequence = output[0].tolist()
        text = self.tokenizer.decode(sequence)
        return text

In [None]:
# Load the graphs from the pickled file
def load_graphs(filename):
    with open(filename, 'rb') as f:
        graphs = pickle.load(f)
    return graphs

# Function to remove a random edge and get n walks of length m
def get_walks_with_removed_edge(G, n, m, edge_name='direction'):
    G_copy = G.copy()
    
    # Select a random edge to remove
    edge_to_remove = random.choice(list(G_copy.edges))
    edge_data = G_copy.get_edge_data(edge_to_remove[0], edge_to_remove[1])
    edge_direction = edge_data[edge_name]
    G_copy.remove_edge(*edge_to_remove)
    
    # Function to get a single random walk with edge types
    def get_random_walk(G, length):
        walk = []
        nodes = list(G.nodes)
        if not nodes:
            return walk
        current_node = random.choice(nodes)
        for _ in range(length - 1):
            neighbors = list(G.neighbors(current_node))
            if not neighbors:
                break
            next_node = random.choice(neighbors)
            edge_data = G.get_edge_data(current_node, next_node)
            edge_direction = edge_data[edge_name]
            walk.append((current_node, edge_direction))
            current_node = next_node
        walk.append((current_node, ''))  # Add the last node without an edge
        return walk
    
    # Generate n walks of length m
    walks = [get_random_walk(G_copy, m) for _ in range(n)]
    
    # Convert walks to string representation
    walks_str = [' '.join([f"{node} {direction}" for node, direction in walk]) for walk in walks]
    
    # Convert removed edge to string representation
    removed_edge_str = f"{edge_to_remove[0]} {edge_direction} {edge_to_remove[1]}"

    return walks_str, removed_edge_str

def retrieve_fn(query, hpc):
    return [s for s in hpc if query[0:2] in s]

In [None]:
def find_edges(G, node, relation):
    edges = []
    for neighbor in G.neighbors(node):
        if G[node][neighbor].get('relationship') == relation:
            edges.append((node, neighbor))
    return edges

def get_all_edge_types(G, attribute_name='direction'):
    edge_types = set()
    for _, _, data in G.edges(data=True):
        if attribute_name in data:
            edge_types.add(data[attribute_name])
    return list(edge_types)


#### Test inference:

In [None]:
data_to_plot = {
    'Spatial task': {},
    'Family tree task': {}
}

In [None]:
filename = spatial_model_dir + '/test_graphs.pkl'
graphs = load_graphs(filename)

n = 500
m = 3

model = GPT(base_model=spatial_model_dir, base_model_name='gpt2')

rag_counts = []
nc_counts = []
hpc_counts = []

for j in range(3):

    test_seqs = []
    hpc = []

    for i in range(500):
        G = graphs[i]
        seqs_to_encode, test_seq = get_walks_with_removed_edge(G, n, m, edge_name='direction')
        seqs_filtered = [s for s in seqs_to_encode  if test_seq[-2:] in s]
        if bool(set([test_seq[:2]]) & set(' '.join(hpc).split())) is False:
            test_seqs.append(test_seq)
            hpc.extend(seqs_filtered)
        if len(test_seqs) > 100:
            print("Found 100 graphs.")
            break
    
    edge_types = get_all_edge_types(G, attribute_name='direction')
    
    rag_count = 0
    hpc_count = 0
    nc_count = 0
    
    for i in range(100):
    
        test_seq = test_seqs[i]
        retrieved_seqs = retrieve_fn(test_seq, hpc)[0:1]
        
        prompt = '\n'.join(retrieved_seqs) + '\n' + test_seq[0:-3]
        out = model.continue_input(prompt, do_sample=False)
        print(prompt)
        print(out)
        if out[len(prompt)+1:len(prompt)+3] == test_seq[-2:]:
            rag_count += 1

        hpc_pred = np.random.choice(list(set(' '.join(retrieved_seqs).split()) - set(edge_types)))
        if hpc_pred == test_seq[-2:]:
            hpc_count +=1

        prompt = test_seq[0:-3]
        out = model.continue_input(prompt, do_sample=False)
        nc_pred = out[len(prompt)+1:len(prompt)+3]
        if nc_pred == test_seq[-2:]:
            nc_count +=1
    
    rag_counts.append(rag_count)
    hpc_counts.append(hpc_count)
    nc_counts.append(nc_count)

data_to_plot['Spatial task']['RAG'] = rag_counts
data_to_plot['Spatial task']['NC only'] = nc_counts
data_to_plot['Spatial task']['HPC only'] = hpc_counts

In [None]:
filename = family_model_dir + '/test_trees.pkl'
graphs = load_graphs(filename)

n = 500
m = 3

model = GPT(base_model=family_model_dir, base_model_name='gpt2')

rag_counts = []
nc_counts = []
hpc_counts = []

for i in range(3):

    test_seqs = []
    hpc = []
    graphs_subset = []

    for i in range(500):
        G = graphs[i]
        seqs_to_encode, test_seq = get_walks_with_removed_edge(G, n, m, edge_name='relationship')
        seqs_filtered = [s for s in seqs_to_encode  if test_seq[-2:] in s]
        if bool(set([test_seq[:2]]) & set(' '.join(hpc).split())) is False:
            test_seqs.append(test_seq)
            hpc.extend(seqs_filtered)
            graphs_subset.append(G)
        if len(test_seqs) > 100:
            print("Found 100 graphs.")
            break
    
    edge_types = get_all_edge_types(G, attribute_name='relationship')
    
    rag_count = 0
    hpc_count = 0
    nc_count = 0
    
    for i in range(100):
    
        test_seq = test_seqs[i]
        retrieved_seqs = retrieve_fn(test_seq, hpc)[0:1]

        valid_answers = [n[1] for n in find_edges(graphs_subset[i], 
                                                  test_seq.split()[0], 
                                                  test_seq.split()[1])]
        
        prompt = '\n'.join(retrieved_seqs) + '\n' + test_seq[0:-3]
        out = model.continue_input(prompt, do_sample=False)
        if out[len(prompt)+1:len(prompt)+3] in valid_answers:
            rag_count += 1

        hpc_pred = np.random.choice(list(set(' '.join(retrieved_seqs).split()) - set(edge_types)))
        if hpc_pred in valid_answers:
            hpc_count +=1

        prompt = test_seq[0:-3]
        out = model.continue_input(prompt, do_sample=False)
        nc_pred = out[len(prompt)+1:len(prompt)+3]
        if nc_pred in valid_answers:
            nc_count +=1
    
    rag_counts.append(rag_count)
    hpc_counts.append(hpc_count)
    nc_counts.append(nc_count)

data_to_plot['Family tree task']['RAG'] = rag_counts
data_to_plot['Family tree task']['NC only'] = nc_counts
data_to_plot['Family tree task']['HPC only'] = hpc_counts

In [None]:
data_to_plot

In [None]:
# Function to calculate mean and SEM
def calculate_mean_sem(data):
    mean = np.mean(data)
    sem = np.std(data) / np.sqrt(len(data))
    return mean, sem

# Calculate means and SEMs
processed_data = {}
for task, methods in data_to_plot.items():
    processed_data[task] = {}
    for method, values in methods.items():
        mean, sem = calculate_mean_sem(values)
        processed_data[task][method] = (mean, sem)

# Specifying the order explicitly
methods = ['NC only', 'HPC only', 'RAG'] 
tasks = list(processed_data.keys())  

n_methods = len(methods)
n_tasks = len(tasks)

# Create figure and axes
fig, ax = plt.subplots(figsize=(4.3, 3))  # Adjusted size for clarity

# Set the positions and width for the bars
positions = np.arange(n_methods)
bar_width = 0.38  

colours = ['red', 'blue']

# Plot data and annotate
for i, (task, colour) in enumerate(zip(tasks, colours)):
    means = [processed_data[task][method][0]/100 for method in methods]
    sems = [processed_data[task][method][1]/100 for method in methods]
    bars = ax.bar(positions + i * bar_width, means, bar_width, yerr=sems, label=task, alpha=0.4, capsize=5, color=colour) 

# Formatting
ax.set_xlabel('Method')
ax.set_ylabel('Accuracy')
ax.set_ylim(0, 0.9)  
ax.set_xticks(positions + bar_width / 2) 
ax.set_xticklabels(methods)
ax.legend(title="Task")

# Show plot
plt.tight_layout()
plt.savefig('RAG_graph_by_method_inf.png', dpi=300)
plt.show()
