### Modelling statistical learning

To explore the effect of sleep on statistical learning, Durrant et al. (2011) constructed two types of sequence, both made up of regular tones at differing frequencies. One type had a structure in which the preceding two tones determined the next, except for a few transitions which were random to avoid repetition. The other type was the reverse – most transitions were random. After listening to a structured sequence, participants were tested on their ability distinguish short structured and unstructured sequences. Delayed recall was then tested, after a night’s sleep for one group, and after a waking rest for the other. Durrant et al. (2011) found that sleep improved performance more than waking rest, suggesting systems consolidation promotes statistical learning.

Here, we generate a set of sequences based on the transition structure in Durrant et al. (2011). A model with the GPT-2 architecture is trained from scratch on the structured sequences only. At the end of each epoch of the training, the perplexity is calculated for a two test sets of structured and unstructured sequences. We find that the difference in perplexity of these two sets increases over time, corresponding to improved ability to distinguish them. In addition, outputs from the trained model are structured in the same way as the training data.

Tested with conda_pytorch_latest_p36 kernel in AWS SageMaker.

#### Installation:

In [None]:
!pip install simpletransformers --upgrade

#### Imports:

In [None]:
import sys
sys.path.append('../scripts/')

import pandas as pd
import random
import logging
from random import shuffle
from matplotlib import pyplot as plt
from statistical_learning_utils import *
from gpt import GPT
import os
import glob
import evaluate
from evaluate import load
import numpy as np

In [None]:
def get_random_sequence():
    start = [random.randint(1,5),random.randint(1,5)]
    for i in range(50):
        next_val = random.randint(1,5)
        start.append(next_val)
    return ','.join([str(i) for i in start])

text_file = open("train.txt", "w")
walks = [get_sequence() for i in range(2000)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("val.txt", "w")
walks = [get_sequence() for i in range(100)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("structured_test.txt", "w")
walks = [get_sequence() for i in range(100)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

text_file = open("unstructured_test.txt", "w")
walks = [get_random_sequence() for i in range(100)]
shuffle(walks)
n = text_file.write('\n'.join(walks))
text_file.close()

#### Train generative model

Train GPT-2 from scratch on dataset created above.

In [None]:
for trial in range(0, 3):

    !rm -rf durrant_{trial}
    !mkdir durrant_{trial}

    # Train the model
    gpt = GPT(vocab_size=10)
    model = gpt.train(segmented_sequence_list=[], 
                      best_model_dir=f'durrant_{trial}', 
                      train_file="train.txt", 
                      test_file="val.txt", 
                      eps=3,
                      seed=trial)

#### Test model

In [None]:
structured_test_file = "structured_test.txt"
unstructured_test_file = "unstructured_test.txt"

In [None]:
perplexity = load("perplexity", module_type="metric")

all_structured = []
all_unstructured = []

for trial in range(3):    
    
    perplexity_structured = []
    perplexity_unstructured = []
    
    for ep in [1, 2, 3]:
        pattern = os.path.join(f'./durrant_{trial}', f'*-epoch-{ep}')
        model_dir = glob.glob(pattern)[0]

        with open(structured_test_file, 'r') as file:
            structured_test_examples = file.readlines()
        results = perplexity.compute(model_id=model_dir,
                             add_start_token=False,
                             predictions=structured_test_examples)['mean_perplexity']
        perplexity_structured.append(results)

        with open(unstructured_test_file, 'r') as file:
            unstructured_test_examples = file.readlines()
        results = perplexity.compute(model_id=model_dir,
                             add_start_token=False,
                             predictions=unstructured_test_examples)['mean_perplexity']
        perplexity_unstructured.append(results)

    all_unstructured.append(perplexity_unstructured)
    all_structured.append(perplexity_structured)

#### Plot perplexity against time

In [None]:
# Convert lists to numpy arrays for easier computation
structured_array = np.array(all_structured)
unstructured_array = np.array(all_unstructured)

# Calculate the average perplexity for each epoch across the three trials
avg_structured = np.mean(structured_array, axis=0)
avg_unstructured = np.mean(unstructured_array, axis=0)

# Calculate the SEM for each epoch across the three trials
sem_structured = np.std(structured_array, axis=1) / np.sqrt(structured_array.shape[0])
sem_unstructured = np.std(unstructured_array, axis=1) / np.sqrt(unstructured_array.shape[0])

# Print the averages and SEM
print("Average structured perplexity:", avg_structured)
print("SEM structured perplexity:", sem_structured)
print("Average unstructured perplexity:", avg_unstructured)
print("SEM unstructured perplexity:", sem_unstructured)

# Define the number of epochs
epochs = [1, 2, 3]

# Create the bar chart
fig, ax = plt.subplots(figsize=(3.9, 2.3))

# Bar width
bar_width = 0.35

# Set positions of the bars on the x-axis
r1 = np.arange(len(epochs))
r2 = [x + bar_width for x in r1]

# Create bars for structured perplexity with error bars
ax.bar(r1, avg_structured, color='b', alpha=0.4, width=bar_width, yerr=sem_structured, capsize=2, edgecolor='grey', label='Structured')

# Create bars for unstructured perplexity with error bars
ax.bar(r2, avg_unstructured, color='r', alpha=0.4, width=bar_width, yerr=sem_unstructured, capsize=2, edgecolor='grey', label='Unstructured')

# Add labels
ax.set_xlabel('Epoch')
ax.set_ylabel('Perplexity')
ax.set_xticks([r + bar_width / 2 for r in range(len(epochs))])
ax.set_xticklabels(epochs)

# Add legend
ax.legend()

# Show the plot
plt.savefig('perplexities.png', dpi=300, bbox_inches='tight')
plt.show()


#### Plot transition structure of generated data

In [None]:
gpt = GPT(base_model='durrant_0', base_model_name='gpt2')
data = ""
for num in range(50):
    for i in range(1, 6):
        out = gpt.continue_input(str(i), do_sample=True, temperature=0.1)
        data += out

In [None]:
data_list = [int(x) for x in data.split(',') if x]

# Initialize a dictionary to hold the transition counts
transition_counts = {((i, j), k): 0 for i in range(1, 6) for j in range(1, 6) for k in range(1, 6)}

# Populate the transition counts
for i in range(len(data_list) - 2):
    prev_pair = (data_list[i], data_list[i+1])
    next_num = data_list[i+2]
    transition_counts[(prev_pair, next_num)] += 1

# Calculate probabilities from counts
transition_probabilities = {}
for key, value in transition_counts.items():
    prev_pair = key[0]
    total_transitions = sum([transition_counts[(prev_pair, k)] for k in range(1, 6)])
    if total_transitions > 0:
        transition_probabilities[key] = value / total_transitions
    else:
        transition_probabilities[key] = 0

# Prepare data for plotting
plot_data = np.zeros((25, 5))  # 25 possible pairs and 5 possible next numbers
for i, pair in enumerate(transition_counts.keys()):
    y_index = (pair[0][0] - 1) * 5 + (pair[0][1] - 1)
    x_index = pair[1] - 1
    plot_data[y_index, x_index] = 1 - transition_probabilities[pair]

# Plot
fig, ax = plt.subplots(figsize=(5, 5))
cax = ax.matshow(plot_data, cmap='Greys')

# Set ticks
ax.set_xticks(range(5))
ax.set_xticklabels(range(1, 6))
ax.set_yticks(range(25))
ax.set_yticklabels([f'{i//5+1},{i%5+1}' for i in range(25)])

ax.set_xlabel('Next Number')
ax.set_ylabel('Previous Pair')
ax.set_title('Transition Probabilities')

plt.colorbar(cax)
plt.savefig('trps.png', dpi=500)
plt.show()

#### Plot loss

In [None]:
file_paths = ['training_progress_scores.csv']

# Initialize lists to store data
eval_losses = []
train_losses = []

# Load data from each file and append to lists
for file_path in file_paths:
    df = pd.read_csv(file_path)
    eval_losses.append(df['eval_loss'])
    train_losses.append(df['train_loss'])

# Compute mean losses
mean_eval_loss = pd.concat(eval_losses, axis=1).mean(axis=1)
mean_train_loss = pd.concat(train_losses, axis=1).mean(axis=1)

# map 4902 steps to 3 epochs
epochs = [s*(3/4902) for s in df['global_step'].tolist()]

# Plot mean losses
plt.figure(figsize=(3.9, 2.3))
plt.plot(epochs, mean_eval_loss, label='Val loss', color='red', alpha=0.5, marker='.', markersize=8)
plt.plot(epochs, mean_train_loss, label='Train loss', color='blue', alpha=0.5, marker='.', markersize=8)
plt.xlabel('Epoch')
plt.xticks([0,1,2,3])
plt.xlim(0,3.1)
plt.ylabel('Loss')
plt.legend()
plt.savefig('loss.png', dpi=300, bbox_inches='tight')
plt.show()
