### Sleep simulations

#### Installation:

In [None]:
!pip install numpy==1.24.2
!pip tensorflow-macos==2.11.0
!pip install evaluate

#### Imports:

In [None]:
from sleep_utils import *
import random
import pandas as pd
import networkx as nx
import logging
from random import shuffle
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import random
import string
import os
import re
import glob
import csv
import torch
from wonderwords import RandomWord
import os
import gc
import pickle
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from itertools import permutations
import logging
from random import shuffle
from matplotlib import pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import math
import evaluate

os.environ["WANDB_DISABLED"] = "true"

In [None]:
class GPT:

    def __init__(self, base_model):
        self.tokenizer = GPT2Tokenizer.from_pretrained(base_model)
        self.model = GPT2LMHeadModel.from_pretrained(base_model)

    def continue_input(self, input_sequence, max_length=200, num_return_sequences=1, no_repeat_ngram_size=0,
                       do_sample=False, temperature=0.7, num_beams=1):
        
        input_ids = self.tokenizer.encode(input_sequence, return_tensors='pt')

        # Generate text
        output = self.model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            num_beams=num_beams,
            no_repeat_ngram_size=no_repeat_ngram_size,
            do_sample=do_sample,
            temperature=temperature,
        )

        # Decode the output
        sequence = output[0].tolist()
        text = self.tokenizer.decode(sequence)
        return text

In [None]:
def train_model_script(name_or_path='spatial_model', 
                       num_epochs=3,
                       output_dir='./clm_script',
                       save_steps=100,
                       lr=5e-05 ):
    torch.cuda.empty_cache()
    gc.collect()
    ! python ./run_clm.py \
        --model_name_or_path {name_or_path} \
        --train_file {os.path.join(output_dir, 'train.txt')} \
        --validation_file {os.path.join(output_dir, 'train.txt')} \
        --per_device_train_batch_size 1 \
        --per_device_eval_batch_size 1 \
        --do_train \
        --do_eval \
        --output_dir {output_dir} \
        --overwrite_output_dir \
        --num_train_epochs {num_epochs} \
        --save_strategy 'steps' \
        --save_steps {save_steps} \
        --learning_rate {lr}  

In [None]:
r = RandomWord()

def create_unique_random_grid(nouns, size=3):
    """Creates a size x size grid with unique random nouns."""
    random_nouns = random.sample(nouns, size * size)
    return [random_nouns[i * size:(i + 1) * size] for i in range(size)]

def find_shortest_paths(grid, start_name, end_name):
    """Finds all shortest paths from start_name to end_name in a grid. """
    # Find coordinates of start and end points
    start = end = None
    for i, row in enumerate(grid):
        for j, name in enumerate(row):
            if name == start_name:
                start = (i, j)
            if name == end_name:
                end = (i, j)
    
    # Check if start or end points were not found
    if start is None or end is None:
        print ("start or end not found")
        return []

    paths = []
    start_x, start_y = start
    end_x, end_y = end

    # Total horizontal and vertical distances
    x_dist = end_x - start_x
    y_dist = end_y - start_y

    # Generate a list of directions taken in the shortest path
    # We know that the shortest route is x_dist EAST or WESTs, and y_dist NORTH or SOUTHs
    hor_moves = ['EAST' if x_dist > 0 else 'WEST'] * abs(x_dist)
    ver_moves = ['SOUTH' if y_dist > 0 else 'NORTH'] * abs(y_dist)
    all_moves = hor_moves + ver_moves

    # We have a list, e.g. [NORTH, NORTH, EAST, EAST] and we want to find all possible orderings
    # Each ordering (i.e. permutation) is a possible shortest path from start_name to end_name
    for path in set(permutations(all_moves, len(all_moves))):
        sequence = [f'FROM: {start_name}, TO: {end_name}, PATH: {start_name}']
        x, y = start
        for direction in path:
            if direction == 'EAST' and x < 2:
                x += 1
            elif direction == 'WEST' and x > 0:
                x -= 1
            elif direction == 'SOUTH' and y < 2:
                y += 1
            elif direction == 'NORTH' and y > 0:
                y -= 1
            else:
                # Invalid move, skip this path
                break
            sequence.append(f"{direction} {grid[x][y]}")

            # add the path when it successfully reaches the end point
            if (x, y) == end:
                paths.append(' '.join(sequence))

    return paths
  
# # example usage
# grid = create_unique_random_grid(nouns)
# paths = find_shortest_paths(grid, grid[0][0], grid[2][2])

# # print the grid and the paths to see the output
# print("Grid:", grid)
# print("Shortest Paths:", paths)


In [None]:
def shuffle_stimuli(stimuli):
    random.shuffle(stimuli)
    return stimuli

def get_all_paths_for_grid(grid):
    all_paths = []
    items = [item for sublist in grid for item in sublist]
    for start in items:
        for end in items:
            if start != end:
                all_paths.extend(find_shortest_paths(grid, start, end))
    return shuffle_stimuli(all_paths)

#### Test generative replay

Let's first create training data for 5 environments.

In [None]:
def train_on_env(training_strs, testing_strs, env=0, base_model='base_model', generated_strs=[], num_train=1000, eps=20):
    # if base_model != f'spatial_model_{env}':
    #     !rm -rf spatial_model_{env}
    #     !mkdir spatial_model_{env}
    
    text_file = open(f"spatial_model_{env}/train.txt", "w")
    list_to_write = np.random.choice(training_strs[env], num_train).tolist()
    
    list_to_write.extend(generated_strs)
    shuffle(list_to_write)
    
    n = text_file.write('\n'.join(list_to_write))
    text_file.close()

    text_file = open(f"spatial_model_{env}/test.txt", "w")
    n = text_file.write('\n'.join(testing_strs[env]))
    text_file.close()

    print(f"About to train model from {base_model} and save to spatial_model_{env}")
    
    train_model_script(name_or_path=base_model, 
                       output_dir=f'spatial_model_{env}', 
                       num_epochs=eps, 
                       save_steps=2000)

In [None]:
def generative_replay(model, num=100, temperature=1):
    examples = []
    while len(examples) < num:
        out = model.continue_input("FROM:", 
                                   do_sample=True,
                                   temperature=temperature)
        examples.extend(out.split('\n'))
    return examples

In [None]:
def test_accuracy(model, test_data):
    correct_predictions = 0
    total_predictions = 0
    directions = ['NORTH', 'EAST', 'SOUTH', 'WEST']

    for sequence in test_data:
        # Find the first direction in the PATH and create the input sequence up to that point
        first_direction_index = next((i for i, word in enumerate(sequence.split()) if word in directions), None)
        if first_direction_index is not None:
            # Prepare the input sequence up to and including the first direction
            input_sequence = ' '.join(sequence.split()[:first_direction_index + 1])
            
            # Generate the model's prediction
            full_predicted_sequence = model.continue_input(input_sequence)
            # Remove the input part from the predicted sequence
            predicted_sequence = full_predicted_sequence[len(input_sequence):].strip()
            predicted_token = predicted_sequence.split()[0]  # First word of the generation

            # Extract the corresponding true token
            target_token = sequence.split()[first_direction_index + 1]

            # Compare the predicted token with the true token
            total_predictions += 1
            print(f"Correct location: {target_token}, Predicted location: {predicted_token}")
            if predicted_token == target_token:
                correct_predictions += 1

    # Calculate accuracy
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    return accuracy

def get_mean_perplexity(input_texts, model_path):
    perplexity = evaluate.load("perplexity", module_type="metric")
    results = perplexity.compute(model_id=model_path,
                                 add_start_token=False,
                                 predictions=input_texts)
    return results['mean_perplexity']

In [None]:
def prepare_data():
    training_strs = []
    testing_strs = []
    for i in range(5):
        nouns = [r.word(include_parts_of_speech=["nouns"]).replace(" ", "_") for _ in range(9)]
        grid = create_unique_random_grid(nouns)
        print(grid)
        pths = get_all_paths_for_grid(grid)
        training_strs.append(pths[0:100])
        testing_strs.append(pths[100:])
    
    for env in range(5):
        if os.path.exists(f"spatial_model_{env}") is False:
            os.mkdir(f"spatial_model_{env}")
        text_file = open(f"spatial_model_{env}/test.txt", "w")
        n = text_file.write('\n'.join(testing_strs[env]))
        text_file.close()
    
    return training_strs, testing_strs

In [None]:
def accuracy_paired_bar(results):
    # Convert results to a structured array
    structured_data = np.concatenate([np.array(res[0]) for res in results])

    # Determine unique test environments
    test_envs = np.unique(structured_data[:, 1])

    # Initialize plot
    fig, ax = plt.subplots(figsize=(4, 3))
    bar_width = 0.35
    opacity = 0.8

    # Process data for each test environment
    for i, env in enumerate(test_envs):
        # Filter data for each environment and phase
        before_data = structured_data[(structured_data[:, 1] == env) & (structured_data[:, 0] == 0)]
        after_data = structured_data[(structured_data[:, 1] == env) & (structured_data[:, 0] == 1)]

        # Calculate mean and SEM for before and after
        means = [np.mean(before_data[:, 2]), np.mean(after_data[:, 2])]
        errors = [sem(before_data[:, 2]), sem(after_data[:, 2])]

        # Create bars for this test environment
        ax.bar(np.arange(len(means)) + i * bar_width, means, bar_width, alpha=opacity, color=plt.cm.Paired(i), yerr=errors, label=f'Test Env {int(env)}')

    ax.set_xlabel('Training Phase')
    ax.set_ylabel('Accuracy')
    ax.set_title('Accuracy before and after \'sleep\'')
    ax.set_xticks(np.arange(len(means)) + bar_width / 2)
    ax.set_xticklabels(['Before', 'After'])
    ax.legend()

    plt.tight_layout()
    plt.show()
    return fig

def perplexity_plot(results):
    # Combine all perplexity results
    combined_perplexity = np.concatenate([np.array(res[1]) for res in results])

    # Plot perplexity
    fig, ax = plt.subplots(figsize=(4, 3))
    colors = ['blue', 'green']  # Assuming 2 testing environments, assign a color to each

    for j in range(2):  # Assuming 2 testing environments
        env_perplexity = combined_perplexity[combined_perplexity[:, 1] == j]
        cycles = np.unique(env_perplexity[:, 0])

        mean_perplexities = [np.mean(env_perplexity[env_perplexity[:, 0] == cycle, 2]) for cycle in cycles]
        perplexity_sems = [sem(env_perplexity[env_perplexity[:, 0] == cycle, 2]) for cycle in cycles]

        ax.errorbar(cycles, mean_perplexities, yerr=perplexity_sems, label=f'Test: Env {j}', marker='o', color=colors[j])

    ax.set_xlabel('Cycle')
    ax.set_ylabel('Perplexity')
    ax.set_title('Model Perplexity Across Cycles')
    ax.legend()
    plt.show()
    return fig

In [None]:
import matplotlib.backends.backend_pdf
import matplotlib.pyplot as plt
import tensorflow as tf
from generative_model import VAE
from generative_tests import check_generative_recall
from tensorflow import keras
import numpy as np
import matplotlib.backends.backend_pdf
from generative_model import models_dict
import matplotlib
import random

training_strs, testing_strs = prepare_data()


def train_with_schedule(total_eps=100, num_cycles=20, start_fraction_rem=0.2, end_fraction_rem=0.8,
                        seed=0, lr=0.001, num=100, temperature=1):


    np.random.seed(seed)
    
    eps_per_cycle = int(total_eps / num_cycles)

    # Train on env. 0
    train_on_env(training_strs, testing_strs, env=0, base_model='base_model', generated_strs=[])
    
    perplexity_results = []
    for cycle in range(num_cycles):
        
        current_fraction_rem = start_fraction_rem + (end_fraction_rem - start_fraction_rem) * cycle / (num_cycles - 1)
        # Update the nrem_eps and rem_eps values for this cycle
        rem_eps = int(current_fraction_rem * eps_per_cycle)
        nrem_eps = eps_per_cycle - rem_eps
        
        # train for nrem_eps on real fmnist
        print("NREM phase")
        train_on_env(training_strs, 
                     testing_strs, 
                     env=1, 
                     base_model='spatial_model_0' if cycle == 0 else f'spatial_model_1',
                     generated_strs=[], 
                     num_train=num,
                     eps=nrem_eps)
        
        # train for rem_eps on sampled mnist
        print("REM phase")
        generated_strs = generative_replay(GPT(base_model='spatial_model_1'), num=num, temperature=temperature)
        print(generated_strs[0:5])
        train_on_env(training_strs, 
                     testing_strs, 
                     env=1, 
                     base_model=f'spatial_model_1',
                     generated_strs=generated_strs, 
                     num_train=0,
                     eps=rem_eps)
    
        for j in range(2):
            with open(f"spatial_model_{j}/test.txt", 'r') as file:
                test_data = [line.strip() for line in file]
            perplexity = get_mean_perplexity(test_data, 'spatial_model_1')
            perplexity_results.append([cycle, j, perplexity])
        print("Perplexity results so far:")
        print(perplexity_results)

    # Test on all environments
    results = []
    for i in range(2):
        model = GPT(base_model=f'spatial_model_{i}')
        for j in range(2):
            with open(f"spatial_model_{j}/test.txt", 'r') as file:
                test_data = [line.strip() for line in file]
            accuracy = test_accuracy(model, test_data)
            results.append([i, j, accuracy])

    return results, perplexity_results


def train_with_schedule_multiple_seeds(seeds=[0], total_eps=100, num_cycles=20, start_fraction_rem=0.2, end_fraction_rem=0.8,
                                       lr=0.001, num=10):
    
    pdf_path = "./outputs/total_eps={}_num_cycles={}_start_rem={}_end_rem={}_lr={}_num={}.pdf".format(str(total_eps),
                                                                                                                  str(num_cycles),
                                                                                                                  str(start_fraction_rem),
                                                                                                                  str(end_fraction_rem),
                                                                                                                  str(lr),
                                                                                                                  str(num))

    pdf = matplotlib.backends.backend_pdf.PdfPages(pdf_path)
    
    eps_per_cycle = int(total_eps / num_cycles)
    fig = plot_schedule(num_cycles, total_eps, eps_per_cycle, start_fraction_rem, end_fraction_rem)
    pdf.savefig(fig, bbox_inches = "tight")
    
    results = [train_with_schedule(total_eps=total_eps,
                                   num_cycles=num_cycles,
                                   start_fraction_rem=start_fraction_rem,
                                   end_fraction_rem=end_fraction_rem,
                                   lr=lr,
                                   seed=s,
                                   num=num) for s in seeds]

    fig = accuracy_paired_bar(results)
    pdf.savefig(fig, bbox_inches = "tight")
    fig = perplexity_plot(results)
    pdf.savefig(fig, bbox_inches = "tight")
    pdf.close()
    return results
   

def plot_schedule(NUM_CYCLES, TOTAL_EPS, eps_per_cycle, starting_fraction_rem, ending_fraction_rem):
    fig, ax = plt.subplots()
    fig.set_size_inches(10, 2)

    nrem_starts = list(range(0, TOTAL_EPS, int(eps_per_cycle)))
    
    for cycle in range(NUM_CYCLES):
        # Calculate the current fraction of REM sleep for this cycle
        current_fraction_rem = starting_fraction_rem + (ending_fraction_rem - starting_fraction_rem) * cycle / (NUM_CYCLES - 1)

        # Update the nrem_eps and rem_eps values for this cycle
        rem_eps = int(current_fraction_rem * eps_per_cycle)
        nrem_eps = eps_per_cycle - rem_eps

        rem_start = nrem_starts[cycle] + nrem_eps
        
        ax.broken_barh([(rem_start, rem_eps)],
                       (10, 9),
                       facecolors='tab:blue')
        ax.broken_barh([(nrem_starts[cycle], nrem_eps)],
                       (20, 9),
                       facecolors='tab:red')

    ax.set_xlabel('Epochs')
    ax.set_yticks([15, 25], labels=['REM', 'NREM'])
    ax.grid(True)

    plt.show()
    return fig

#### Baseline

In [None]:
train_with_schedule_multiple_seeds(seeds=[0], 
                                   total_eps=100, 
                                   num_cycles=25, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001,
                                   num=100)

#### Number of cycles comparison

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1,2], 
                                   total_eps=100, 
                                   num_cycles=10, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001,
                                   num=100)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1,2], 
                                   total_eps=100, 
                                   num_cycles=5, 
                                   start_fraction_rem=0.5, 
                                   end_fraction_rem=0.5,
                                   lr=0.001,
                                   num=100)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0,1], 
                                   total_eps=100, 
                                   num_cycles=10, 
                                   start_fraction_rem=0.8, 
                                   end_fraction_rem=0.8,
                                   lr=0.001,
                                   num=100)

In [None]:
train_with_schedule_multiple_seeds(seeds=[0], 
                                   total_eps=100, 
                                   num_cycles=5, 
                                   start_fraction_rem=1.0, 
                                   end_fraction_rem=1.0,
                                   lr=0.001,
                                   num=100)

#### NREM/REM ratio comparison

#### Temperature comparison

In [None]:
# train_with_schedule_multiple_seeds(seeds=[0], 
#                                    total_eps=100, 
#                                    num_cycles=10, 
#                                    start_fraction_rem=0.5, 
#                                    end_fraction_rem=0.5,
#                                    lr=0.001,
#                                    num=100,
#                                    temperature=0.5)

In [None]:
# train_with_schedule_multiple_seeds(seeds=[0], 
#                                    total_eps=100, 
#                                    num_cycles=10, 
#                                    start_fraction_rem=0.5, 
#                                    end_fraction_rem=0.5,
#                                    lr=0.001,
#                                    num=100,
#                                    temperature=0.75)

In [None]:
# train_with_schedule_multiple_seeds(seeds=[0], 
#                                    total_eps=100, 
#                                    num_cycles=10, 
#                                    start_fraction_rem=0.5, 
#                                    end_fraction_rem=0.5,
#                                    lr=0.001,
#                                    num=100,
#                                    temperature=1.25)