### Sleep simulations - items variant

#### Installation:

In [None]:
!pip install numpy==1.24.2
!pip tensorflow-macos==2.11.0
!pip install evaluate

#### Imports:

In [None]:
from continual_learning_utils import *
from grid_environment_utils import * 
from testing_utils import * 
import random
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import random
import string
import os
import re
import glob
import csv
import torch
from wonderwords import RandomWord
import os
import gc
import pickle
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from itertools import permutations
import logging
from random import shuffle
from matplotlib import pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import math
import evaluate
from scipy.stats import sem

os.environ["WANDB_DISABLED"] = "true"

In [None]:
def train_model_script(name_or_path='spatial_model', 
                       num_epochs=3,
                       output_dir='./clm_script',
                       save_steps=100,
                       lr=5e-05 ):
    torch.cuda.empty_cache()
    gc.collect()
    ! python ./run_clm.py \
        --model_name_or_path {name_or_path} \
        --train_file {os.path.join(output_dir, 'train.txt')} \
        --validation_file {os.path.join(output_dir, 'train.txt')} \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --do_train \
        --do_eval \
        --output_dir {output_dir} \
        --overwrite_output_dir \
        --num_train_epochs {num_epochs} \
        --save_strategy 'steps' \
        --save_steps {save_steps} \
        --learning_rate {lr} \
        --log_level 'error'

#### Test generative replay

Let's first create training data for 5 environments.

In [None]:
def train_on_env(training_strs, testing_strs, eps=10, lr=5e-05, num_train=100, env=0, base_model='base_model', generated_strs=None):
    # There are 100 training sequence so by default use them all
    list_to_write = np.random.choice(training_strs[env], num_train).tolist()
    
    if generated_strs is not None:
        list_to_write.extend(generated_strs)

    # We oversample the list of training sequences 
    # This avoids overfitting to a particular sequence order
    text_file = open(f"spatial_model_{env}/train.txt", "w")
    list_to_write = np.random.choice(list_to_write, len(list_to_write)*10).tolist()
    n = text_file.write('\n'.join(list_to_write))
    text_file.close()

    text_file = open(f"spatial_model_{env}/test.txt", "w")
    n = text_file.write('\n'.join(testing_strs[env]))
    text_file.close()
    
    train_model_script(name_or_path=base_model, 
                       output_dir=f'spatial_model_{env}', 
                       num_epochs=eps, 
                       save_steps=2000,
                       lr=lr)

In [None]:
def get_mean_perplexity(input_texts, model_path):
    perplexity = evaluate.load("perplexity", module_type="metric")
    results = perplexity.compute(model_id=model_path,
                                 add_start_token=False,
                                 predictions=input_texts)
    return results['mean_perplexity']

In [None]:
def accuracy_paired_bar(results):
    structured_data = np.concatenate([np.array(res[0]) for res in results])

    test_envs = np.unique(structured_data[:, 1])
    fig, ax = plt.subplots(figsize=(4, 3))
    bar_width = 0.35
    opacity = 0.8

    for i, env in enumerate(test_envs):
        before_data = structured_data[(structured_data[:, 1] == env) & (structured_data[:, 0] == 0)]
        after_data = structured_data[(structured_data[:, 1] == env) & (structured_data[:, 0] == 1)]

        means = [np.mean(before_data[:, 2]), np.mean(after_data[:, 2])]
        errors = [sem(before_data[:, 2]) if len(before_data[:, 2]) > 1 else 0, 
                  sem(after_data[:, 2]) if len(after_data[:, 2]) > 1 else 0]

        ax.bar(np.arange(len(means)) + i * bar_width, means, bar_width, alpha=opacity, color=plt.cm.Paired(i), yerr=errors, label=f'Test Env {int(env)}')

    ax.set_xlabel('Training Phase')
    ax.set_ylabel('Accuracy')
    ax.set_title('Accuracy before and after \'sleep\'')
    ax.set_xticks(np.arange(len(means)) + bar_width / 2)
    ax.set_xticklabels(['Before', 'After'])
    ax.legend()

    plt.tight_layout()
    plt.show()
    return fig

    
def perplexity_plot(results):
    combined_perplexity = np.concatenate([np.array(res[1]) for res in results])

    fig, ax = plt.subplots(figsize=(4, 3))
    colors = ['blue', 'green']

    for j in range(2):
        env_perplexity = combined_perplexity[combined_perplexity[:, 1] == j]
        cycles = np.unique(env_perplexity[:, 0])

        mean_perplexities = [np.mean(env_perplexity[env_perplexity[:, 0] == cycle, 2]) for cycle in cycles]
        perplexity_sems = [sem(env_perplexity[env_perplexity[:, 0] == cycle, 2]) if len(env_perplexity[env_perplexity[:, 0] == cycle, 2]) > 1 else 0 for cycle in cycles]
        ax.errorbar(cycles, mean_perplexities, yerr=perplexity_sems, label=f'Test: Env {j}', marker='o', color=colors[j])

    ax.set_xlabel('Cycle')
    ax.set_ylabel('Perplexity')
    ax.set_title('Model Perplexity Across Cycles')
    ax.legend()
    plt.show()
    return fig


In [None]:
import matplotlib.backends.backend_pdf
import matplotlib.pyplot as plt
import numpy as np
import random

training_strs, testing_strs = prepare_data(default=True)


def train_with_schedule(num_cycles=20, start_fraction_rem=0.2, end_fraction_rem=0.8,
                        seed=0, lr=0.001, temperature=1, total_items=200, eps_per_item=3):

    ! rm -rf spatial_model_0
    ! mkdir spatial_model_0
    ! rm -rf spatial_model_1
    ! mkdir spatial_model_1

    np.random.seed(seed)
    
    items_per_cycle = int(total_items / num_cycles)
    nrem_eps = rem_eps = eps_per_item

    # Train on env. 0
    train_on_env(training_strs, testing_strs, env=0, base_model='base_model_b8', generated_strs=[])
    
    perplexity_results = []
    for cycle in range(num_cycles):
        
        current_fraction_rem = start_fraction_rem + (end_fraction_rem - start_fraction_rem) * cycle / (num_cycles - 1)
        # Update the nrem_eps and rem_eps values for this cycle
        rem_items = int(current_fraction_rem * items_per_cycle)
        nrem_items = items_per_cycle - rem_items
        
        # train for nrem_eps 
        print("NREM phase")
        train_on_env(training_strs, 
                     testing_strs, 
                     env=1, 
                     base_model='spatial_model_0' if cycle == 0 else f'spatial_model_1',
                     generated_strs=[], 
                     num_train=nrem_items,
                     eps=nrem_eps,
                     lr=lr)
        
        # train for rem_eps on sampled mnist
        print("REM phase")
        generated_strs = generative_replay(GPT(base_model='spatial_model_1'), num=rem_items, temperature=temperature)
        print(generated_strs[0:5])
        train_on_env(training_strs, 
                     testing_strs, 
                     env=1, 
                     base_model=f'spatial_model_1',
                     generated_strs=generated_strs, 
                     num_train=0,
                     eps=rem_eps,
                     lr=lr)
    
        for j in range(2):
            with open(f"spatial_model_{j}/test.txt", 'r') as file:
                test_data = [line.strip() for line in file]
            perplexity = get_mean_perplexity(test_data, 'spatial_model_1')
            perplexity_results.append([cycle, j, perplexity])
        print("Perplexity results so far:")
        print(perplexity_results)

    # Test on all environments
    results = []
    for i in range(2):
        model = GPT(base_model=f'spatial_model_{i}')
        for j in range(2):
            with open(f"spatial_model_{j}/test.txt", 'r') as file:
                test_data = [line.strip() for line in file]
            accuracy = test_accuracy(model, test_data)
            results.append([i, j, accuracy])

    return results, perplexity_results


def train_with_schedule_multiple_seeds(seeds=[0], num_cycles=20, start_fraction_rem=0.2, end_fraction_rem=0.8,
                                       lr=0.001, total_items=200):
    
    pdf_path = "./outputs/num_cycles={}_start_rem={}_end_rem={}_lr={}.pdf".format(str(num_cycles),
                                                                                         str(start_fraction_rem),
                                                                                         str(end_fraction_rem),
                                                                                         str(lr))

    pdf = matplotlib.backends.backend_pdf.PdfPages(pdf_path)
    
    items_per_cycle = int(total_items / num_cycles)
    fig = plot_schedule(num_cycles, total_items, items_per_cycle, start_fraction_rem, end_fraction_rem)
    pdf.savefig(fig, bbox_inches = "tight")
    
    results = [train_with_schedule(num_cycles=num_cycles,
                                   start_fraction_rem=start_fraction_rem,
                                   end_fraction_rem=end_fraction_rem,
                                   lr=lr,
                                   seed=s) for s in seeds]

    fig = accuracy_paired_bar(results)
    pdf.savefig(fig, bbox_inches = "tight")
    fig = perplexity_plot(results)
    pdf.savefig(fig, bbox_inches = "tight")
    pdf.close()
    return results
   

def plot_schedule(NUM_CYCLES, TOTAL_EPS, eps_per_cycle, starting_fraction_rem, ending_fraction_rem):
    fig, ax = plt.subplots()
    fig.set_size_inches(10, 2)

    nrem_starts = list(range(0, TOTAL_EPS, int(eps_per_cycle)))
    
    for cycle in range(NUM_CYCLES):
        # Calculate the current fraction of REM sleep for this cycle
        current_fraction_rem = starting_fraction_rem + (ending_fraction_rem - starting_fraction_rem) * cycle / (NUM_CYCLES - 1)

        # Update the nrem_eps and rem_eps values for this cycle
        rem_eps = int(current_fraction_rem * eps_per_cycle)
        nrem_eps = eps_per_cycle - rem_eps

        rem_start = nrem_starts[cycle] + nrem_eps
        
        ax.broken_barh([(rem_start, rem_eps)],
                       (10, 9),
                       facecolors='tab:blue')
        ax.broken_barh([(nrem_starts[cycle], nrem_eps)],
                       (20, 9),
                       facecolors='tab:red')

    ax.set_xlabel('Items')
    ax.set_yticks([15, 25], labels=['REM', 'NREM'])
    ax.grid(True)

    plt.show()
    return fig

#### Baseline

#### Number of cycles comparison

In [None]:
def generative_replay(model, num=100, temperature=1):
    examples = []
    while len(examples) < num:
        out = model.continue_input("FROM:", 
                                   do_sample=True,
                                   temperature=temperature)
        # Leave out the last sequence as it stopped midway through
        examples.extend(out.split('\n')[:-1])
    shuffle(examples)
    return examples

def experience_replay(i, train_size=100, sample_size=10):
    # Get sample_size items from the first train_size items of each previous environment
    train_list = [training_strs[j][0:train_size] for j in range(0,i)]
    # Flatten this list
    train_list = [x for xs in train_list for x in xs]
    return random.choices(train_list, k=sample_size)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0],  
                                   num_cycles=8, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   total_items=400)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1,2], 
                                   num_cycles=4, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1,2], 
                                   num_cycles=8, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1,2], 
                                   num_cycles=16, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1,2], 
                                   num_cycles=16, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001)