### Continual learning and generative replay

#### Installation:

Local:

In [None]:
!pip install numpy==1.24.2
!pip tensorflow-macos==2.11.0

Colab:

In [None]:
!pip install wonderwords evaluate datasets accelerate

#### Imports:

In [None]:
from continual_learning_utils import *
from grid_environment_utils import * 
from testing_utils import * 
import random
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import random
import string
import os
import re
import glob
import csv
import torch
from wonderwords import RandomWord
import os
import gc
import pickle
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from itertools import permutations
import logging
from random import shuffle
from matplotlib import pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import math

os.environ["WANDB_DISABLED"] = "true"

In [None]:
def train_model_script(name_or_path='spatial_model', 
                       num_epochs=3,
                       output_dir='./clm_script',
                       save_steps=100,
                       lr=5e-05):
    torch.cuda.empty_cache()
    gc.collect()
    ! python ./run_clm.py \
        --model_name_or_path {name_or_path} \
        --train_file {os.path.join(output_dir, 'train.txt')} \
        --validation_file {os.path.join(output_dir, 'train.txt')} \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --do_train \
        --do_eval \
        --output_dir {output_dir} \
        --overwrite_output_dir \
        --num_train_epochs {num_epochs} \
        --save_strategy 'steps' \
        --save_steps {save_steps} \
        --learning_rate {lr} 

#### Train base model

In [None]:
training_strs = []
for i in range(3000):
    nouns = [r.word(include_parts_of_speech=["nouns"]).replace(" ", "_") for _ in range(9)]
    grid = create_unique_random_grid(nouns)
    training_strs.extend(get_all_paths_for_grid(grid))

print(f"{len(training_strs)} shortest paths on arbitrary grids generated for pre-training.")

!mkdir base_model
text_file = open("base_model/train.txt", "w")
n = text_file.write('\n'.join(training_strs))
text_file.close()

train_model_script(name_or_path='gpt2', output_dir='base_model', num_epochs=5, save_steps=2000)

In [None]:
!cp -r base_model/* base_model_backup/

#### Test generative replay

Let's first create training data for 5 environments.

In [None]:
def train_on_env(training_strs, testing_strs, eps=10, lr=5e-05, num_train=100, env=0, base_model='base_model', generated_strs=None):
    !rm -rf spatial_model_{env}
    !mkdir spatial_model_{env}
    
    text_file = open(f"spatial_model_{env}/train.txt", "w")
    
    # There are 100 training sequence so by default use them all
    list_to_write = training_strs[env][0:num_train]
    
    if generated_strs is not None:
        list_to_write.extend(generated_strs)

    # We oversample the list of training sequences by a factor of five (len(list_to_write)*5)
    # This avoids overfitting to a particular sequence order
    list_to_write = np.random.choice(list_to_write, 1000).tolist()
    n = text_file.write('\n'.join(list_to_write))
    text_file.close()

    text_file = open(f"spatial_model_{env}/test.txt", "w")
    n = text_file.write('\n'.join(testing_strs[env]))
    text_file.close()
    
    train_model_script(name_or_path=base_model, 
                       output_dir=f'spatial_model_{env}', 
                       num_epochs=eps, 
                       save_steps=2000,
                       lr=lr)

In [None]:
def generative_replay(model, num=100, temperature=1):
    examples = []
    while len(examples) < num:
        out = model.continue_input("FROM:", 
                                   do_sample=True,
                                   temperature=temperature)
        # Leave out the last sequence as it stopped midway through
        examples.extend(out.split('\n')[:-1])
    shuffle(examples)
    return examples

def experience_replay(i, train_size=100, sample_size=10):
    # Get sample_size items from the first train_size items of each previous environment
    train_list = [training_strs[j][0:train_size] for j in range(0,i)]
    # Flatten this list
    train_list = [x for xs in train_list for x in xs]
    return random.choices(train_list, k=sample_size)

#### Exploring the effect of different amounts of generative replay

In [None]:
results = []
word_freq_results = []

num_env = 5
num_trials = 1

params_to_test = [{'sample_size': 0, 'train_size': 100, 'temp': 1.0},
                  {'sample_size': 50, 'train_size': 100, 'temp': 1.0},
                  {'sample_size': 100, 'train_size': 100, 'temp': 1.0},
                  {'sample_size': 500, 'train_size': 100, 'temp': 1.0}]


for trial_num in range(num_trials):
    
    training_strs, testing_strs = prepare_data()

    for params in params_to_test:
        train_size = params['train_size']
        sample_size = params['sample_size']
        temp = params['temp']
                
        for i in range(num_env):
            if temp == -1:
                generated_strs = experience_replay(i, train_size=train_size, sample_size=sample_size) if i > 0 else []
            else:
                generated_strs = generative_replay(GPT(base_model=f'spatial_model_{i-1}'), num=sample_size, temperature=temp) if i > 0 else []
                
            print(generated_strs)
            train_on_env(training_strs, testing_strs, num_train=train_size, env=i, base_model='base_model' if i == 0 else f'spatial_model_{i-1}', generated_strs=generated_strs)
            
            # Save the data from generative / experience replay, and unique locations, for analysis
            locs = get_unique_locations(generated_strs)
            word_freq_results.append({'model': i, 
                                      'locs': locs, 
                                      'temp': temp, 
                                      'train_size': train_size, 
                                      "sample_size": sample_size, 
                                      "seqs": generated_strs, 
                                      "training_strs": training_strs, 
                                      "testing_strs": testing_strs})
            with open('word_freq_results_sample.pkl', 'wb') as file:
                pickle.dump(word_freq_results, file)
            
            # Test on all environments
            model = GPT(base_model=f'spatial_model_{i}')
            for j in range(num_env):
                with open(f"spatial_model_{j}/test.txt", 'r') as file:
                    test_data = [line.strip() for line in file]
                accuracy = test_accuracy(model, test_data)
                results.append([i, j, accuracy, temp, train_size, sample_size])
    
            # Save at intervals in case code errors before end
            with open('replay_results_sample.csv', 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['trained_on', 'tested_on', 'accuracy', 'temp', 'train_size', 'sample_size'])
                writer.writerows(results)



#### Aggregate across trials and plot

In [None]:
df = pd.read_csv('replay_sample_results.csv')

# Group by 'Sample_Size', 'Trained_On', and 'Tested_On', and calculate mean and SEM
grouped = df.groupby(['sample_size', 'trained_on', 'tested_on'])
mean_df = grouped['accuracy'].mean().reset_index()
sem_df = grouped['accuracy'].sem().reset_index()

sem_df['accuracy'] = 0

In [None]:
# Unique sample sizes and number of environments
sample_sizes = mean_df['sample_size'].unique()
num_env = df['trained_on'].nunique()

# Create a figure with subplots
fig, axes = plt.subplots(len(sample_sizes), 1, figsize=(10, 15), sharex=True)

# Iterate over each sample size and create a subplot
for i, sample_size in enumerate(sample_sizes):
    df_sample_mean = mean_df[mean_df['sample_size'] == sample_size]
    df_sample_sem = sem_df[sem_df['sample_size'] == sample_size]

    for tested_on in range(num_env):
        # Filter the mean and SEM dataframes for the specific 'Tested_On' value
        means = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['accuracy']
        sems = df_sample_sem[df_sample_sem['tested_on'] == tested_on]['accuracy']
        trained_on_values = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['trained_on']
        
        # Plot error bars
        axes[i].errorbar(trained_on_values, means, yerr=sems, label=f'Tested on Env {tested_on}', marker='o')

    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) {sample_size} self-generated samples')
    axes[i].set_ylabel('accuracy')
    axes[i].legend()

# Set common labels and title
axes[-1].set_xlabel('Trained On Environment')
axes[-1].set_xticks(range(num_env))
# plt.suptitle('Mean Model Accuracy Across Trials with SEM')
plt.savefig('Number of samples effect three trials.png', dpi=500)
plt.show()


#### Explore the effect of different temperatures

In [None]:
results = []
word_freq_results = []

num_env = 3
num_trials = 1

params_to_test = [{'sample_size': 100, 'train_size': 100, 'temp': 0.5},
                  {'sample_size': 100, 'train_size': 100, 'temp': 1.0},
                  {'sample_size': 100, 'train_size': 100, 'temp': 1.5}]

for trial_num in range(num_trials):
    
    training_strs, testing_strs = prepare_data()

    for params in params_to_test:
        train_size = params['train_size']
        sample_size = params['sample_size']
        temp = params['temp']
                
        for i in range(num_env):
            if temp == -1:
                generated_strs = experience_replay(i, train_size=train_size, sample_size=sample_size) if i > 0 else []
            else:
                generated_strs = generative_replay(GPT(base_model=f'spatial_model_{i-1}'), num=sample_size, temperature=temp) if i > 0 else []
                
            print(generated_strs)
            train_on_env(training_strs, testing_strs, num_train=train_size, env=i, base_model='base_model' if i == 0 else f'spatial_model_{i-1}', generated_strs=generated_strs)
            
            # Save the data from generative / experience replay, and unique locations, for analysis
            locs = get_unique_locations(generated_strs)
            word_freq_results.append({'model': i, 
                                      'locs': locs, 
                                      'temp': temp, 
                                      'train_size': train_size, 
                                      "sample_size": sample_size, 
                                      "seqs": generated_strs, 
                                      "training_strs": training_strs, 
                                      "testing_strs": testing_strs})
            with open('word_freq_results_temp.pkl', 'wb') as file:
                pickle.dump(word_freq_results, file)
            
            # Test on all environments
            model = GPT(base_model=f'spatial_model_{i}')
            for j in range(num_env):
                with open(f"spatial_model_{j}/test.txt", 'r') as file:
                    test_data = [line.strip() for line in file]
                accuracy = test_accuracy(model, test_data)
                results.append([i, j, accuracy, temp, train_size, sample_size])
    
            # Save at intervals in case code errors before end
            with open('replay_results_temp.csv', 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['trained_on', 'tested_on', 'accuracy', 'temp', 'train_size', 'sample_size'])
                writer.writerows(results)



In [None]:
df = pd.read_csv('replay_temp_results.csv')
mean_df = df.groupby(['Temperature', 'Trained_On', 'Tested_On'])['Accuracy'].mean().reset_index()
sem_df = df.groupby(['Temperature', 'Trained_On', 'Tested_On'])['Accuracy'].sem().reset_index()

temperatures = mean_df['Temperature'].unique()
num_env = df['Trained_On'].nunique()

fig, axes = plt.subplots(len(temperatures), 1, figsize=(8, 10), sharex=True)

for i, temp in enumerate(temperatures):
    df_temp_mean = mean_df[mean_df['Temperature'] == temp]
    df_temp_sem = sem_df[sem_df['Temperature'] == temp]

    if df_temp_mean.empty or df_temp_sem.empty:
        continue

    for tested_on in range(num_env):
        means = df_temp_mean[df_temp_mean['Tested_On'] == tested_on]['Accuracy']
        sems = df_temp_sem[df_temp_sem['Tested_On'] == tested_on]['Accuracy']
        trained_on_values = df_temp_mean[df_temp_mean['Tested_On'] == tested_on]['Trained_On']

        if means.empty or sems.empty or trained_on_values.empty:
            continue

        axes[i].plot(trained_on_values, means, label=f'Tested on Env {tested_on}', marker='o')

    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) Temperature of {temp}')
    axes[i].set_ylabel('Accuracy')
    axes[i].legend()

axes[-1].set_xticks(range(num_env))
axes[-1].set_xlabel('Trained On Environment')
plt.savefig('Temperature_effect.png')
plt.show()


In [None]:
word_counts = {(model, temp): {env: 0 for env in range(num_env)} for model in range(5) for temp in temps}

for record in word_freq_results:
    model, temp, locs = record['model'], record['temp'], record['locs']
    for word in locs:
        if word in all_env_words:
            for env in range(num_env):
                if word in get_unique_locations(training_strs[env]):
                    word_counts[(model, temp)][env] += 1

In [None]:
# Create a figure with subplots for each temperature
fig, axes = plt.subplots(len(temps), 1, figsize=(8, 10))
bar_width = 0.2
num_groups = 5  # Assuming 5 models
group_width = num_env * bar_width

# Iterate through each temperature and plot on a separate axis
for i, temp in enumerate(temps):
    for env in range(num_env):
        frequencies = [word_counts[(model, temp)][env] for model in range(num_groups)]
        axes[i].bar([x * (group_width + bar_width) + env * bar_width for x in range(num_groups)], frequencies, width=bar_width, label=f'Env {env} locations')
    
    # Set the title for each subplot
    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) Location Distributions for Temperature {temp}')
    axes[i].legend()
    axes[i].set_xlabel('Trained On Environment')
    axes[i].set_ylabel('Number of Locations')
    axes[i].set_xticks([x * (group_width + bar_width) + group_width / 2 for x in range(num_groups)], range(num_groups))

plt.tight_layout()
plt.savefig('location_distributions.png')
plt.show()


#### Generated vs. imagined sequences

In [None]:
results = []
word_freq_results = []

num_env = 5
num_trials = 3

params_to_test = [{'sample_size': 50, 'train_size': 20, 'temp': -1},
                  {'sample_size': 50, 'train_size': 20, 'temp': 0.3},
                  {'sample_size': 50, 'train_size': 20, 'temp': 0.6},
                  {'sample_size': 50, 'train_size': 20, 'temp': 0.9},
                  {'sample_size': 50, 'train_size': 20, 'temp': 1.2},
                  {'sample_size': 50, 'train_size': 20, 'temp': 1.5},
                  # testing effect of sample size
                  {'sample_size': 0, 'train_size': 20, 'temp': 1.0},
                  {'sample_size': 10, 'train_size': 20, 'temp': 1.0},
                  {'sample_size': 50, 'train_size': 20, 'temp': 1.0},
                  {'sample_size': 100, 'train_size': 20, 'temp': 1.0},
                  {'sample_size': 200, 'train_size': 20, 'temp': 1.0},
                  # testing effect of train_size
                  {'sample_size': 10, 'train_size': 20, 'temp': -1},
                  {'sample_size': 100, 'train_size': 20, 'temp': -1},
                  {'sample_size': 200, 'train_size': 20, 'temp': -1},]

for trial_num in range(num_trials):
    
    training_strs, testing_strs = prepare_data(short_paths=True)

    for params in params_to_test:
        train_size = params['train_size']
        sample_size = params['sample_size']
        temp = params['temp']
                
        for i in range(num_env):
            if temp == -1:
                generated_strs = experience_replay(i, train_size=train_size, sample_size=sample_size) if i > 0 else []
            else:
                generated_strs = generative_replay(GPT(base_model=f'spatial_model_{i-1}'), num=sample_size, temperature=temp) if i > 0 else []
                
            print(generated_strs)
            train_on_env(training_strs, 
                         testing_strs, 
                         eps=20, 
                         lr=5e-05,
                         num_train=train_size, 
                         env=i, 
                         base_model='base_model_b8' if i == 0 else f'spatial_model_{i-1}', 
                         generated_strs=generated_strs)
            
            # Save the data from generative / experience replay, and unique locations, for analysis
            locs = get_unique_locations(generated_strs)
            word_freq_results.append({'model': i, 
                                      'locs': locs, 
                                      'temp': temp, 
                                      'train_size': train_size, 
                                      "sample_size": sample_size, 
                                      "seqs": generated_strs, 
                                      "training_strs": training_strs, 
                                      "testing_strs": testing_strs})
            with open('word_freq_results_imagined.pkl', 'wb') as file:
                pickle.dump(word_freq_results, file)
            
            # Test on all environments
            model = GPT(base_model=f'spatial_model_{i}')
            for j in range(num_env):
                if j<=i:
                    with open(f"spatial_model_{j}/test.txt", 'r') as file:
                        test_data = [line.strip() for line in file]
                        print(test_data)
                    accuracy = test_accuracy(model, test_data)
                    results.append(['next_node', i, j, accuracy, temp, train_size, sample_size])
                    accuracy = shortest_path_accuracy(model, 
                                                      test_data_subset(test_data, training_strs[j][:train_size]), 
                                                      training_strs[j] + testing_strs[j])
                    results.append(['shortest_path', i, j, accuracy, temp, train_size, sample_size])
    
            # Save at intervals in case code errors before end
            with open('replay_results_imagined.csv', 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['test_type', 'trained_on', 'tested_on', 'accuracy', 'temp', 'train_size', 'sample_size'])
                writer.writerows(results)



In [None]:
df = pd.read_csv('replay_results_imagined.csv')
df = df[df['test_type'] == 'shortest_path']

# Group by 'Sample_Size', 'Trained_On', and 'Tested_On', and calculate mean and SEM
grouped = df.groupby(['temp', 'trained_on', 'tested_on'])
mean_df = grouped['accuracy'].mean().reset_index()
sem_df = grouped['accuracy'].sem().reset_index()

sem_df['accuracy'] = 0

In [None]:
vals = mean_df['temp'].unique()
num_env = df['trained_on'].nunique()

# Create a figure with subplots
fig, axes = plt.subplots(len(vals), 1, figsize=(8, 16), sharex=True)

# Iterate over each sample size and create a subplot
for i, val in enumerate(vals):
    df_sample_mean = mean_df[mean_df['temp'] == val]
    df_sample_sem = sem_df[sem_df['temp'] == val]

    for tested_on in range(num_env):
        # Filter the mean and SEM dataframes for the specific 'Tested_On' value
        means = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['accuracy']
        sems = df_sample_sem[df_sample_sem['tested_on'] == tested_on]['accuracy']
        trained_on_values = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['trained_on']
        
        # Plot error bars
        axes[i].errorbar(trained_on_values, means, yerr=sems, label=f'Tested on Env {tested_on}', marker='o')

    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) {val} self-generated samples')
    axes[i].set_ylabel('Accuracy')
    axes[i].set_ylim((0,1))
    axes[i].legend()

# Set common labels and title
axes[-1].set_xlabel('Trained On Environment')
axes[-1].set_xticks(range(num_env))
# plt.suptitle('Mean Model Accuracy Across Trials with SEM')
plt.savefig('Number of samples effect three trials.png', dpi=500)
plt.show()
