### Continual learning and generative replay

#### Installation:

Local:

In [None]:
!pip install numpy==1.24.2
!pip tensorflow-macos==2.11.0

Colab:

In [None]:
!pip install wonderwords evaluate datasets accelerate

#### Imports:

In [None]:
from continual_learning_utils import *
from grid_environment_utils import * 
from testing_utils import * 
import random
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import random
import string
import os
import re
import glob
import csv
import torch
from wonderwords import RandomWord
import os
import gc
import pickle
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from itertools import permutations
import logging
from random import shuffle
from matplotlib import pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import math

os.environ["WANDB_DISABLED"] = "true"

In [None]:
def train_model_script(name_or_path='spatial_model', 
                       num_epochs=3,
                       output_dir='./clm_script',
                       save_steps=100,
                       lr=5e-05):
    torch.cuda.empty_cache()
    gc.collect()
    ! python ./run_clm.py \
        --model_name_or_path {name_or_path} \
        --train_file {os.path.join(output_dir, 'train.txt')} \
        --validation_file {os.path.join(output_dir, 'train.txt')} \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --do_train \
        --do_eval \
        --output_dir {output_dir} \
        --overwrite_output_dir \
        --num_train_epochs {num_epochs} \
        --save_strategy 'steps' \
        --save_steps {save_steps} \
        --learning_rate {lr} 

#### Train base model

In [None]:
training_strs = []
for i in range(1000):
    nouns = [r.word(include_parts_of_speech=["nouns"]).replace(" ", "_") for _ in range(9)]
    grid = create_unique_random_grid(nouns)
    training_strs.extend(get_all_paths_for_grid(grid))

print(f"{len(training_strs)} shortest paths on arbitrary grids generated for pre-training.")

!mkdir base_model
text_file = open("base_model/train.txt", "w")
n = text_file.write('\n'.join(training_strs))
text_file.close()

train_model_script(name_or_path='gpt2', output_dir='base_model', num_epochs=5, save_steps=2000)

In [None]:
!cp -r base_model/* base_model_backup/

#### Test generative replay

Let's first create training data for 5 environments.

In [None]:
def train_on_env(training_strs, testing_strs, eps=10, lr=5e-05, num_train=100, env=0, base_model='base_model', generated_strs=None):
    !rm -rf spatial_model_{env}
    !mkdir spatial_model_{env}
    
    text_file = open(f"spatial_model_{env}/train.txt", "w")
    
    # There are 100 training sequence so by default use them all
    list_to_write = training_strs[env][0:num_train]
    
    if generated_strs is not None:
        list_to_write.extend(generated_strs)

    # We oversample the list of training sequences 
    # This avoids overfitting to a particular sequence order
    list_to_write = np.random.choice(list_to_write, 1000).tolist()
    n = text_file.write('\n'.join(list_to_write))
    text_file.close()

    text_file = open(f"spatial_model_{env}/test.txt", "w")
    n = text_file.write('\n'.join(testing_strs[env]))
    text_file.close()
    
    train_model_script(name_or_path=base_model, 
                       output_dir=f'spatial_model_{env}', 
                       num_epochs=eps, 
                       save_steps=2000,
                       lr=lr)

In [None]:
def generative_replay(model, num=100, temperature=1):
    examples = []
    while len(examples) < num:
        out = model.continue_input("FROM:", 
                                   do_sample=True,
                                   temperature=temperature)
        # Leave out the last sequence as it stopped midway through
        examples.extend(out.split('\n')[:-1])
    shuffle(examples)
    return examples

def experience_replay(i, train_size=100, sample_size=10):
    # Get sample_size items from the first train_size items of each previous environment
    train_list = [training_strs[j][0:train_size] for j in range(0,i)]
    # Flatten this list
    train_list = [x for xs in train_list for x in xs]
    return random.choices(train_list, k=sample_size)

#### Test different parameters

* Vary the temperature of sampling
* Vary the number of samples
* Vary the amount of training per new environment

In [None]:
results = []
word_freq_results = []

num_env = 5
num_trials = 3


params_to_test = [{'sample_size': 50, 'train_size': 100, 'temp': -1},
                  {'sample_size': 50, 'train_size': 100, 'temp': 0.3},
                  {'sample_size': 50, 'train_size': 100, 'temp': 0.9},
                  {'sample_size': 50, 'train_size': 100, 'temp': 1.5},
                  {'sample_size': 50, 'train_size': 100, 'temp': 2.1},
                  {'sample_size': 0, 'train_size': 100, 'temp': 1.2},
                  {'sample_size': 10, 'train_size': 100, 'temp': 1.2},
                  {'sample_size': 50, 'train_size': 100, 'temp': 1.2},
                  {'sample_size': 100, 'train_size': 100, 'temp': 1.2},
                  {'sample_size': 200, 'train_size': 100, 'temp': 1.2}]

for trial_num in range(num_trials):
    
    training_strs, testing_strs = prepare_data(default=True)

    for params in params_to_test:
        train_size = params['train_size']
        sample_size = params['sample_size']
        temp = params['temp']
                
        for i in range(num_env):
            if temp == -1:
                generated_strs = experience_replay(i, train_size=train_size, sample_size=sample_size) if i > 0 else []
            else:
                generated_strs = generative_replay(GPT(base_model=f'spatial_model_{i-1}'), num=sample_size, temperature=temp) if i > 0 else []
                
            print(generated_strs)
            train_on_env(training_strs, 
                         testing_strs, 
                         eps=20, 
                         lr=5e-05,
                         num_train=train_size, 
                         env=i, 
                         base_model='base_model_b8' if i == 0 else f'spatial_model_{i-1}', 
                         generated_strs=generated_strs)
            
            # Save the data from generative / experience replay, and unique locations, for analysis
            locs = get_unique_locations(generated_strs)
            word_freq_results.append({'model': i, 
                                      'locs': locs, 
                                      'temp': temp, 
                                      'train_size': train_size, 
                                      "sample_size": sample_size, 
                                      "seqs": generated_strs, 
                                      "training_strs": training_strs, 
                                      "testing_strs": testing_strs})
            with open('word_freq_results_imagined.pkl', 'wb') as file:
                pickle.dump(word_freq_results, file)
            
            # Test on all environments
            model = GPT(base_model=f'spatial_model_{i}')
            for j in range(num_env):
                if j<=i:
                    with open(f"spatial_model_{j}/test.txt", 'r') as file:
                        test_data = [line.strip() for line in file]
                        print(test_data)
                    accuracy = test_accuracy(model, test_data)
                    results.append(['next_node', i, j, accuracy, temp, train_size, sample_size])
                    accuracy = shortest_path_accuracy(model, 
                                                      test_data_subset(test_data, training_strs[j][:train_size]), 
                                                      training_strs[j] + testing_strs[j])
                    results.append(['shortest_path', i, j, accuracy, temp, train_size, sample_size])
    
            # Save at intervals in case code errors before end
            with open('replay_results_imagined.csv', 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['test_type', 'trained_on', 'tested_on', 'accuracy', 'temp', 'train_size', 'sample_size'])
                writer.writerows(results)



In [None]:
df = pd.read_csv('replay_results_imagined_1102_final.csv')
df = df[df["temp"] == 1.2]
df = df[df['test_type'] == 'next_node']
var_to_test = 'sample_size'

# Group by var_to_test, 'Trained_On', and 'Tested_On', and calculate mean and SEM
grouped = df.groupby([var_to_test, 'trained_on', 'tested_on'])
mean_df = grouped['accuracy'].mean().reset_index()
sem_df = grouped['accuracy'].sem().reset_index()

#sem_df['accuracy'] = 0 # ignore as there is just one trial

vals = mean_df[var_to_test].unique()
num_env = df['trained_on'].nunique()

# Create a figure with subplots
vals_subset = [0, 10, 50, 100]
fig, axes = plt.subplots(len(vals_subset), 1, figsize=(6, 9), sharex=True)

# Iterate over each sample size and create a subplot
for i, val in enumerate(vals_subset):
    df_sample_mean = mean_df[mean_df[var_to_test] == val]
    df_sample_sem = sem_df[sem_df[var_to_test] == val]

    for tested_on in range(num_env):
        # Filter the mean and SEM dataframes for the specific 'Tested_On' value
        means = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['accuracy']
        sems = df_sample_sem[df_sample_sem['tested_on'] == tested_on]['accuracy']
        trained_on_values = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['trained_on']
        
        # Plot error bars
        axes[i].errorbar(trained_on_values, means, yerr=sems, label=f'Tested on Env {tested_on}', marker='o')

    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) {val} self-generated samples')
    axes[i].set_ylabel('Accuracy')
    axes[i].set_ylim((0,1.05))
    axes[i].legend()

# Set common labels and title
axes[-1].set_xticks(range(num_env))
axes[-1].set_xticklabels([F'Env. {n}' for n in range(num_env)])

plt.tight_layout()
plt.savefig('Sample effect.png', dpi=500)
plt.show()


In [None]:
df = pd.read_csv('replay_results_imagined_1102_final.csv')
df = df[df["sample_size"] == 50][df["train_size"] == 100]
df = df[df['test_type'] == 'next_node']
var_to_test = 'temp'

# Group by var_to_test, 'Trained_On', and 'Tested_On', and calculate mean and SEM
grouped = df.groupby([var_to_test, 'trained_on', 'tested_on'])
mean_df = grouped['accuracy'].mean().reset_index()
sem_df = grouped['accuracy'].sem().reset_index()

#sem_df['accuracy'] = 0 # ignore as there is just one trial

vals = mean_df[var_to_test].unique()
num_env = df['trained_on'].nunique()

# Create a figure with subplots
vals_subset = [-1, 0.3, 0.9, 1.5, 2.1]
fig, axes = plt.subplots(len(vals_subset), 1, figsize=(6, 12), sharex=True)

# Iterate over each sample size and create a subplot
for i, val in enumerate(vals_subset):
    df_sample_mean = mean_df[mean_df[var_to_test] == val]
    df_sample_sem = sem_df[sem_df[var_to_test] == val]

    for tested_on in range(num_env):
        # Filter the mean and SEM dataframes for the specific 'Tested_On' value
        means = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['accuracy']
        sems = df_sample_sem[df_sample_sem['tested_on'] == tested_on]['accuracy']
        trained_on_values = df_sample_mean[df_sample_mean['tested_on'] == tested_on]['trained_on']
        
        # Plot error bars
        axes[i].errorbar(trained_on_values, means, yerr=sems, label=f'Tested on Env {tested_on}', marker='o')

    letter = string.ascii_lowercase[i]
    axes[i].set_title(f'{letter}) Temperature of {val}' if val > 0 else 'a) Experience replay')
    axes[i].set_ylabel('Accuracy')
    axes[i].set_ylim((0,1.05))
    axes[i].set_xlabel('Training stage')
    axes[i].legend()

# Set common labels and title
axes[-1].set_xticks(range(num_env))
axes[-1].set_xticklabels([F'Env. {n}' for n in range(num_env)])

plt.tight_layout()
plt.savefig('Temp effect.png', dpi=500)
plt.show()


#### Plot aggregated forgetting rate stats

In [None]:
df = pd.read_csv('replay_results_imagined_1102_final.csv')
df = df[df["sample_size"] == 50][df["train_size"] == 100]
df = df[df['test_type'] == 'next_node']
df = df[df['tested_on'] <= df['trained_on']]

# Group by temperature, trained_on, tested_on to calculate mean accuracy
mean_accuracy_df = df.groupby(['temp', 'trained_on', 'tested_on'])['accuracy'].mean().reset_index()

# Sort to ensure calculation is done in sequence
mean_accuracy_df.sort_values(by=['temp', 'tested_on', 'trained_on'], inplace=True)

# Calculate the decrease in accuracy for each subsequent stage
mean_accuracy_df['next_accuracy'] = mean_accuracy_df.groupby(['temp', 'tested_on'])['accuracy'].shift(-1)
mean_accuracy_df['decrease'] = mean_accuracy_df['next_accuracy'] - mean_accuracy_df['accuracy']

# Drop the last stage for each temp and tested_on since it has no subsequent stage to compare
mean_accuracy_df.dropna(subset=['decrease'], inplace=True)

# Group by temp to calculate mean rate of forgetting
mean_forgetting_rate = mean_accuracy_df.groupby('temp')['decrease'].mean().reset_index()

# Filter out the rows where 'temp' is -1 before plotting
mean_forgetting_rate_filtered = mean_forgetting_rate[mean_forgetting_rate['temp'] != -1]

# Continue with plotting using the filtered DataFrame
plt.figure(figsize=(3, 3))
plt.plot(mean_forgetting_rate_filtered['temp'], mean_forgetting_rate_filtered['decrease'], marker='o', linestyle='-', color='blue')  # Plotting with temperature != -1
plt.xlabel('Temperature')
plt.ylabel('Mean accuracy change')
plt.savefig('mean_acc_change_temp.png', bbox_inches='tight', dpi=500)
plt.show()

In [None]:
# Load the dataset
df = pd.read_csv('replay_results_imagined_1102_final.csv')
df = df[df['train_size'] == 100][df['temp'] == 1.2]
df = df[df['test_type'] == 'next_node']
df = df[df['tested_on'] <= df['trained_on']]

# Initialize a DataFrame to store the rate of forgetting for each sample size
rates_of_forgetting = []

# Get unique sample sizes
sample_sizes = df['sample_size'].unique()

for sample_size in sample_sizes:
    df_sample = df[df['sample_size'] == sample_size]
    
    # Group by 'trained_on' and 'tested_on' to calculate mean accuracy
    grouped = df_sample.groupby(['trained_on', 'tested_on'])['accuracy'].mean().reset_index()
    
    # Calculate rate of forgetting
    grouped['next_accuracy'] = grouped.groupby('tested_on')['accuracy'].shift(-1)
    grouped['decrease'] = grouped['accuracy'] - grouped['next_accuracy']
    grouped.dropna(subset=['decrease'], inplace=True)
    
    # Average the decrease for this sample size
    mean_decrease = grouped['decrease'].mean()
    rates_of_forgetting.append((sample_size, mean_decrease))

# Convert the list to a DataFrame
rates_df = pd.DataFrame(rates_of_forgetting, columns=['sample_size', 'mean_rate_of_forgetting'])

# Plot
plt.figure(figsize=(3, 3))
plt.plot(rates_df['sample_size'], -rates_df['mean_rate_of_forgetting'], marker='o', linestyle='-', color='blue')  # Negative sign to show decrease as positive values
plt.xlabel('Sample size')
plt.yticks([-0.4, -0.3, -0.2, -0.1, 0])
plt.ylabel('Mean accuracy change')
plt.savefig('mean_acc_change_sample.png', bbox_inches='tight', dpi=500)
plt.show()